#' Support Vector Machine #' @export svm <- function(dataset, rvar, evar, type = "classification", lev = "", kernel = "radial", cost = 1, gamma = 1, wts = "None", seed = NA, check = "standardize", form, data_filter = "", arr = "", rows = NULL, envir = parent.frame()) { ## ---- 参数合法性检查---- valid_kernels <- c("linear", "radial", "poly", "sigmoid") if (!kernel %in% valid_kernels) { return(paste0("Kernel must be one of: ", paste(valid_kernels, collapse = ", ")) %>% add_class("svm")) } if (is.empty(cost) || cost <= 0) { return("Cost should be greater than 0." %>% add_class("svm")) } if (is.empty(gamma) || gamma <= 0) { return("Gamma should be greater than 0." %>% add_class("svm")) } if (rvar %in% evar) { return("Response variable contained in the set of explanatory variables.\nPlease update model specification." %>% add_class("svm")) } ## ---- 1. 权重变量处理 ---- vars <- c(rvar, evar) wtsname <- NULL if (wts == "None" || is.empty(wts)) { wts <- NULL } else if (is_string(wts)) { wtsname <- wts vars <- c(rvar, evar, wtsname) } ## ---- 2. 数据集筛选与标准化 ---- df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset)) dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir) is_cat <- sapply(dataset, function(x) is.factor(x) || is.character(x) || is.logical(x)) cat_vars <- names(is_cat)[is_cat] if (length(cat_vars) > 0) { # 2. 字符型转因子(确保后续编码顺序一致) for (var in cat_vars) { if (is.character(dataset[[var]])) { dataset[[var]] <- as.factor(dataset[[var]]) } # 3. 因子/逻辑型转数值(按原始水平顺序编码,如level1→1, level2→2...) dataset[[var]] <- as.numeric(dataset[[var]]) } } # 缺失值处理 dataset <- na.omit(dataset) if (nrow(dataset) == 0) { return("No valid samples after removing missing values. Please check your data." %>% add_class("svm")) } # 标准化 if ("standardize" %in% check) { dataset <- scale_sv(dataset) } ## ---- 3. 分类任务的响应变量 ---- if (type == "classification") { dataset[[rvar]] <- as.factor(dataset[[rvar]]) if (lev == "") { lev <- levels(dataset[[rvar]])[1] } } ## ---- 4. 构建SVM训练参数 ---- if (missing(form)) { form <- as.formula(paste(rvar, "~ .")) } weights_vec <- if (!is.null(wtsname)) dataset[[wtsname]] else NULL svm_input <- list( formula = form, data = dataset, kernel = kernel, cost = cost, gamma = gamma, weights = weights_vec, type = ifelse(type == "classification", "C-classification", "eps-regression"), na.action = na.omit, scale = FALSE, probability = (type == "classification") ) if (!is.na(seed)) set.seed(seed) ## ---- 模型训练---- model <- try({ do.call(e1071::svm, svm_input) }, silent = TRUE) if (inherits(model, "try-error")) { return(paste("Model training failed:", attr(model, "condition")$message) %>% add_class("svm")) } if (is.null(model) || !inherits(model, "svm")) { return("Model training failed: Generated SVM model is invalid" %>% add_class("svm")) } ## ---- 附加关键信息---- model$df_name <- df_name model$rvar <- rvar model$evar <- evar model$type <- type model$lev <- lev model$wtsname <- wtsname model$seed <- seed model$cost <- cost model$gamma <- gamma model$kernel <- kernel ## ---- 返回模型对象---- list( model = model, df_name = df_name, rvar = rvar, evar = evar, type = type, lev = lev, wtsname = wtsname, seed = seed, cost = cost, gamma = gamma, kernel = kernel, dataset = dataset ) %>% add_class(c("svm", "model")) } #' Center or standardize variables in a data frame #' @export scale_sv <- function(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE) { isNum <- sapply(dataset, function(x) is.numeric(x)) if (length(isNum) == 0 || sum(isNum) == 0) { return(dataset) } cn <- names(isNum)[isNum] # 数值变量列名 descr <- attr(dataset, "description") # 保留原描述属性 if (calc) { # 计算均值 ms <- sapply(dataset[, cn, drop = FALSE], function(x) mean(x, na.rm = TRUE)) # 计算标准差 sds <- sapply(dataset[, cn, drop = FALSE], function(x) { sd_val <- sd(x, na.rm = TRUE) ifelse(sd_val == 0, 1, sd_val) }) attr(dataset, "radiant_ms") <- ms attr(dataset, "radiant_sds") <- sds attr(dataset, "radiant_sf") <- sf } else { ms <- attr(dataset, "radiant_ms") sds <- attr(dataset, "radiant_sds") if (is.null(ms) || is.null(sds)) { warning("Training data mean/std not found; skipping standardization.") return(dataset) } } dataset[, cn] <- lapply(seq_along(cn), function(i) { (dataset[[cn[i]]] - ms[i]) / sds[i] }) attr(dataset, "description") <- descr return(dataset) } #' Summary method for the svm function #' @export summary.svm <- function(object, prn = TRUE, ...) { if (is.character(object)) return(object) svm_model <- object$model n_obs <- nrow(object$dataset) wtsname <- object$wtsname cat("Support Vector Machine\n") cat(sprintf("Kernel type : %s (%s)\n", object$kernel, object$type)) cat(sprintf("Data : %s\n", object$df_name)) cat(sprintf("Response variable : %s\n", object$rvar)) if (object$type == "classification") { cat(sprintf("Level : %s in %s\n", object$lev, object$rvar)) } cat(sprintf("Explanatory variables: %s\n", paste(object$evar, collapse = ", "))) if (!is.null(wtsname) && nzchar(wtsname)) { cat(sprintf("Weights used : %s\n", wtsname)) } cat(sprintf("Cost (C) : %.2f\n", object$cost)) if (object$kernel %in% c("radial", "poly", "sigmoid")) { cat(sprintf("Gamma : %.2f\n", object$gamma)) } if (!is.na(object$seed)) cat(sprintf("Seed : %s\n", object$seed)) # 支持向量计数 if (object$type == "classification") { n_sv_per_class <- if (!is.null(svm_model$nSV)) svm_model$nSV else c(0, 0) total_sv <- sum(n_sv_per_class) sv_info <- paste(sprintf("class %s: %d", seq_along(n_sv_per_class), n_sv_per_class), collapse = ", ") cat(sprintf("Support vectors : %d (%s)\n", total_sv, sv_info)) } else { total_sv <- if (!is.null(svm_model$index)) length(svm_model$index) else 0 cat(sprintf("Support vectors : %d\n", total_sv)) } nr_obs <- n_obs if (!is.null(wtsname) && nzchar(wtsname) && wtsname %in% names(object$dataset)) { weights_values <- as.numeric(object$dataset[[wtsname]]) weights_clean <- weights_values[!is.na(weights_values)] if (length(weights_clean) > 0) { sum_weights <- sum(weights_clean) if (sum_weights > 0) nr_obs <- sum_weights } } cat(sprintf("Nr_obs : %s\n", format_nr(nr_obs, dec = 0))) if (prn) { cat("Coefficients/Support Vectors:\n") if (object$kernel == "linear") { feat_coefs_raw <- rep(0, length(object$evar)) bias <- 0 if (!is.null(svm_model$terms) && !is.null(svm_model$coefs) && !is.null(svm_model$SV)) { tryCatch({ model_vars <- attr(svm_model$terms, "term.labels") coef_mat <- as.matrix(svm_model$coefs) sv_mat <- as.matrix(svm_model$SV) if (nrow(coef_mat) > 0 && nrow(sv_mat) > 0) { w_full <- as.numeric(t(coef_mat) %*% sv_mat) sv_colnames <- colnames(sv_mat) if (!is.null(sv_colnames)) { # 按变量名精确匹配 for (i in seq_along(object$evar)) { var_name <- object$evar[i] matching_cols <- which(sv_colnames == var_name) if (length(matching_cols) > 0) { feat_coefs_raw[i] <- sum(w_full[matching_cols]) } } } else { # 无列名则按位置取 n_evar <- length(object$evar) feat_coefs_raw[1:min(n_evar, length(w_full))] <- w_full[1:min(n_evar, length(w_full))] } bias <- as.numeric(-svm_model$rho) } }, error = function(e) { warning("系数提取失败: ", e$message, call. = FALSE) }) } # 输出 coef_data <- data.frame( Variable = c(object$evar, "bias"), Value = c(feat_coefs_raw, bias), stringsAsFactors = FALSE ) for (i in seq(1, nrow(coef_data), 2)) { if (i + 1 > nrow(coef_data)) { cat(sprintf(" %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i])) } else { cat(sprintf(" %-12s: %.2f %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i], coef_data$Variable[i+1], coef_data$Value[i+1])) } } } else { cat(" Non-linear kernel: Coefficients not available\n") } } } #' Plot method for the svm function #' @export plot.svm <- function(x, plots = "none", size = 12, shiny = FALSE, custom = FALSE, ...) { if (is.character(x) || !inherits(x, "svm")) return(x) plot_list <- list() if ("decision_boundary" %in% plots) { if (length(x$evar) != 2) { warning("Decision boundary plot requires exactly 2 explanatory variables") } else if (x$type != "classification") { warning("Decision boundary plot only available for classification") } else if (nlevels(x$dataset[[x$rvar]]) != 2) { warning("Decision boundary plot only available for binary classification") } else { plot_list[["decision_boundary"]] <- svm_boundary_plot(x, size = size, custom = custom) } } if ("margin" %in% plots) { plot_list[["margin"]] <- svm_margin_plot(x, size = size, custom = custom) } if ("vip" %in% plots) { if (length(x$evar) < 2) { warning("Variable importance needs at least 2 explanatory variables") } else { plot_list[["vip"]] <- svm_vip_plot(x, size = size, custom = custom) } } if (length(plot_list) == 0) { return("No valid plots selected for SVM") } ## 返回 patchwork 对象 patchwork::wrap_plots(plot_list, ncol = min(2, length(plot_list))) %>% { if (isTRUE(shiny)) . else print(.) } } #' Predict method for SVM model #' @export predict.svm <- function(object, pred_data = NULL, pred_cmd = "", dec = 3, envir = parent.frame(), ...) { # 1. 基础校验 if (is.character(object)) return(object) if (!inherits(object, "svm")) stop("Object must be of class 'svm'") if (is.null(object$model) || !inherits(object$model, "svm")) { err_msg <- "Prediction failed: Invalid SVM model" err_obj <- structure(list(error = err_msg), class = "svm.predict.error") return(err_obj) } svm_info <- object # 2. 处理预测数据 if (is.character(pred_data) && nzchar(pred_data)) { if (!exists(pred_data, envir = envir)) { err_obj <- structure(list(error = sprintf("Dataset '%s' not found", pred_data)), class = "svm.predict.error") return(err_obj) } pred_data <- get(pred_data, envir = envir) } has_data <- !is.null(pred_data) && is.data.frame(pred_data) && nrow(pred_data) > 0 has_cmd <- is.character(pred_cmd) && nzchar(pred_cmd) if (!has_data && !has_cmd) { err_obj <- structure(list(error = "Please select data and/or specify a command to generate predictions."), class = "svm.predict.error") return(err_obj) } ## ensure you have a name for the prediction dataset if (is.data.frame(pred_data)) { df_name <- deparse(substitute(pred_data)) } else { df_name <- pred_data } # 3. 预测核心函数 - 修改参数名以匹配NN pfun <- function(model, pred, se, conf_lev) { # 验证数据 missing_vars <- setdiff(svm_info$evar, colnames(pred)) if (length(missing_vars) > 0) { return(paste("Missing variables:", paste(missing_vars, collapse = ", "))) } # 内部标准化 ms <- attr(svm_info$dataset, "radiant_ms") sds <- attr(svm_info$dataset, "radiant_sds") if (!is.null(ms) && !is.null(sds)) { isNum <- sapply(pred, is.numeric) cn <- names(isNum)[isNum] for (var in cn) { if (var %in% names(ms) && var %in% names(sds) && sds[var] != 0) { pred[[var]] <- (pred[[var]] - ms[var]) / sds[var] } } } # 调试信息:检查模型是否启用了概率 cat("Model probability enabled:", svm_info$model$prob.model, "\n") # 执行预测 pred_result <- try({ predict( svm_info$model, newdata = pred, probability = TRUE, # 始终设置为TRUE decision.values = TRUE ) }, silent = TRUE) if (inherits(pred_result, "try-error")) { return(paste("Prediction failed:", attr(pred_result, "condition")$message)) } # 4. 结果整理 if (svm_info$type == "classification") { # 正确的属性名是"probabilities" prob_mat <- attr(pred_result, "probabilities") if (!is.null(prob_mat)) { # 调试:显示概率矩阵的列名 cat("Probability matrix columns:", paste(colnames(prob_mat), collapse = ", "), "\n") # 更智能的列名匹配 target_level <- as.character(svm_info$lev) # 尝试多种匹配方式 matching_col <- NULL # 1. 精确匹配 exact_match <- grep(paste0("^", target_level, "$"), colnames(prob_mat), value = TRUE, ignore.case = TRUE) if (length(exact_match) > 0) { matching_col <- exact_match } # 2. 部分匹配 else { partial_match <- grep(target_level, colnames(prob_mat), value = TRUE, ignore.case = TRUE) if (length(partial_match) > 0) { matching_col <- partial_match[1] # 取第一个匹配 } } # 3. 如果还是没有匹配,尝试逻辑值匹配 if (is.null(matching_col) || length(matching_col) == 0) { if (target_level %in% c("TRUE", "Yes", "1", "1.0")) { logical_match <- grep("TRUE", colnames(prob_mat), value = TRUE, ignore.case = TRUE) if (length(logical_match) > 0) { matching_col <- logical_match[1] } } else if (target_level %in% c("FALSE", "No", "0", "0.0")) { logical_match <- grep("FALSE", colnames(prob_mat), value = TRUE, ignore.case = TRUE) if (length(logical_match) > 0) { matching_col <- logical_match[1] } } } # 4. 作为最后手段,使用第一列 if (is.null(matching_col) || length(matching_col) == 0) { if (ncol(prob_mat) > 0) { matching_col <- colnames(prob_mat)[1] cat("Warning: Using first probability column '", matching_col, "' for level '", target_level, "'\n") } } if (!is.null(matching_col) && length(matching_col) > 0) { surv_prob <- as.numeric(prob_mat[, matching_col]) pred_df <- data.frame( Prediction = round(surv_prob, dec), stringsAsFactors = FALSE ) } else { # 备用方案:使用决策值 decision_vals <- attr(pred_result, "decision.values") if (!is.null(decision_vals)) { # 将决策值转换为概率 pred_prob <- 1 / (1 + exp(-decision_vals)) pred_df <- data.frame( Prediction = round(pred_prob, dec), stringsAsFactors = FALSE ) } else { # 最后手段:使用预测类 pred_class <- as.character(pred_result) pred_prob <- ifelse(pred_class == target_level, 1, 0) pred_df <- data.frame( Prediction = round(pred_prob, dec), stringsAsFactors = FALSE ) } } } else { # 备用方案:使用决策值 decision_vals <- attr(pred_result, "decision.values") if (!is.null(decision_vals)) { # 将决策值转换为概率 pred_prob <- 1 / (1 + exp(-decision_vals)) pred_df <- data.frame( Prediction = round(pred_prob, dec), stringsAsFactors = FALSE ) } else { # 最后手段:使用预测类 target_level <- as.character(svm_info$lev) pred_class <- as.character(pred_result) pred_prob <- ifelse(pred_class == target_level, 1, 0) pred_df <- data.frame( Prediction = round(pred_prob, dec), stringsAsFactors = FALSE ) } } } else { # 回归模型 pred_values <- as.numeric(pred_result) # 关键:添加反标准化处理 rvar_ms <- if (!is.null(ms) && svm_info$rvar %in% names(ms)) ms[svm_info$rvar] else 0 rvar_sds <- if (!is.null(sds) && svm_info$rvar %in% names(sds)) sds[svm_info$rvar] else 1 # 反标准化 pred_values <- pred_values * rvar_sds + rvar_ms pred_df <- data.frame( Prediction = round(pred_values, dec), stringsAsFactors = FALSE ) } return(pred_df) } # 5. 调用预测框架 - 与NN完全一致 result <- predict_model( object, pfun, "svm.predict", # 模型类型 pred_data, pred_cmd, conf_lev = 0.95, se = FALSE, dec, envir = envir ) %>% set_attr("radiant_pred_data", df_name) # 6. 结果元数据 if (inherits(result, "svm.predict.error")) { return(result$error) } if (is.character(result)) return(result) result <- add_class(result, "svm.predict") attr(result, "svm_meta") <- list( model_type = svm_info$type, kernel = svm_info$kernel, cost = svm_info$cost, gamma = if (svm_info$kernel != "linear") svm_info$gamma else NA, seed = svm_info$seed, train_data = svm_info$df_name, dec = dec ) return(result) } #' Print method for predict.svm #' @export print.svm.predict <- function(x, ..., n = 10) { if (is.character(x)) { cat(x, "\n") return(invisible(x)) } if (!inherits(x, "svm.predict")) stop("Object must be of class 'svm.predict'") meta <- attr(x, "svm_meta") if (is.null(meta)) { meta <- list( model_type = "Unknown", kernel = "Unknown", cost = NA, gamma = NA, seed = NA, train_data = "Unknown", dec = 1 ) } dec <- meta$dec %||% 1 n_pred <- nrow(x) show_n <- if (n < 0 || n >= n_pred) n_pred else n cat("SVM Predictions\n") model_type <- if (is.empty(meta$model_type)) "Unknown" else as.character(meta$model_type) cat(sprintf("Model Type : %s\n", tools::toTitleCase(model_type))) cat(sprintf("Kernel : %s\n", ifelse(is.empty(meta$kernel), "Unknown", meta$kernel))) cat(sprintf("Cost (C) : %.2f\n", ifelse(is.na(meta$cost), 0, meta$cost))) if (!is.na(meta$gamma)) cat(sprintf("Gamma : %.2f\n", meta$gamma)) if (!is.na(meta$seed)) cat(sprintf("Seed : %s\n", meta$seed)) cat(sprintf("Training Dataset : %s\n", ifelse(is.empty(meta$train_data), "Unknown", meta$train_data))) cat(sprintf("Total Predictions : %d\n", n_pred)) if (n_pred == 0) { cat("No predictions generated.\n") return(invisible(x)) } x_show <- x[1:show_n, , drop = FALSE] num_cols <- sapply(x_show, is.numeric) x_show[num_cols] <- lapply(x_show[num_cols], function(col) { sprintf(paste0("%.", dec, "f"), col) }) # 重新计算列宽 col_widths <- sapply(colnames(x_show), function(cn) { max(nchar(cn), max(nchar(as.character(x_show[[cn]])), na.rm = TRUE)) }) fmt_parts <- paste0("%-", col_widths, "s") fmt <- paste(fmt_parts, collapse = " ") header_vals <- as.character(colnames(x_show)) cat(do.call(sprintf, c(list(fmt), header_vals)), "\n") divider <- paste0(rep("-", sum(col_widths) + 2*(length(col_widths)-1)), collapse = "") cat(divider, "\n") for (i in 1:show_n) { row_vals <- as.character(x_show[i, ]) row_vals[is.na(row_vals)] <- "NA" cat(do.call(sprintf, c(list(fmt), row_vals)), "\n") } if (show_n < n_pred) { cat(sprintf("\n... (showing first %d of %d)\n", show_n, n_pred)) } return(invisible(x)) } # ---- 决策边界 ---- #' @export svm_boundary_plot <- function(object, size, custom) { df <- object$dataset rvar <- object$rvar f1 <- object$evar[1] f2 <- object$evar[2] ## 1. 造网格(因子用水平,数值用序列) grid <- expand.grid( f1 = if (is.factor(df[[f1]])) levels(df[[f1]]) else seq(min(df[[f1]], na.rm = TRUE), max(df[[f1]], na.rm = TRUE), length.out = 200), f2 = if (is.factor(df[[f2]])) levels(df[[f2]]) else seq(min(df[[f2]], na.rm = TRUE), max(df[[f2]], na.rm = TRUE), length.out = 200) ) names(grid) <- c(f1, f2) ## 2. 预测:分类用决策函数值,回归用响应值 pred <- predict(object$model, newdata = grid, type = ifelse(object$type == "classification", "decision", "response"), na.action = na.pass) # 保留 NA grid$pred <- as.numeric(pred) # 强制数值 ## 3. 绘图 p <- ggplot(df, aes_string(f1, f2)) + geom_tile(data = grid, aes_string(fill = "pred"), alpha = 0.65) + geom_point(aes(color = .data[[rvar]]), size = 2) + scale_fill_gradient(low = "#A4C4FF", high = "#FF9A9A", na.value = "grey90") + labs(title = "SVM Decision Boundary", x = f1, y = f2, fill = "Score", color = rvar) + theme_gray(base_size = size) + theme(legend.position = "bottom") if (custom) p else print(p) } # ---- 支持向量 / 间隔 ---- #' @export svm_margin_plot <- function(object, size, custom) { mod <- object$model df <- object$dataset # 把支持向量原始行号转成逻辑标记 sv_idx <- seq_len(nrow(df)) %in% mod$index df$sv <- sv_idx # 长度 = nrow(df) p <- ggplot(df, aes_string(object$evar[1], object$evar[2])) + geom_point(aes(color = sv, shape = sv), size = 2, alpha = 0.8) + scale_color_manual(values = c("FALSE" = "grey60", "TRUE" = "red"), labels = c("Ordinary", "Support Vector")) + labs(title = "Support Vectors & Margin", color = NULL, shape = NULL) + theme_gray(base_size = size) + theme(legend.position = "bottom") if (custom) p else print(p) } #' Variable importance for SVM using permutation importance #' @export varimp <- function(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10) { if (!inherits(object, "svm")) { stop("Object must be of class 'svm'") } # 使用训练数据作为默认 if (is.null(data)) { data <- object$dataset } # 确定响应变量 if (is.null(rvar)) { rvar <- object$rvar } # 确定分类水平 if (is.null(lev) && object$type == "classification") { lev <- object$lev } # 创建仅包含解释变量的数据集 X <- data[, object$evar, drop = FALSE] y <- data[[rvar]] # 设置随机种子 if (!is.na(seed)) { set.seed(seed) } # 基准预测 base_pred <- predict(object, pred_data = data, envir = environment()) # 根据任务类型选择性能指标 if (object$type == "classification") { # 处理二分类或多分类 base_metric <- if (nlevels(as.factor(y)) == 2) { # 二分类:计算AUC pred_prob_col <- if (length(grep("Prob_", colnames(base_pred))) > 0) { grep(paste0("Prob_", lev), colnames(base_pred), value = TRUE) } else { NULL } if (!is.null(pred_prob_col)) { pROC::roc(response = y, predictor = base_pred[[pred_prob_col]])$auc[[1]] } else { # 无概率输出,使用准确率 mean(base_pred$Predicted_Class == y, na.rm = TRUE) } } else { # 多分类:使用准确率 mean(base_pred$Predicted_Class == y, na.rm = TRUE) } } else { # 回归:计算R² base_metric <- 1 - sum((base_pred$Predicted_Value - y)^2, na.rm = TRUE) / sum((y - mean(y, na.rm = TRUE))^2, na.rm = TRUE) } # 为每个变量计算排列重要性 importance_scores <- sapply(object$evar, function(var) { metric_diffs <- numeric(nperm) for (i in 1:nperm) { # 创建数据副本 perm_data <- data # 随机打乱当前变量 perm_data[[var]] <- sample(perm_data[[var]], replace = FALSE) # 预测 perm_pred <- predict(object, pred_data = perm_data, envir = environment()) # 计算性能变化 if (object$type == "classification") { if (nlevels(as.factor(y)) == 2 && length(grep("Prob_", colnames(perm_pred))) > 0) { pred_prob_col <- grep(paste0("Prob_", lev), colnames(perm_pred), value = TRUE) if (length(pred_prob_col) > 0) { perm_metric <- pROC::roc(response = y, predictor = perm_pred[[pred_prob_col]])$auc[[1]] } else { perm_metric <- mean(perm_pred$Predicted_Class == y, na.rm = TRUE) } } else { perm_metric <- mean(perm_pred$Predicted_Class == y, na.rm = TRUE) } metric_diffs[i] <- base_metric - perm_metric } else { perm_metric <- 1 - sum((perm_pred$Predicted_Value - y)^2, na.rm = TRUE) / sum((y - mean(y, na.rm = TRUE))^2, na.rm = TRUE) metric_diffs[i] <- base_metric - perm_metric } } mean(metric_diffs, na.rm = TRUE) }) # 创建结果数据框 result <- data.frame( Variable = names(importance_scores), Importance = as.numeric(importance_scores), stringsAsFactors = FALSE ) # 按重要性排序 result <- result[order(-result$Importance), ] rownames(result) <- NULL return(result) } #' @export svm_vip_plot <- function(object, size, custom) { tryCatch({ vi_scores <- varimp( object, rvar = object$rvar, lev = if (object$type == "classification") object$lev else NULL, data = object$dataset, seed = 1234 ) # 确保重要性值是数值类型 vi_scores$Importance <- as.numeric(pmax(vi_scores$Importance, 0)) # 检查数据有效性 if (nrow(vi_scores) == 0 || all(is.na(vi_scores$Importance))) { p <- ggplot() + annotate("text", x = 0.5, y = 0.5, label = "Could not compute variable importance\n(check model/data)", size = 5, color = "red") + theme_void() } else { # 创建条形图 p <- ggplot(vi_scores, aes(x = reorder(Variable, Importance), y = Importance)) + geom_col(fill = "#377eb8") + coord_flip() + labs( title = "SVM Variable Importance (Permutation)", x = NULL, y = ifelse(object$type == "regression", "Importance (R² decrease)", "Importance (Performance decrease)") ) + theme_gray(base_size = size) + theme( axis.text.y = element_text(hjust = 0), panel.grid.major.y = element_line(color = "grey90"), panel.grid.minor = element_blank() ) } }, error = function(e) { p <- ggplot() + annotate("text", x = 0.5, y = 0.5, label = paste("Error calculating importance:\n", e$message), size = 4, color = "red") + theme_void() }) if (custom) { return(p) } else { print(p) invisible(p) } }