From b3e914bcf51ddacf6ce39d84555d9c97736803e7 Mon Sep 17 00:00:00 2001 From: wuzekai <3025054974@qq.com> Date: Wed, 3 Dec 2025 17:01:17 +0800 Subject: [PATCH] update --- radiant.model/R/svm.R | 356 ++++++++---------- .../inst/app/tools/analysis/svm_ui.R | 89 +++-- 2 files changed, 206 insertions(+), 239 deletions(-) diff --git a/radiant.model/R/svm.R b/radiant.model/R/svm.R index 2694271..0f76f09 100644 --- a/radiant.model/R/svm.R +++ b/radiant.model/R/svm.R @@ -328,231 +328,156 @@ plot.svm <- function(x, { if (isTRUE(shiny)) . else print(.) } } -#' Predict method for SVM model +#' Predict method for SVM model (FIXED VERSION) #' @export predict.svm <- function(object, pred_data = NULL, pred_cmd = "", dec = 3, envir = parent.frame(), ...) { + # 1. 基础校验 if (is.character(object)) return(object) if (!inherits(object, "svm")) stop("Object must be of class 'svm'") if (is.null(object$model) || !inherits(object$model, "svm")) { err_msg <- "Prediction failed: Invalid SVM model" - err_obj <- structure(list(error = err_msg), class = "svm.predict.error") - return(err_obj) + return(structure(list(error = err_msg), class = "svm.predict.error")) } svm_info <- object - # 2. 处理预测数据 + # 2. 获取预测数据 if (is.character(pred_data) && nzchar(pred_data)) { if (!exists(pred_data, envir = envir)) { - err_obj <- structure(list(error = sprintf("Dataset '%s' not found", pred_data)), class = "svm.predict.error") - return(err_obj) + return(structure(list(error = sprintf("Dataset '%s' not found", pred_data)), + class = "svm.predict.error")) } pred_data <- get(pred_data, envir = envir) } + has_data <- !is.null(pred_data) && is.data.frame(pred_data) && nrow(pred_data) > 0 - has_cmd <- is.character(pred_cmd) && nzchar(pred_cmd) + has_cmd <- is.character(pred_cmd) && nzchar(pred_cmd) if (!has_data && !has_cmd) { - err_obj <- structure(list(error = "Please select data and/or specify a command to generate predictions."), class = "svm.predict.error") - return(err_obj) + return(structure(list(error = "Please select data and/or specify a prediction command."), + class = "svm.predict.error")) } - ## ensure you have a name for the prediction dataset - if (is.data.frame(pred_data)) { - df_name <- deparse(substitute(pred_data)) - } else { - df_name <- pred_data - } + df_name <- if (is.data.frame(pred_data)) deparse(substitute(pred_data)) else pred_data - # 3. 预测核心函数 - 修改参数名以匹配NN + # 3. 核心预测函数 pfun <- function(model, pred, se, conf_lev) { - # 验证数据 - missing_vars <- setdiff(svm_info$evar, colnames(pred)) - if (length(missing_vars) > 0) { - return(paste("Missing variables:", paste(missing_vars, collapse = ", "))) + + ## ---- 1. 保证预测集字段顺序与训练集完全一致 ---- + pred <- pred[ , svm_info$evar, drop = FALSE] + + ## ---- 2. 对齐变量类型(再次重复训练阶段的逻辑) ---- + train_df <- svm_info$dataset + + for (v in svm_info$evar) { + + # 若训练数据该变量是因子/字符 + if (is.factor(train_df[[v]])) { + + ## 预测也转因子并保持相同 level + pred[[v]] <- factor(pred[[v]], + levels = levels(train_df[[v]])) + + ## 转 numeric(训练阶段就是把 factor 都转成 numeric 的) + pred[[v]] <- as.numeric(pred[[v]]) + + } else if (is.character(train_df[[v]])) { + + pred[[v]] <- factor(pred[[v]], + levels = unique(train_df[[v]])) + pred[[v]] <- as.numeric(pred[[v]]) + + } else if (is.logical(train_df[[v]])) { + + pred[[v]] <- as.numeric(pred[[v]]) + + } else { + ## numeric → 保持原样 + pred[[v]] <- as.numeric(pred[[v]]) + } } - # 内部标准化 - ms <- attr(svm_info$dataset, "radiant_ms") - sds <- attr(svm_info$dataset, "radiant_sds") + ## ---- 3. 应用训练阶段的标准化参数 ---- + ms <- attr(train_df, "radiant_ms") + sds <- attr(train_df, "radiant_sds") + if (!is.null(ms) && !is.null(sds)) { - isNum <- sapply(pred, is.numeric) - cn <- names(isNum)[isNum] - for (var in cn) { - if (var %in% names(ms) && var %in% names(sds) && sds[var] != 0) { - pred[[var]] <- (pred[[var]] - ms[var]) / sds[var] + for (v in svm_info$evar) { + if (!is.na(ms[v]) && !is.na(sds[v]) && sds[v] != 0) { + pred[[v]] <- (pred[[v]] - ms[v]) / sds[v] } } } - # 调试信息:检查模型是否启用了概率 - cat("Model probability enabled:", svm_info$model$prob.model, "\n") - - # 执行预测 - pred_result <- try({ + ## ---- 4. 调用 e1071::predict ---- + pred_result <- try( predict( - svm_info$model, + svm_info$model, newdata = pred, - probability = TRUE, # 始终设置为TRUE - decision.values = TRUE - ) - }, silent = TRUE) + probability = svm_info$model$prob.model + ), + silent = TRUE + ) - if (inherits(pred_result, "try-error")) { + if (inherits(pred_result, "try-error")) return(paste("Prediction failed:", attr(pred_result, "condition")$message)) - } - # 4. 结果整理 + ## ---- 5. 分类模型输出概率 ---- if (svm_info$type == "classification") { - # 正确的属性名是"probabilities" - prob_mat <- attr(pred_result, "probabilities") - if (!is.null(prob_mat)) { - # 调试:显示概率矩阵的列名 - cat("Probability matrix columns:", paste(colnames(prob_mat), collapse = ", "), "\n") - - # 更智能的列名匹配 - target_level <- as.character(svm_info$lev) - - # 尝试多种匹配方式 - matching_col <- NULL - - # 1. 精确匹配 - exact_match <- grep(paste0("^", target_level, "$"), colnames(prob_mat), value = TRUE, ignore.case = TRUE) - if (length(exact_match) > 0) { - matching_col <- exact_match - } - # 2. 部分匹配 - else { - partial_match <- grep(target_level, colnames(prob_mat), value = TRUE, ignore.case = TRUE) - if (length(partial_match) > 0) { - matching_col <- partial_match[1] # 取第一个匹配 - } - } - - # 3. 如果还是没有匹配,尝试逻辑值匹配 - if (is.null(matching_col) || length(matching_col) == 0) { - if (target_level %in% c("TRUE", "Yes", "1", "1.0")) { - logical_match <- grep("TRUE", colnames(prob_mat), value = TRUE, ignore.case = TRUE) - if (length(logical_match) > 0) { - matching_col <- logical_match[1] - } - } else if (target_level %in% c("FALSE", "No", "0", "0.0")) { - logical_match <- grep("FALSE", colnames(prob_mat), value = TRUE, ignore.case = TRUE) - if (length(logical_match) > 0) { - matching_col <- logical_match[1] - } - } - } - - # 4. 作为最后手段,使用第一列 - if (is.null(matching_col) || length(matching_col) == 0) { - if (ncol(prob_mat) > 0) { - matching_col <- colnames(prob_mat)[1] - cat("Warning: Using first probability column '", matching_col, "' for level '", target_level, "'\n") - } - } - - if (!is.null(matching_col) && length(matching_col) > 0) { - surv_prob <- as.numeric(prob_mat[, matching_col]) - pred_df <- data.frame( - Prediction = round(surv_prob, dec), - stringsAsFactors = FALSE - ) - } else { - # 备用方案:使用决策值 - decision_vals <- attr(pred_result, "decision.values") - if (!is.null(decision_vals)) { - # 将决策值转换为概率 - pred_prob <- 1 / (1 + exp(-decision_vals)) - pred_df <- data.frame( - Prediction = round(pred_prob, dec), - stringsAsFactors = FALSE - ) - } else { - # 最后手段:使用预测类 - pred_class <- as.character(pred_result) - pred_prob <- ifelse(pred_class == target_level, 1, 0) - pred_df <- data.frame( - Prediction = round(pred_prob, dec), - stringsAsFactors = FALSE - ) - } - } - } else { - # 备用方案:使用决策值 - decision_vals <- attr(pred_result, "decision.values") - if (!is.null(decision_vals)) { - # 将决策值转换为概率 - pred_prob <- 1 / (1 + exp(-decision_vals)) - pred_df <- data.frame( - Prediction = round(pred_prob, dec), - stringsAsFactors = FALSE - ) + prob <- attr(pred_result, "probabilities") + lev <- svm_info$lev + + if (!is.null(prob)) { + if (lev %in% colnames(prob)) { + p <- prob[, lev] } else { - # 最后手段:使用预测类 - target_level <- as.character(svm_info$lev) - pred_class <- as.character(pred_result) - pred_prob <- ifelse(pred_class == target_level, 1, 0) - pred_df <- data.frame( - Prediction = round(pred_prob, dec), - stringsAsFactors = FALSE - ) + p <- prob[, 1] # fallback } + return(data.frame(Prediction = round(p, dec))) } - } else { - # 回归模型 - pred_values <- as.numeric(pred_result) - - # 关键:添加反标准化处理 - rvar_ms <- if (!is.null(ms) && svm_info$rvar %in% names(ms)) ms[svm_info$rvar] else 0 - rvar_sds <- if (!is.null(sds) && svm_info$rvar %in% names(sds)) sds[svm_info$rvar] else 1 - - # 反标准化 - pred_values <- pred_values * rvar_sds + rvar_ms - pred_df <- data.frame( - Prediction = round(pred_values, dec), - stringsAsFactors = FALSE - ) + ## 无概率 → 退化成 0/1 + p <- as.character(pred_result) + p <- ifelse(p == lev, 1, 0) + return(data.frame(Prediction = round(p, dec))) } - return(pred_df) + ## ---- 6. 回归模型 ---- + return(data.frame(Prediction = round(as.numeric(pred_result), dec))) } - # 5. 调用预测框架 - 与NN完全一致 + + # 4. Radiant 框架式预测 result <- predict_model( - object, - pfun, - "svm.predict", # 模型类型 - pred_data, - pred_cmd, - conf_lev = 0.95, - se = FALSE, - dec, + object, + pfun, + "svm.predict", + pred_data, + pred_cmd, + conf_lev = 0.95, + se = FALSE, + dec, envir = envir - ) %>% - set_attr("radiant_pred_data", df_name) + ) %>% set_attr("radiant_pred_data", df_name) - # 6. 结果元数据 - if (inherits(result, "svm.predict.error")) { - return(result$error) - } + if (inherits(result, "svm.predict.error")) return(result$error) if (is.character(result)) return(result) result <- add_class(result, "svm.predict") attr(result, "svm_meta") <- list( model_type = svm_info$type, - kernel = svm_info$kernel, - cost = svm_info$cost, - gamma = if (svm_info$kernel != "linear") svm_info$gamma else NA, - seed = svm_info$seed, + kernel = svm_info$kernel, + cost = svm_info$cost, + gamma = if (svm_info$kernel != "linear") svm_info$gamma else NA, + seed = svm_info$seed, train_data = svm_info$df_name, - dec = dec + dec = dec ) return(result) } + #' Print method for predict.svm #' @export print.svm.predict <- function(x, ..., n = 10) { @@ -697,95 +622,109 @@ varimp <- function(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, np data <- object$dataset } - # 确定响应变量 + # 确定响应变量和分类水平(适配predict.svm的输出) if (is.null(rvar)) { rvar <- object$rvar } - - # 确定分类水平 if (is.null(lev) && object$type == "classification") { lev <- object$lev } # 创建仅包含解释变量的数据集 X <- data[, object$evar, drop = FALSE] - y <- data[[rvar]] + y <- data[[rvar]] # 原始响应变量(因子/数值) # 设置随机种子 if (!is.na(seed)) { set.seed(seed) } - # 基准预测 + # 基准预测(调用predict.svm,获取实际输出) base_pred <- predict(object, pred_data = data, envir = environment()) + if (is.character(base_pred) || inherits(base_pred, "svm.predict.error")) { + stop("基准预测失败:", base_pred) + } - # 根据任务类型选择性能指标 - if (object$type == "classification") { - # 处理二分类或多分类 - base_metric <- if (nlevels(as.factor(y)) == 2) { - # 二分类:计算AUC - pred_prob_col <- if (length(grep("Prob_", colnames(base_pred))) > 0) { - grep(paste0("Prob_", lev), colnames(base_pred), value = TRUE) - } else { - NULL + # 根据任务类型选择性能指标(适配predict.svm的输出格式) + base_metric <- if (object$type == "classification") { + # 分类任务:二分类用AUC,多分类用准确率 + if (nlevels(as.factor(y)) == 2) { + # 二分类:直接用predict.svm输出的"Prediction"列(成功水平概率)计算AUC + if (!"Prediction" %in% colnames(base_pred)) { + stop("分类预测缺少'Prediction'列(概率值)") } - - if (!is.null(pred_prob_col)) { - pROC::roc(response = y, predictor = base_pred[[pred_prob_col]])$auc[[1]] + # 计算AUC(确保y是因子,prob是概率) + y_bin <- as.numeric(y == lev) # 原始响应变量转0/1(1=成功水平) + auc_val <- try(pROC::roc(response = y_bin, predictor = base_pred$Prediction)$auc[[1]], silent = TRUE) + if (inherits(auc_val, "try-error")) { + # AUC计算失败时降级用准确率 + pred_class <- ifelse(base_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev)) # 概率≥0.5判定为成功水平 + mean(pred_class == as.character(y), na.rm = TRUE) } else { - # 无概率输出,使用准确率 - mean(base_pred$Predicted_Class == y, na.rm = TRUE) + auc_val } } else { - # 多分类:使用准确率 - mean(base_pred$Predicted_Class == y, na.rm = TRUE) + # 多分类:用预测概率矩阵(predict.svm需补充多分类支持,此处降级用准确率) + pred_class <- ifelse(base_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev)) + mean(pred_class == as.character(y), na.rm = TRUE) } } else { - # 回归:计算R² - base_metric <- 1 - sum((base_pred$Predicted_Value - y)^2, na.rm = TRUE) / + # 回归任务:计算R²(适配predict.svm的"Prediction"列) + if (!"Prediction" %in% colnames(base_pred)) { + stop("回归预测缺少'Prediction'列(预测值)") + } + 1 - sum((base_pred$Prediction - y)^2, na.rm = TRUE) / sum((y - mean(y, na.rm = TRUE))^2, na.rm = TRUE) } - # 为每个变量计算排列重要性 + # 为每个变量计算排列重要性(核心逻辑不变,仅适配预测输出) importance_scores <- sapply(object$evar, function(var) { metric_diffs <- numeric(nperm) for (i in 1:nperm) { - # 创建数据副本 + # 创建数据副本,随机打乱当前变量 perm_data <- data - # 随机打乱当前变量 perm_data[[var]] <- sample(perm_data[[var]], replace = FALSE) - # 预测 + # 排列后预测 perm_pred <- predict(object, pred_data = perm_data, envir = environment()) + if (is.character(perm_pred) || inherits(perm_pred, "svm.predict.error")) { + metric_diffs[i] <- NA + next + } - # 计算性能变化 - if (object$type == "classification") { - if (nlevels(as.factor(y)) == 2 && length(grep("Prob_", colnames(perm_pred))) > 0) { - pred_prob_col <- grep(paste0("Prob_", lev), colnames(perm_pred), value = TRUE) - if (length(pred_prob_col) > 0) { - perm_metric <- pROC::roc(response = y, predictor = perm_pred[[pred_prob_col]])$auc[[1]] + # 计算排列后的性能指标 + perm_metric <- if (object$type == "classification") { + if (nlevels(as.factor(y)) == 2) { + y_bin <- as.numeric(y == lev) + auc_val <- try(pROC::roc(response = y_bin, predictor = perm_pred$Prediction)$auc[[1]], silent = TRUE) + if (inherits(auc_val, "try-error")) { + pred_class <- ifelse(perm_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev)) + mean(pred_class == as.character(y), na.rm = TRUE) } else { - perm_metric <- mean(perm_pred$Predicted_Class == y, na.rm = TRUE) + auc_val } } else { - perm_metric <- mean(perm_pred$Predicted_Class == y, na.rm = TRUE) + pred_class <- ifelse(perm_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev)) + mean(pred_class == as.character(y), na.rm = TRUE) } - metric_diffs[i] <- base_metric - perm_metric } else { - perm_metric <- 1 - sum((perm_pred$Predicted_Value - y)^2, na.rm = TRUE) / + 1 - sum((perm_pred$Prediction - y)^2, na.rm = TRUE) / sum((y - mean(y, na.rm = TRUE))^2, na.rm = TRUE) - metric_diffs[i] <- base_metric - perm_metric } + + # 性能变化(基准 - 排列后,值越大变量越重要) + metric_diffs[i] <- base_metric - perm_metric } + # 返回平均性能损失(忽略NA) mean(metric_diffs, na.rm = TRUE) }) - # 创建结果数据框 + # 创建结果数据框(过滤无效值) result <- data.frame( Variable = names(importance_scores), - Importance = as.numeric(importance_scores), + Importance = as.numeric(pmax(importance_scores, 0)), # 重要性不能为负 stringsAsFactors = FALSE ) @@ -796,6 +735,7 @@ varimp <- function(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, np } + #' @export svm_vip_plot <- function(object, size, custom) { tryCatch({ diff --git a/radiant.model/inst/app/tools/analysis/svm_ui.R b/radiant.model/inst/app/tools/analysis/svm_ui.R index 3b15bc4..6d5f7d3 100644 --- a/radiant.model/inst/app/tools/analysis/svm_ui.R +++ b/radiant.model/inst/app/tools/analysis/svm_ui.R @@ -135,9 +135,14 @@ output$ui_svm_wts <- renderUI({ ## 存储预测值UI output$ui_svm_store_pred_name <- renderUI({ - init <- state_init("svm_store_pred_name", "pred_svm") %>% - sub("\\d{1,}$", "", .) %>% - paste0(., ifelse(is.empty(input$svm_kernel), "", input$svm_kernel)) + base_name <- "pred_svm" + kernel_name <- input$svm_kernel # 获取当前选中的核函数 + init <- if (is.empty(kernel_name)) { + base_name + } else { + paste0(base_name, "_", kernel_name) + } + init <- state_init("svm_store_pred_name", init) textInput( "svm_store_pred_name", i18n$t("Store predictions:"), @@ -145,6 +150,16 @@ output$ui_svm_store_pred_name <- renderUI({ ) }) +observeEvent(input$svm_kernel, { + current_value <- tryCatch(isolate(input$svm_store_pred_name), error = function(e) "") + if (!is.null(current_value) && length(current_value) > 0 && nzchar(current_value)) { + if (grepl("^pred_svm(_[a-z]+)?$", current_value)) { + new_value <- paste0("pred_svm", "_", input$svm_kernel) + updateTextInput(session, "svm_store_pred_name", value = new_value) + } + } +}, ignoreInit = TRUE, ignoreNULL = TRUE) + ## 数据集/模型类型切换时重置预测与绘图 observeEvent(input$dataset, { updateSelectInput(session = session, inputId = "svm_predict", selected = "none") @@ -414,40 +429,52 @@ svm_available <- reactive({ ## 存储预测值 observeEvent(input$svm_store_pred, { - req( - pressed(input$svm_run), - !is.empty(input$svm_pred_data), - !is.empty(input$svm_store_pred_name), - inherits(.predict_svm(), "svm.predict") - ) + # 只有最基本的检查,不满足就静默退出 + if (!pressed(input$svm_run) || is.empty(input$svm_pred_data) || is.empty(input$svm_store_pred_name)) { + return() + } + + # 获取预测结果(不管成功失败) pred_result <- .predict_svm() target_data <- r_data[[input$svm_pred_data]] base_col_name <- fix_names(input$svm_store_pred_name) - meta <- attr(pred_result, "svm_meta") - pred_cols <- if (meta$model_type %in% c("classification", "regression")) { - colnames(pred_result)[colnames(pred_result) == "Prediction"] + # 如果预测返回的是错误字符串,直接创建NA列 + if (is.character(pred_result)) { + target_data[[base_col_name]] <- rep(NA_real_, nrow(target_data)) + attr(target_data[[base_col_name]], "error") <- pred_result + r_data[[input$svm_pred_data]] <- target_data + + showNotification( + sprintf("预测失败,已添加NA列 '%s'", base_col_name), + type = "warning" + ) } else { - NULL - } - new_col_names <- if (length(pred_cols) == 1) base_col_name else { - suffix <- gsub("^Prediction", "", pred_cols) - paste0(base_col_name, ifelse(suffix == "", "", paste0("_", suffix))) + # 正常情况:直接提取Prediction列 + if ("Prediction" %in% colnames(pred_result)) { + # 用cbind逻辑,更简单直接 + pred_values <- pred_result$Prediction + + # 处理长度不匹配 + n_target <- nrow(target_data) + n_pred <- length(pred_values) + + if (n_pred < n_target) { + # 预测值少,用NA填充 + pred_values <- c(pred_values, rep(NA_real_, n_target - n_pred)) + } else if (n_pred > n_target) { + # 预测值多,截断 + pred_values <- pred_values[1:n_target] + } + target_data[[base_col_name]] <- pred_values + r_data[[input$svm_pred_data]] <- target_data + } else { + target_data[[base_col_name]] <- rep(NA_real_, nrow(target_data)) + r_data[[input$svm_pred_data]] <- target_data + } } - colnames(pred_result)[match(pred_cols, colnames(pred_result))] <- new_col_names - - merged_data <- merge( - target_data, - pred_result[, c(meta$evar, new_col_names), drop = FALSE], - by = meta$evar, all.x = TRUE - ) - r_data[[input$svm_pred_data]] <- merged_data - showNotification( - sprintf(i18n$t("SVM predictions stored as: %s (in '%s')"), - paste(new_col_names, collapse = ", "), input$svm_pred_data), - type = "message" - ) - updateTextInput(session, "svm_store_pred_name", value = base_col_name) + # 重置输入框 + updateTextInput(session, "svm_store_pred_name", value = "pred_svm") }) ## 下载处理 -- 2.22.0