#' Support Vector Machine #' @export svm <- function(dataset, rvar, evar, type = "classification", lev = "", kernel = "radial", cost = 1, gamma = 1, wts = "None", seed = NA, check = "standardize", form, data_filter = "", arr = "", rows = NULL, envir = parent.frame()) { ## ---- 参数合法性检查(SVM特有) ---- valid_kernels <- c("linear", "radial", "poly", "sigmoid") if (!kernel %in% valid_kernels) { return(paste0("Kernel must be one of: ", paste(valid_kernels, collapse = ", ")) %>% add_class("svm")) } if (is.empty(cost) || cost <= 0) { return("Cost should be greater than 0." %>% add_class("svm")) } if (is.empty(gamma) || gamma <= 0) { return("Gamma should be greater than 0." %>% add_class("svm")) } if (rvar %in% evar) { return("Response variable contained in the set of explanatory variables.\nPlease update model specification." %>% add_class("svm")) } ## ---- 1. 权重变量处理 ---- vars <- c(rvar, evar) wtsname <- NULL if (wts == "None" || is.empty(wts)) { wts <- NULL } else if (is_string(wts)) { wtsname <- wts vars <- c(rvar, evar, wtsname) } ## ---- 2. 数据集筛选与标准化 ---- df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset)) dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir) if ("standardize" %in% check) { dataset <- scale_df(dataset) } ## ---- 3. 分类任务的响应变量(转为因子) ---- if (type == "classification") { dataset[[rvar]] <- as.factor(dataset[[rvar]]) if (lev == "") { lev <- levels(dataset[[rvar]])[1] } } ## ---- 4. 构建SVM训练参数 ---- if (missing(form)) { form <- as.formula(paste(rvar, "~ .")) } weights_vec <- if (!is.null(wtsname)) dataset[[wtsname]] else NULL svm_input <- list( formula = form, data = dataset, kernel = kernel, cost = cost, gamma = gamma, weights = weights_vec, type = ifelse(type == "classification", "C-classification", "eps-regression"), na.action = na.omit, scale = FALSE, probability = (type == "classification") ) if (!is.na(seed)) set.seed(seed) ## ---- 6. 训练模型 ---- model <- do.call(e1071::svm, svm_input) ## ---- 7. 附加关键信息 ---- model$df_name <- df_name model$rvar <- rvar model$evar <- evar model$type <- type model$lev <- lev model$wtsname <- wtsname model$seed <- seed model$cost <- cost model$gamma <- gamma model$kernel <- kernel as.list(environment()) %>% add_class(c("svm", "model")) } #' Center or standardize variables in a data frame #' @export scale_df <- function(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE) { isNum <- sapply(dataset, function(x) is.numeric(x)) if (length(isNum) == 0 || sum(isNum) == 0) { return(dataset) } cn <- names(isNum)[isNum] # 数值变量列名 descr <- attr(dataset, "description") # 保留原描述属性 if (calc) { # 计算均值(忽略NA) ms <- sapply(dataset[, cn, drop = FALSE], function(x) mean(x, na.rm = TRUE)) # 计算标准差(忽略NA,样本标准差ddof=1,避免除以零) sds <- sapply(dataset[, cn, drop = FALSE], function(x) { sd_val <- sd(x, na.rm = TRUE) ifelse(sd_val == 0, 1, sd_val) }) attr(dataset, "radiant_ms") <- ms attr(dataset, "radiant_sds") <- sds attr(dataset, "radiant_sf") <- sf } else { ms <- attr(dataset, "radiant_ms") sds <- attr(dataset, "radiant_sds") if (is.null(ms) || is.null(sds)) { warning("Training data mean/std not found; skipping standardization.") return(dataset) } } dataset[, cn] <- lapply(seq_along(cn), function(i) { (dataset[[cn[i]]] - ms[i]) / sds[i] }) attr(dataset, "description") <- descr return(dataset) } #' Summary method for the svm function #' @export summary.svm <- function(object, prn = TRUE, ...) { if (is.character(object)) return(object) svm_model <- object$model n_obs <- nrow(object$dataset) wtsname <- object$wtsname # 可能是 NULL 或长度 0 字符 cat("Support Vector Machine\n") cat(sprintf("Kernel type : %s (%s)\n", object$kernel, object$type)) cat(sprintf("Data : %s\n", object$df_name)) cat(sprintf("Response variable : %s\n", object$rvar)) if (object$type == "classification") { cat(sprintf("Level : %s in %s\n", object$lev, object$rvar)) } cat(sprintf("Explanatory variables: %s\n", paste(object$evar, collapse = ", "))) if (!is.null(wtsname) && length(wtsname) && wtsname != "") { cat(sprintf("Weights used : %s\n", wtsname)) } cat(sprintf("Cost (C) : %.2f\n", object$cost)) if (object$kernel %in% c("radial", "poly", "sigmoid")) { cat(sprintf("Gamma : %.2f\n", object$gamma)) } if (!is.na(object$seed)) cat(sprintf("Seed : %s\n", object$seed)) if (object$type == "classification") { n_sv_per_class <- svm_model$nSV total_sv <- sum(n_sv_per_class) sv_info <- paste(sprintf("class %s: %d", seq_along(n_sv_per_class), n_sv_per_class), collapse = ", ") cat(sprintf("Support vectors : %d (%s)\n", total_sv, sv_info)) } else { total_sv <- sum(svm_model$nSV) cat(sprintf("Support vectors : %d\n", total_sv)) } ## ---- 权重样本数计算同样保护 ---- if (!is.null(wtsname) && length(wtsname) && wtsname != "") { weights_values <- as.numeric(object$dataset[[wtsname]]) nr_obs <- if (all(!is.na(weights_values))) sum(weights_values, na.rm = TRUE) else n_obs } else { nr_obs <- n_obs } cat(sprintf("Nr_obs : %s\n", format_nr(nr_obs, dec = 0))) ## ---- 系数输出(仅线性核) ---- if (prn) { cat("Coefficients/Support Vectors:\n") if (object$kernel == "linear") { if (is.null(svm_model$w) || length(svm_model$w) == 0) { cat(" Linear kernel coefficients not available (possible reasons:\n") cat(" - Insufficient support vectors (data may be linearly inseparable)\n") cat(" - Model training did not converge)\n") if (object$type == "classification") { cat(sprintf(" Support vectors count: %d\n", sum(svm_model$nSV))) } } else { feat_coefs_raw <- as.numeric(svm_model$w) bias <- as.numeric(-svm_model$rho) n_evar <- length(object$evar) feat_coefs <- matrix(data = feat_coefs_raw, nrow = 1, ncol = n_evar, byrow = TRUE, dimnames = list(NULL, object$evar)) coef_data <- data.frame( Variable = c(object$evar, "bias"), Value = as.numeric(c(as.vector(feat_coefs), bias)), stringsAsFactors = FALSE, check.names = FALSE ) for (i in seq(1, nrow(coef_data), 2)) { if (i + 1 > nrow(coef_data)) { cat(sprintf(" %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i])) } else { cat(sprintf(" %-12s: %.2f %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i], coef_data$Variable[i+1], coef_data$Value[i+1])) } } } } else { cat(" Non-linear kernel: Coefficients not available\n") } } } #' Plot method for the svm function #' @export plot.svm <- function(x, plots = "none", size = 12, shiny = FALSE, custom = FALSE, ...) { if (is.character(x) || !inherits(x, "svm")) return(x) plot_list <- list() if ("decision_boundary" %in% plots) { if (length(x$evar) != 2) { warning("Decision boundary plot requires exactly 2 explanatory variables") } else if (x$type != "classification") { warning("Decision boundary plot only available for classification") } else if (nlevels(x$dataset[[x$rvar]]) != 2) { warning("Decision boundary plot only available for binary classification") } else { plot_list[["decision_boundary"]] <- svm_boundary_plot(x, size = size, custom = custom) } } if ("margin" %in% plots) { plot_list[["margin"]] <- svm_margin_plot(x, size = size, custom = custom) } if ("vip" %in% plots) { if (length(x$evar) < 2) { warning("Variable importance needs at least 2 explanatory variables") } else { plot_list[["vip"]] <- svm_vip_plot(x, size = size, custom = custom) } } if (length(plot_list) == 0) { return("No valid plots selected for SVM") } ## 返回 patchwork 对象(Shiny 自动打印) patchwork::wrap_plots(plot_list, ncol = min(2, length(plot_list))) %>% { if (isTRUE(shiny)) . else print(.) } } #' Predict method for the svm function #' @export predict.svm <- function(object, pred_data = NULL, pred_cmd = "", dec = 3, envir = parent.frame(), ...) { ## 1. 基础校验 if (is.character(object)) return(object) if (!inherits(object, "svm")) stop("Object must be of class 'svm'") ## 2.1 确定预测数据源 if (is.null(pred_data) || is.character(pred_data)) { # 当pred_data为NULL或字符(数据集名)时,使用训练数据 pred_data_raw <- object$dataset[, object$evar, drop = FALSE] } else { # 当pred_data是数据框时,直接使用 pred_data_raw <- as.data.frame(pred_data) } pred_names <- colnames(pred_data_raw) missing_vars <- setdiff(object$evar, pred_names) if (length(missing_vars) > 0) { msg <- paste0( "NA\n" ) return(msg %>% add_class("svm.predict")) } pred_data <- pred_data_raw[, object$evar, drop = FALSE] if (!is.empty(pred_cmd)) { pred_cmd <- gsub("\\s{2,}", " ", pred_cmd) %>% gsub(";\\s+", ";", .) %>% strsplit(";")[[1]] for (cmd in pred_cmd) { if (grepl("=", cmd)) { var_val <- strsplit(trimws(cmd), "=")[[1]] var <- trimws(var_val[1]) val <- try(eval(parse(text = trimws(var_val[2])), envir = envir), silent = TRUE) if (!inherits(val, "try-error")) pred_data[[var]] <- val else warning(sprintf("Invalid command '%s': using original values", cmd)) } } pred_data <- unique(pred_data) } ## 3. 变量类型与因子水平校验 train_types <- sapply(object$dataset[, object$evar], class) pred_types <- sapply(pred_data, class) type_mismatch <- names(which(train_types != pred_types)) if (length(type_mismatch) > 0) { return(paste0("Variable type mismatch (train vs pred):\n", paste(sprintf(" %s: %s vs %s", type_mismatch, train_types[type_mismatch], pred_types[type_mismatch]), collapse = "\n")) %>% add_class("svm.predict")) } for (var in object$evar) { if (is.factor(object$dataset[[var]])) { train_levs <- levels(object$dataset[[var]]) pred_levs <- levels(pred_data[[var]]) if (!identical(train_levs, pred_levs)) { pred_data[[var]] <- factor(pred_data[[var]], levels = train_levs) warning(sprintf("Aligned factor levels for '%s' to match training data", var)) } } } ## 4. 标准化对齐 train_ms <- attr(object$dataset, "radiant_ms") train_sds <- attr(object$dataset, "radiant_sds") train_sf <- attr(object$dataset, "radiant_sf") %||% 2 if (!is.null(train_ms) && !is.null(train_sds)) { pred_data <- scale_df( dataset = pred_data, center = TRUE, scale = TRUE, sf = train_sf, wts = NULL, calc = FALSE ) attr(pred_data, "radiant_ms") <- train_ms attr(pred_data, "radiant_sds") <- train_sds } ## 5. 生成预测值 predict_args <- list( object = object$model, newdata = pred_data, na.action = na.omit ) pred_result <- if (object$type == "classification") { predict_args$type <- "class" pred_class <- do.call(predict, predict_args) if (isTRUE(object$model$param$probability)) { predict_args$type <- "probabilities" pred_prob <- do.call(predict, predict_args) %>% as.data.frame() %>% set_colnames(paste0("Prob_", colnames(.))) data.frame( Predicted_Class = as.character(pred_class), pred_prob, stringsAsFactors = FALSE ) } else { data.frame(Predicted_Class = as.character(pred_class), stringsAsFactors = FALSE) } } else { predict_args$type <- "response" pred_val <- do.call(predict, predict_args) pred_val <- as.numeric(pred_val) data.frame(Predicted_Value = round(pred_val, dec), stringsAsFactors = FALSE) } pred_result <- cbind(pred_data, pred_result) attr(pred_result, "svm_meta") <- list( kernel = object$kernel, cost = object$cost, gamma = if (object$kernel != "linear") object$gamma else NA, seed = object$seed, train_data = object$df_name, model_type = object$type ) attr(pred_result, "class") <- c("svm.predict", "data.frame") return(pred_result) } #' Print method for predict.svm #' @export print.svm.predict <- function(x, ..., n = 10) { if (is.character(x)) { cat(x, "\n") return(invisible(x)) } if (!inherits(x, "svm.predict")) stop("Object must be of class 'svm.predict'") meta <- attr(x, "svm_meta") n_pred <- nrow(x) show_n <- if (n < 0 || n >= n_pred) n_pred else n cat("SVM Predictions\n") cat(sprintf("Model Type : %s\n", tools::toTitleCase(meta$model_type))) cat(sprintf("Kernel : %s\n", meta$kernel)) cat(sprintf("Cost (C) : %.2f\n", meta$cost)) if (!is.na(meta$gamma)) cat(sprintf("Gamma : %.2f\n", meta$gamma)) if (!is.na(meta$seed)) cat(sprintf("Seed : %s\n", meta$seed)) cat(sprintf("Training Dataset : %s\n", meta$train_data)) cat(sprintf("Total Predictions : %d\n", n_pred)) if (n_pred == 0) { cat("No predictions generated.\n") return(invisible(x)) } x_show <- x[1:show_n, , drop = FALSE] # 确保保持数据框结构 col_widths <- sapply(colnames(x_show), function(cn) { max(nchar(cn), max(nchar(as.character(x_show[[cn]])), na.rm = TRUE)) }) fmt_parts <- paste0("%-", col_widths, "s") fmt <- paste(fmt_parts, collapse = " ") header_vals <- as.character(colnames(x_show)) cat(do.call(sprintf, c(list(fmt), header_vals)), "\n") divider <- paste0(rep("-", sum(col_widths) + 2*(length(col_widths)-1)), collapse = "") cat(divider, "\n") for (i in 1:show_n) { row_vals <- as.character(x_show[i, ]) row_vals[is.na(row_vals)] <- "NA" cat(do.call(sprintf, c(list(fmt), row_vals)), "\n") } if (show_n < n_pred) { cat(sprintf("\n... (showing first %d of %d; use 'n=-1' to view all)\n", show_n, n_pred)) } return(invisible(x)) } #' Cross-validation for SVM #' @export cv.svm <- function(object, K = 5, repeats = 1, kernel = c("linear", "radial"), cost = 2^(-2:2), gamma = 2^(-2:2), seed = 1234, trace = TRUE, fun, ...) { if (!inherits(object, "svm")) stop("Object must be of class 'svm'") tune_grid <- expand.grid( kernel = kernel, cost = cost, gamma = if (any(kernel %in% c("radial", "poly", "sigmoid"))) gamma else NA ) cv_results <- data.frame( mean_perf = rep(NA, nrow(tune_grid)), std_perf = rep(NA, nrow(tune_grid)), kernel = tune_grid$kernel, cost = tune_grid$cost, gamma = tune_grid$gamma ) cv_results } # ---- 决策边界 ---- #' @export svm_boundary_plot <- function(object, size, custom) { df <- object$dataset rvar <- object$rvar f1 <- object$evar[1] f2 <- object$evar[2] ## 1. 造网格(因子用水平,数值用序列) grid <- expand.grid( f1 = if (is.factor(df[[f1]])) levels(df[[f1]]) else seq(min(df[[f1]], na.rm = TRUE), max(df[[f1]], na.rm = TRUE), length.out = 200), f2 = if (is.factor(df[[f2]])) levels(df[[f2]]) else seq(min(df[[f2]], na.rm = TRUE), max(df[[f2]], na.rm = TRUE), length.out = 200) ) names(grid) <- c(f1, f2) ## 2. 预测:分类用决策函数值,回归用响应值 pred <- predict(object$model, newdata = grid, type = ifelse(object$type == "classification", "decision", "response"), na.action = na.pass) # 保留 NA grid$pred <- as.numeric(pred) # 强制数值 ## 3. 绘图 p <- ggplot(df, aes_string(f1, f2)) + geom_tile(data = grid, aes_string(fill = "pred"), alpha = 0.65) + geom_point(aes(color = .data[[rvar]]), size = 2) + scale_fill_gradient(low = "#A4C4FF", high = "#FF9A9A", na.value = "grey90") + labs(title = "SVM Decision Boundary", x = f1, y = f2, fill = "Score", color = rvar) + theme_gray(base_size = size) + theme(legend.position = "bottom") if (custom) p else print(p) } # ---- 支持向量 / 间隔 ---- #' @export svm_margin_plot <- function(object, size, custom) { mod <- object$model df <- object$dataset # 把支持向量原始行号转成逻辑标记 sv_idx <- seq_len(nrow(df)) %in% mod$index df$sv <- sv_idx # 长度 = nrow(df) p <- ggplot(df, aes_string(object$evar[1], object$evar[2])) + geom_point(aes(color = sv, shape = sv), size = 2, alpha = 0.8) + scale_color_manual(values = c("FALSE" = "grey60", "TRUE" = "red"), labels = c("Ordinary", "Support Vector")) + labs(title = "Support Vectors & Margin", color = NULL, shape = NULL) + theme_gray(base_size = size) + theme(legend.position = "bottom") if (custom) p else print(p) } #' Variable importance for SVM using permutation importance #' @export varimp <- function(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10) { if (!inherits(object, "svm")) { stop("Object must be of class 'svm'") } # 使用训练数据作为默认 if (is.null(data)) { data <- object$dataset } # 确定响应变量 if (is.null(rvar)) { rvar <- object$rvar } # 确定分类水平 if (is.null(lev) && object$type == "classification") { lev <- object$lev } # 创建仅包含解释变量的数据集 X <- data[, object$evar, drop = FALSE] y <- data[[rvar]] # 设置随机种子 if (!is.na(seed)) { set.seed(seed) } # 基准预测 base_pred <- predict(object, pred_data = data, envir = environment()) # 根据任务类型选择性能指标 if (object$type == "classification") { # 处理二分类或多分类 base_metric <- if (nlevels(as.factor(y)) == 2) { # 二分类:计算AUC pred_prob_col <- if (length(grep("Prob_", colnames(base_pred))) > 0) { grep(paste0("Prob_", lev), colnames(base_pred), value = TRUE) } else { NULL } if (!is.null(pred_prob_col)) { pROC::roc(response = y, predictor = base_pred[[pred_prob_col]])$auc[[1]] } else { # 无概率输出,使用准确率 mean(base_pred$Predicted_Class == y, na.rm = TRUE) } } else { # 多分类:使用准确率 mean(base_pred$Predicted_Class == y, na.rm = TRUE) } } else { # 回归:计算R² base_metric <- 1 - sum((base_pred$Predicted_Value - y)^2, na.rm = TRUE) / sum((y - mean(y, na.rm = TRUE))^2, na.rm = TRUE) } # 为每个变量计算排列重要性 importance_scores <- sapply(object$evar, function(var) { metric_diffs <- numeric(nperm) for (i in 1:nperm) { # 创建数据副本 perm_data <- data # 随机打乱当前变量 perm_data[[var]] <- sample(perm_data[[var]], replace = FALSE) # 预测 perm_pred <- predict(object, pred_data = perm_data, envir = environment()) # 计算性能变化 if (object$type == "classification") { if (nlevels(as.factor(y)) == 2 && length(grep("Prob_", colnames(perm_pred))) > 0) { pred_prob_col <- grep(paste0("Prob_", lev), colnames(perm_pred), value = TRUE) if (length(pred_prob_col) > 0) { perm_metric <- pROC::roc(response = y, predictor = perm_pred[[pred_prob_col]])$auc[[1]] } else { perm_metric <- mean(perm_pred$Predicted_Class == y, na.rm = TRUE) } } else { perm_metric <- mean(perm_pred$Predicted_Class == y, na.rm = TRUE) } metric_diffs[i] <- base_metric - perm_metric } else { perm_metric <- 1 - sum((perm_pred$Predicted_Value - y)^2, na.rm = TRUE) / sum((y - mean(y, na.rm = TRUE))^2, na.rm = TRUE) metric_diffs[i] <- base_metric - perm_metric } } mean(metric_diffs, na.rm = TRUE) }) # 创建结果数据框 result <- data.frame( Variable = names(importance_scores), Importance = as.numeric(importance_scores), stringsAsFactors = FALSE ) # 按重要性排序 result <- result[order(-result$Importance), ] rownames(result) <- NULL return(result) } #' @export svm_vip_plot <- function(object, size, custom) { tryCatch({ vi_scores <- varimp( object, rvar = object$rvar, lev = if (object$type == "classification") object$lev else NULL, data = object$dataset, seed = 1234 ) # 确保重要性值是数值类型 vi_scores$Importance <- as.numeric(pmax(vi_scores$Importance, 0)) # 检查数据有效性 if (nrow(vi_scores) == 0 || all(is.na(vi_scores$Importance))) { p <- ggplot() + annotate("text", x = 0.5, y = 0.5, label = "Could not compute variable importance\n(check model/data)", size = 5, color = "red") + theme_void() } else { # 创建条形图 p <- ggplot(vi_scores, aes(x = reorder(Variable, Importance), y = Importance)) + geom_col(fill = "#377eb8") + coord_flip() + labs( title = "SVM Variable Importance (Permutation)", x = NULL, y = ifelse(object$type == "regression", "Importance (R² decrease)", "Importance (Performance decrease)") ) + theme_gray(base_size = size) + theme( axis.text.y = element_text(hjust = 0), panel.grid.major.y = element_line(color = "grey90"), panel.grid.minor = element_blank() ) } }, error = function(e) { p <- ggplot() + annotate("text", x = 0.5, y = 0.5, label = paste("Error calculating importance:\n", e$message), size = 4, color = "red") + theme_void() }) if (custom) { return(p) } else { print(p) invisible(p) } }