#' Support Vector Machine #' @export svm <- function(dataset, rvar, evar, type = "classification", lev = "", kernel = "radial", cost = 1, gamma = 1, wts = "None", seed = NA, check = "standardize", form, data_filter = "", arr = "", rows = NULL, envir = parent.frame()) { ## ---- 参数合法性检查---- valid_kernels <- c("linear", "radial", "poly", "sigmoid") if (!kernel %in% valid_kernels) { return(paste0("Kernel must be one of: ", paste(valid_kernels, collapse = ", ")) %>% add_class("svm")) } if (is.empty(cost) || cost <= 0) { return("Cost should be greater than 0." %>% add_class("svm")) } if (is.empty(gamma) || gamma <= 0) { return("Gamma should be greater than 0." %>% add_class("svm")) } if (rvar %in% evar) { return("Response variable contained in the set of explanatory variables.\nPlease update model specification." %>% add_class("svm")) } ## ---- 1. 权重变量处理 ---- vars <- c(rvar, evar) wtsname <- NULL if (wts == "None" || is.empty(wts)) { wts <- NULL } else if (is_string(wts)) { wtsname <- wts vars <- c(rvar, evar, wtsname) } ## ---- 2. 数据集筛选与标准化 ---- df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset)) dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir) is_cat <- sapply(dataset, function(x) is.factor(x) || is.character(x) || is.logical(x)) cat_vars <- names(is_cat)[is_cat] if (length(cat_vars) > 0) { # 2. 字符型转因子(确保后续编码顺序一致) for (var in cat_vars) { if (is.character(dataset[[var]])) { dataset[[var]] <- as.factor(dataset[[var]]) } # 3. 因子/逻辑型转数值(按原始水平顺序编码,如level1→1, level2→2...) dataset[[var]] <- as.numeric(dataset[[var]]) } } # 缺失值处理 dataset <- na.omit(dataset) if (nrow(dataset) == 0) { return("No valid samples after removing missing values. Please check your data." %>% add_class("svm")) } # 标准化 if ("standardize" %in% check) { dataset <- scale_sv(dataset) } ## ---- 3. 分类任务的响应变量 ---- if (type == "classification") { dataset[[rvar]] <- as.factor(dataset[[rvar]]) if (lev == "") { lev <- levels(dataset[[rvar]])[1] } } ## ---- 4. 构建SVM训练参数 ---- if (missing(form)) { form <- as.formula(paste(rvar, "~ .")) } weights_vec <- if (!is.null(wtsname)) dataset[[wtsname]] else NULL svm_input <- list( formula = form, data = dataset, kernel = kernel, cost = cost, gamma = gamma, weights = weights_vec, type = ifelse(type == "classification", "C-classification", "eps-regression"), na.action = na.omit, scale = FALSE, probability = (type == "classification") ) if (!is.na(seed)) set.seed(seed) ## ---- 模型训练---- model <- try({ do.call(e1071::svm, svm_input) }, silent = TRUE) if (inherits(model, "try-error")) { return(paste("Model training failed:", attr(model, "condition")$message) %>% add_class("svm")) } if (is.null(model) || !inherits(model, "svm")) { return("Model training failed: Generated SVM model is invalid" %>% add_class("svm")) } ## ---- 附加关键信息---- model$df_name <- df_name model$rvar <- rvar model$evar <- evar model$type <- type model$lev <- lev model$wtsname <- wtsname model$seed <- seed model$cost <- cost model$gamma <- gamma model$kernel <- kernel ## ---- 返回模型对象---- list( model = model, df_name = df_name, rvar = rvar, evar = evar, type = type, lev = lev, wtsname = wtsname, seed = seed, cost = cost, gamma = gamma, kernel = kernel, dataset = dataset ) %>% add_class(c("svm", "model")) } #' Center or standardize variables in a data frame #' @export scale_sv <- function(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE) { isNum <- sapply(dataset, function(x) is.numeric(x)) if (length(isNum) == 0 || sum(isNum) == 0) { return(dataset) } cn <- names(isNum)[isNum] # 数值变量列名 descr <- attr(dataset, "description") # 保留原描述属性 if (calc) { # 计算均值 ms <- sapply(dataset[, cn, drop = FALSE], function(x) mean(x, na.rm = TRUE)) # 计算标准差 sds <- sapply(dataset[, cn, drop = FALSE], function(x) { sd_val <- sd(x, na.rm = TRUE) ifelse(sd_val == 0, 1, sd_val) }) attr(dataset, "radiant_ms") <- ms attr(dataset, "radiant_sds") <- sds attr(dataset, "radiant_sf") <- sf } else { ms <- attr(dataset, "radiant_ms") sds <- attr(dataset, "radiant_sds") if (is.null(ms) || is.null(sds)) { warning("Training data mean/std not found; skipping standardization.") return(dataset) } } dataset[, cn] <- lapply(seq_along(cn), function(i) { (dataset[[cn[i]]] - ms[i]) / sds[i] }) attr(dataset, "description") <- descr return(dataset) } #' Summary method for the svm function #' @export summary.svm <- function(object, prn = TRUE, ...) { if (is.character(object)) return(object) svm_model <- object$model n_obs <- nrow(object$dataset) wtsname <- object$wtsname cat("Support Vector Machine\n") cat(sprintf("Kernel type : %s (%s)\n", object$kernel, object$type)) cat(sprintf("Data : %s\n", object$df_name)) cat(sprintf("Response variable : %s\n", object$rvar)) if (object$type == "classification") { cat(sprintf("Level : %s in %s\n", object$lev, object$rvar)) } cat(sprintf("Explanatory variables: %s\n", paste(object$evar, collapse = ", "))) if (!is.null(wtsname) && nzchar(wtsname)) { cat(sprintf("Weights used : %s\n", wtsname)) } cat(sprintf("Cost (C) : %.2f\n", object$cost)) if (object$kernel %in% c("radial", "poly", "sigmoid")) { cat(sprintf("Gamma : %.2f\n", object$gamma)) } if (!is.na(object$seed)) cat(sprintf("Seed : %s\n", object$seed)) # 支持向量计数 if (object$type == "classification") { n_sv_per_class <- if (!is.null(svm_model$nSV)) svm_model$nSV else c(0, 0) total_sv <- sum(n_sv_per_class) sv_info <- paste(sprintf("class %s: %d", seq_along(n_sv_per_class), n_sv_per_class), collapse = ", ") cat(sprintf("Support vectors : %d (%s)\n", total_sv, sv_info)) } else { total_sv <- if (!is.null(svm_model$index)) length(svm_model$index) else 0 cat(sprintf("Support vectors : %d\n", total_sv)) } nr_obs <- n_obs if (!is.null(wtsname) && nzchar(wtsname) && wtsname %in% names(object$dataset)) { weights_values <- as.numeric(object$dataset[[wtsname]]) weights_clean <- weights_values[!is.na(weights_values)] if (length(weights_clean) > 0) { sum_weights <- sum(weights_clean) if (sum_weights > 0) nr_obs <- sum_weights } } cat(sprintf("Nr_obs : %s\n", format_nr(nr_obs, dec = 0))) if (prn) { cat("Coefficients/Support Vectors:\n") if (object$kernel == "linear") { feat_coefs_raw <- rep(0, length(object$evar)) bias <- 0 if (!is.null(svm_model$terms) && !is.null(svm_model$coefs) && !is.null(svm_model$SV)) { tryCatch({ model_vars <- attr(svm_model$terms, "term.labels") coef_mat <- as.matrix(svm_model$coefs) sv_mat <- as.matrix(svm_model$SV) if (nrow(coef_mat) > 0 && nrow(sv_mat) > 0) { w_full <- as.numeric(t(coef_mat) %*% sv_mat) sv_colnames <- colnames(sv_mat) if (!is.null(sv_colnames)) { # 按变量名精确匹配 for (i in seq_along(object$evar)) { var_name <- object$evar[i] matching_cols <- which(sv_colnames == var_name) if (length(matching_cols) > 0) { feat_coefs_raw[i] <- sum(w_full[matching_cols]) } } } else { # 无列名则按位置取 n_evar <- length(object$evar) feat_coefs_raw[1:min(n_evar, length(w_full))] <- w_full[1:min(n_evar, length(w_full))] } bias <- as.numeric(-svm_model$rho) } }, error = function(e) { warning("系数提取失败: ", e$message, call. = FALSE) }) } # 输出 coef_data <- data.frame( Variable = c(object$evar, "bias"), Value = c(feat_coefs_raw, bias), stringsAsFactors = FALSE ) for (i in seq(1, nrow(coef_data), 2)) { if (i + 1 > nrow(coef_data)) { cat(sprintf(" %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i])) } else { cat(sprintf(" %-12s: %.2f %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i], coef_data$Variable[i+1], coef_data$Value[i+1])) } } } else { cat(" Non-linear kernel: Coefficients not available\n") } } } #' Plot method for the svm function #' @export plot.svm <- function(x, plots = "none", size = 12, shiny = FALSE, custom = FALSE, ...) { if (is.character(x) || !inherits(x, "svm")) return(x) plot_list <- list() if ("decision_boundary" %in% plots) { if (length(x$evar) != 2) { warning("Decision boundary plot requires exactly 2 explanatory variables") } else if (x$type != "classification") { warning("Decision boundary plot only available for classification") } else if (nlevels(x$dataset[[x$rvar]]) != 2) { warning("Decision boundary plot only available for binary classification") } else { plot_list[["decision_boundary"]] <- svm_boundary_plot(x, size = size, custom = custom) } } if ("margin" %in% plots) { plot_list[["margin"]] <- svm_margin_plot(x, size = size, custom = custom) } if ("vip" %in% plots) { if (length(x$evar) < 2) { warning("Variable importance needs at least 2 explanatory variables") } else { plot_list[["vip"]] <- svm_vip_plot(x, size = size, custom = custom) } } if (length(plot_list) == 0) { return("No valid plots selected for SVM") } ## 返回 patchwork 对象 patchwork::wrap_plots(plot_list, ncol = min(2, length(plot_list))) %>% { if (isTRUE(shiny)) . else print(.) } } #' Predict method for SVM model (FIXED VERSION) #' @export predict.svm <- function(object, pred_data = NULL, pred_cmd = "", dec = 3, envir = parent.frame(), ...) { # 1. 基础校验 if (is.character(object)) return(object) if (!inherits(object, "svm")) stop("Object must be of class 'svm'") if (is.null(object$model) || !inherits(object$model, "svm")) { err_msg <- "Prediction failed: Invalid SVM model" return(structure(list(error = err_msg), class = "svm.predict.error")) } svm_info <- object # 2. 获取预测数据 if (is.character(pred_data) && nzchar(pred_data)) { if (!exists(pred_data, envir = envir)) { return(structure(list(error = sprintf("Dataset '%s' not found", pred_data)), class = "svm.predict.error")) } pred_data <- get(pred_data, envir = envir) } has_data <- !is.null(pred_data) && is.data.frame(pred_data) && nrow(pred_data) > 0 has_cmd <- is.character(pred_cmd) && nzchar(pred_cmd) if (!has_data && !has_cmd) { return(structure(list(error = "Please select data and/or specify a prediction command."), class = "svm.predict.error")) } df_name <- if (is.data.frame(pred_data)) deparse(substitute(pred_data)) else pred_data # 3. 核心预测函数 pfun <- function(model, pred, se, conf_lev) { ## ---- 1. 保证预测集字段顺序与训练集完全一致 ---- pred <- pred[ , svm_info$evar, drop = FALSE] ## ---- 2. 对齐变量类型(再次重复训练阶段的逻辑) ---- train_df <- svm_info$dataset for (v in svm_info$evar) { # 若训练数据该变量是因子/字符 if (is.factor(train_df[[v]])) { ## 预测也转因子并保持相同 level pred[[v]] <- factor(pred[[v]], levels = levels(train_df[[v]])) ## 转 numeric(训练阶段就是把 factor 都转成 numeric 的) pred[[v]] <- as.numeric(pred[[v]]) } else if (is.character(train_df[[v]])) { pred[[v]] <- factor(pred[[v]], levels = unique(train_df[[v]])) pred[[v]] <- as.numeric(pred[[v]]) } else if (is.logical(train_df[[v]])) { pred[[v]] <- as.numeric(pred[[v]]) } else { ## numeric → 保持原样 pred[[v]] <- as.numeric(pred[[v]]) } } ## ---- 3. 应用训练阶段的标准化参数 ---- ms <- attr(train_df, "radiant_ms") sds <- attr(train_df, "radiant_sds") if (!is.null(ms) && !is.null(sds)) { for (v in svm_info$evar) { if (!is.na(ms[v]) && !is.na(sds[v]) && sds[v] != 0) { pred[[v]] <- (pred[[v]] - ms[v]) / sds[v] } } } ## ---- 4. 调用 e1071::predict ---- pred_result <- try( predict( svm_info$model, newdata = pred, probability = svm_info$model$prob.model ), silent = TRUE ) if (inherits(pred_result, "try-error")) return(paste("Prediction failed:", attr(pred_result, "condition")$message)) ## ---- 5. 分类模型输出概率 ---- if (svm_info$type == "classification") { prob <- attr(pred_result, "probabilities") lev <- svm_info$lev if (!is.null(prob)) { if (lev %in% colnames(prob)) { p <- prob[, lev] } else { p <- prob[, 1] # fallback } return(data.frame(Prediction = round(p, dec))) } ## 无概率 → 退化成 0/1 p <- as.character(pred_result) p <- ifelse(p == lev, 1, 0) return(data.frame(Prediction = round(p, dec))) } ## ---- 6. 回归模型 ---- return(data.frame(Prediction = round(as.numeric(pred_result), dec))) } # 4. Radiant 框架式预测 result <- predict_model( object, pfun, "svm.predict", pred_data, pred_cmd, conf_lev = 0.95, se = FALSE, dec, envir = envir ) %>% set_attr("radiant_pred_data", df_name) if (inherits(result, "svm.predict.error")) return(result$error) if (is.character(result)) return(result) result <- add_class(result, "svm.predict") attr(result, "svm_meta") <- list( model_type = svm_info$type, kernel = svm_info$kernel, cost = svm_info$cost, gamma = if (svm_info$kernel != "linear") svm_info$gamma else NA, seed = svm_info$seed, train_data = svm_info$df_name, dec = dec ) return(result) } #' Print method for predict.svm #' @export print.svm.predict <- function(x, ..., n = 10) { if (is.character(x)) { cat(x, "\n") return(invisible(x)) } if (!inherits(x, "svm.predict")) stop("Object must be of class 'svm.predict'") meta <- attr(x, "svm_meta") if (is.null(meta)) { meta <- list( model_type = "Unknown", kernel = "Unknown", cost = NA, gamma = NA, seed = NA, train_data = "Unknown", dec = 1 ) } dec <- meta$dec %||% 1 n_pred <- nrow(x) show_n <- if (n < 0 || n >= n_pred) n_pred else n cat("SVM Predictions\n") model_type <- if (is.empty(meta$model_type)) "Unknown" else as.character(meta$model_type) cat(sprintf("Model Type : %s\n", tools::toTitleCase(model_type))) cat(sprintf("Kernel : %s\n", ifelse(is.empty(meta$kernel), "Unknown", meta$kernel))) cat(sprintf("Cost (C) : %.2f\n", ifelse(is.na(meta$cost), 0, meta$cost))) if (!is.na(meta$gamma)) cat(sprintf("Gamma : %.2f\n", meta$gamma)) if (!is.na(meta$seed)) cat(sprintf("Seed : %s\n", meta$seed)) cat(sprintf("Training Dataset : %s\n", ifelse(is.empty(meta$train_data), "Unknown", meta$train_data))) cat(sprintf("Total Predictions : %d\n", n_pred)) if (n_pred == 0) { cat("No predictions generated.\n") return(invisible(x)) } x_show <- x[1:show_n, , drop = FALSE] num_cols <- sapply(x_show, is.numeric) x_show[num_cols] <- lapply(x_show[num_cols], function(col) { sprintf(paste0("%.", dec, "f"), col) }) # 重新计算列宽 col_widths <- sapply(colnames(x_show), function(cn) { max(nchar(cn), max(nchar(as.character(x_show[[cn]])), na.rm = TRUE)) }) fmt_parts <- paste0("%-", col_widths, "s") fmt <- paste(fmt_parts, collapse = " ") header_vals <- as.character(colnames(x_show)) cat(do.call(sprintf, c(list(fmt), header_vals)), "\n") divider <- paste0(rep("-", sum(col_widths) + 2*(length(col_widths)-1)), collapse = "") cat(divider, "\n") for (i in 1:show_n) { row_vals <- as.character(x_show[i, ]) row_vals[is.na(row_vals)] <- "NA" cat(do.call(sprintf, c(list(fmt), row_vals)), "\n") } if (show_n < n_pred) { cat(sprintf("\n... (showing first %d of %d)\n", show_n, n_pred)) } return(invisible(x)) } # ---- 决策边界 ---- #' @export svm_boundary_plot <- function(object, size, custom) { df <- object$dataset rvar <- object$rvar f1 <- object$evar[1] f2 <- object$evar[2] ## 1. 造网格(因子用水平,数值用序列) grid <- expand.grid( f1 = if (is.factor(df[[f1]])) levels(df[[f1]]) else seq(min(df[[f1]], na.rm = TRUE), max(df[[f1]], na.rm = TRUE), length.out = 200), f2 = if (is.factor(df[[f2]])) levels(df[[f2]]) else seq(min(df[[f2]], na.rm = TRUE), max(df[[f2]], na.rm = TRUE), length.out = 200) ) names(grid) <- c(f1, f2) ## 2. 预测:分类用决策函数值,回归用响应值 pred <- predict(object$model, newdata = grid, type = ifelse(object$type == "classification", "decision", "response"), na.action = na.pass) # 保留 NA grid$pred <- as.numeric(pred) # 强制数值 ## 3. 绘图 p <- ggplot(df, aes_string(f1, f2)) + geom_tile(data = grid, aes_string(fill = "pred"), alpha = 0.65) + geom_point(aes(color = .data[[rvar]]), size = 2) + scale_fill_gradient(low = "#A4C4FF", high = "#FF9A9A", na.value = "grey90") + labs(title = "SVM Decision Boundary", x = f1, y = f2, fill = "Score", color = rvar) + theme_gray(base_size = size) + theme(legend.position = "bottom") if (custom) p else print(p) } # ---- 支持向量 / 间隔 ---- #' @export svm_margin_plot <- function(object, size, custom) { mod <- object$model df <- object$dataset # 把支持向量原始行号转成逻辑标记 sv_idx <- seq_len(nrow(df)) %in% mod$index df$sv <- sv_idx # 长度 = nrow(df) p <- ggplot(df, aes_string(object$evar[1], object$evar[2])) + geom_point(aes(color = sv, shape = sv), size = 2, alpha = 0.8) + scale_color_manual(values = c("FALSE" = "grey60", "TRUE" = "red"), labels = c("Ordinary", "Support Vector")) + labs(title = "Support Vectors & Margin", color = NULL, shape = NULL) + theme_gray(base_size = size) + theme(legend.position = "bottom") if (custom) p else print(p) } #' Variable importance for SVM using permutation importance #' @export varimp <- function(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10) { if (!inherits(object, "svm")) { stop("Object must be of class 'svm'") } # 使用训练数据作为默认 if (is.null(data)) { data <- object$dataset } # 确定响应变量和分类水平(适配predict.svm的输出) if (is.null(rvar)) { rvar <- object$rvar } if (is.null(lev) && object$type == "classification") { lev <- object$lev } # 创建仅包含解释变量的数据集 X <- data[, object$evar, drop = FALSE] y <- data[[rvar]] # 原始响应变量(因子/数值) # 设置随机种子 if (!is.na(seed)) { set.seed(seed) } # 基准预测(调用predict.svm,获取实际输出) base_pred <- predict(object, pred_data = data, envir = environment()) if (is.character(base_pred) || inherits(base_pred, "svm.predict.error")) { stop("基准预测失败:", base_pred) } # 根据任务类型选择性能指标(适配predict.svm的输出格式) base_metric <- if (object$type == "classification") { # 分类任务:二分类用AUC,多分类用准确率 if (nlevels(as.factor(y)) == 2) { # 二分类:直接用predict.svm输出的"Prediction"列(成功水平概率)计算AUC if (!"Prediction" %in% colnames(base_pred)) { stop("分类预测缺少'Prediction'列(概率值)") } # 计算AUC(确保y是因子,prob是概率) y_bin <- as.numeric(y == lev) # 原始响应变量转0/1(1=成功水平) auc_val <- try(pROC::roc(response = y_bin, predictor = base_pred$Prediction)$auc[[1]], silent = TRUE) if (inherits(auc_val, "try-error")) { # AUC计算失败时降级用准确率 pred_class <- ifelse(base_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev)) # 概率≥0.5判定为成功水平 mean(pred_class == as.character(y), na.rm = TRUE) } else { auc_val } } else { # 多分类:用预测概率矩阵(predict.svm需补充多分类支持,此处降级用准确率) pred_class <- ifelse(base_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev)) mean(pred_class == as.character(y), na.rm = TRUE) } } else { # 回归任务:计算R²(适配predict.svm的"Prediction"列) if (!"Prediction" %in% colnames(base_pred)) { stop("回归预测缺少'Prediction'列(预测值)") } 1 - sum((base_pred$Prediction - y)^2, na.rm = TRUE) / sum((y - mean(y, na.rm = TRUE))^2, na.rm = TRUE) } # 为每个变量计算排列重要性(核心逻辑不变,仅适配预测输出) importance_scores <- sapply(object$evar, function(var) { metric_diffs <- numeric(nperm) for (i in 1:nperm) { # 创建数据副本,随机打乱当前变量 perm_data <- data perm_data[[var]] <- sample(perm_data[[var]], replace = FALSE) # 排列后预测 perm_pred <- predict(object, pred_data = perm_data, envir = environment()) if (is.character(perm_pred) || inherits(perm_pred, "svm.predict.error")) { metric_diffs[i] <- NA next } # 计算排列后的性能指标 perm_metric <- if (object$type == "classification") { if (nlevels(as.factor(y)) == 2) { y_bin <- as.numeric(y == lev) auc_val <- try(pROC::roc(response = y_bin, predictor = perm_pred$Prediction)$auc[[1]], silent = TRUE) if (inherits(auc_val, "try-error")) { pred_class <- ifelse(perm_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev)) mean(pred_class == as.character(y), na.rm = TRUE) } else { auc_val } } else { pred_class <- ifelse(perm_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev)) mean(pred_class == as.character(y), na.rm = TRUE) } } else { 1 - sum((perm_pred$Prediction - y)^2, na.rm = TRUE) / sum((y - mean(y, na.rm = TRUE))^2, na.rm = TRUE) } # 性能变化(基准 - 排列后,值越大变量越重要) metric_diffs[i] <- base_metric - perm_metric } # 返回平均性能损失(忽略NA) mean(metric_diffs, na.rm = TRUE) }) # 创建结果数据框(过滤无效值) result <- data.frame( Variable = names(importance_scores), Importance = as.numeric(pmax(importance_scores, 0)), # 重要性不能为负 stringsAsFactors = FALSE ) # 按重要性排序 result <- result[order(-result$Importance), ] rownames(result) <- NULL return(result) } #' @export svm_vip_plot <- function(object, size, custom) { tryCatch({ vi_scores <- varimp( object, rvar = object$rvar, lev = if (object$type == "classification") object$lev else NULL, data = object$dataset, seed = 1234 ) # 确保重要性值是数值类型 vi_scores$Importance <- as.numeric(pmax(vi_scores$Importance, 0)) # 检查数据有效性 if (nrow(vi_scores) == 0 || all(is.na(vi_scores$Importance))) { p <- ggplot() + annotate("text", x = 0.5, y = 0.5, label = "Could not compute variable importance\n(check model/data)", size = 5, color = "red") + theme_void() } else { # 创建条形图 p <- ggplot(vi_scores, aes(x = reorder(Variable, Importance), y = Importance)) + geom_col(fill = "#377eb8") + coord_flip() + labs( title = "SVM Variable Importance (Permutation)", x = NULL, y = ifelse(object$type == "regression", "Importance (R² decrease)", "Importance (Performance decrease)") ) + theme_gray(base_size = size) + theme( axis.text.y = element_text(hjust = 0), panel.grid.major.y = element_line(color = "grey90"), panel.grid.minor = element_blank() ) } }, error = function(e) { p <- ggplot() + annotate("text", x = 0.5, y = 0.5, label = paste("Error calculating importance:\n", e$message), size = 4, color = "red") + theme_void() }) if (custom) { return(p) } else { print(p) invisible(p) } }