diff --git a/radiant.model/NAMESPACE b/radiant.model/NAMESPACE index d1cd23715c42f15c42c1720eb7ec3905d998962b..a506c5171a4d459b54f38e866ccf4454aad61045 100644 --- a/radiant.model/NAMESPACE +++ b/radiant.model/NAMESPACE @@ -1,277 +1,279 @@ -# Generated by roxygen2: do not edit by hand - -S3method(plot,confusion) -S3method(plot,coxp) -S3method(plot,crs) -S3method(plot,crtree) -S3method(plot,dtree) -S3method(plot,evalbin) -S3method(plot,evalreg) -S3method(plot,gbt) -S3method(plot,logistic) -S3method(plot,mnl) -S3method(plot,mnl.predict) -S3method(plot,model.predict) -S3method(plot,nb) -S3method(plot,nb.predict) -S3method(plot,nn) -S3method(plot,regress) -S3method(plot,repeater) -S3method(plot,rforest) -S3method(plot,rforest.predict) -S3method(plot,simulater) -S3method(plot,svm) -S3method(plot,uplift) -S3method(predict,coxp) -S3method(predict,crtree) -S3method(predict,gbt) -S3method(predict,logistic) -S3method(predict,mnl) -S3method(predict,nb) -S3method(predict,nn) -S3method(predict,regress) -S3method(predict,rforest) -S3method(predict,svm) -S3method(print,coxp.predict) -S3method(print,crtree.predict) -S3method(print,gbt.predict) -S3method(print,logistic.predict) -S3method(print,mnl.predict) -S3method(print,nb.predict) -S3method(print,nn.predict) -S3method(print,regress.predict) -S3method(print,rforest.predict) -S3method(print,svm.predict) -S3method(render,DiagrammeR) -S3method(sensitivity,dtree) -S3method(store,coxp.predict) -S3method(store,crs) -S3method(store,mnl.predict) -S3method(store,model) -S3method(store,model.predict) -S3method(store,nb.predict) -S3method(store,rforest.predict) -S3method(summary,confusion) -S3method(summary,coxp) -S3method(summary,crs) -S3method(summary,crtree) -S3method(summary,dtree) -S3method(summary,evalbin) -S3method(summary,evalreg) -S3method(summary,gbt) -S3method(summary,logistic) -S3method(summary,mnl) -S3method(summary,nb) -S3method(summary,nn) -S3method(summary,regress) -S3method(summary,repeater) -S3method(summary,rforest) -S3method(summary,simulater) -S3method(summary,svm) -S3method(summary,uplift) -export(.as_int) -export(.as_num) -export(MAE) -export(RMSE) -export(Rsq) -export(ann) -export(auc) -export(confint_robust) -export(confusion) -export(coxp) -export(crs) -export(crtree) -export(cv.crtree) -export(cv.gbt) -export(cv.nn) -export(cv.rforest) -export(cv.svm) -export(dtree) -export(dtree_parser) -export(evalbin) -export(evalreg) -export(find_max) -export(find_min) -export(gbt) -export(logistic) -export(minmax) -export(mnl) -export(nb) -export(nn) -export(onehot) -export(pdp_plot) -export(pred_plot) -export(predict_model) -export(print_predict_model) -export(profit) -export(radiant.model) -export(radiant.model_viewer) -export(radiant.model_window) -export(regress) -export(remove_comments) -export(repeater) -export(rforest) -export(rig) -export(scale_df) -export(sdw) -export(sensitivity) -export(sim_cleaner) -export(sim_cor) -export(sim_splitter) -export(sim_summary) -export(simulater) -export(svm) -export(test_specs) -export(uplift) -export(var_check) -export(varimp) -export(varimp_plot) -export(write.coeff) -import(ggplot2) -import(radiant.data) -import(shiny) -importFrom(DiagrammeR,DiagrammeR) -importFrom(DiagrammeR,DiagrammeROutput) -importFrom(DiagrammeR,mermaid) -importFrom(DiagrammeR,renderDiagrammeR) -importFrom(NeuralNetTools,garson) -importFrom(NeuralNetTools,olden) -importFrom(NeuralNetTools,plotnet) -importFrom(broom,augment) -importFrom(car,linearHypothesis) -importFrom(car,vif) -importFrom(data.tree,Clone) -importFrom(data.tree,FormatPercent) -importFrom(data.tree,Get) -importFrom(data.tree,Traverse) -importFrom(data.tree,as.Node) -importFrom(data.tree,isLeaf) -importFrom(data.tree,isNotLeaf) -importFrom(data.tree,isNotRoot) -importFrom(dplyr,across) -importFrom(dplyr,arrange) -importFrom(dplyr,arrange_at) -importFrom(dplyr,bind_cols) -importFrom(dplyr,bind_rows) -importFrom(dplyr,data_frame) -importFrom(dplyr,desc) -importFrom(dplyr,distinct_at) -importFrom(dplyr,everything) -importFrom(dplyr,filter) -importFrom(dplyr,first) -importFrom(dplyr,funs) -importFrom(dplyr,group_by) -importFrom(dplyr,group_by_) -importFrom(dplyr,group_by_at) -importFrom(dplyr,inner_join) -importFrom(dplyr,last) -importFrom(dplyr,min_rank) -importFrom(dplyr,mutate) -importFrom(dplyr,mutate_) -importFrom(dplyr,mutate_all) -importFrom(dplyr,mutate_at) -importFrom(dplyr,mutate_if) -importFrom(dplyr,near) -importFrom(dplyr,pull) -importFrom(dplyr,rename) -importFrom(dplyr,sample_n) -importFrom(dplyr,select) -importFrom(dplyr,select_at) -importFrom(dplyr,slice) -importFrom(dplyr,summarise) -importFrom(dplyr,summarise_) -importFrom(dplyr,summarise_all) -importFrom(dplyr,summarise_at) -importFrom(dplyr,summarize) -importFrom(dplyr,ungroup) -importFrom(e1071,naiveBayes) -importFrom(ggplot2,autoplot) -importFrom(ggrepel,geom_text_repel) -importFrom(graphics,par) -importFrom(import,from) -importFrom(lubridate,is.Date) -importFrom(lubridate,now) -importFrom(magrittr,"%<>%") -importFrom(magrittr,"%>%") -importFrom(magrittr,"%T>%") -importFrom(magrittr,extract2) -importFrom(magrittr,set_colnames) -importFrom(magrittr,set_names) -importFrom(magrittr,set_rownames) -importFrom(nnet,nnet) -importFrom(nnet,nnet.formula) -importFrom(patchwork,plot_annotation) -importFrom(patchwork,wrap_plots) -importFrom(pdp,partial) -importFrom(psych,cohen.kappa) -importFrom(radiant.data,launch) -importFrom(radiant.data,set_attr) -importFrom(radiant.data,visualize) -importFrom(ranger,ranger) -importFrom(rlang,":=") -importFrom(rlang,.data) -importFrom(rlang,parse_exprs) -importFrom(rpart,prune.rpart) -importFrom(rpart,rpart) -importFrom(rpart,rpart.control) -importFrom(sandwich,vcovHC) -importFrom(scales,percent) -importFrom(shiny,getDefaultReactiveDomain) -importFrom(shiny,incProgress) -importFrom(shiny,withProgress) -importFrom(stats,anova) -importFrom(stats,as.formula) -importFrom(stats,binomial) -importFrom(stats,coef) -importFrom(stats,confint) -importFrom(stats,confint.default) -importFrom(stats,contrasts) -importFrom(stats,cor) -importFrom(stats,deviance) -importFrom(stats,dnorm) -importFrom(stats,family) -importFrom(stats,formula) -importFrom(stats,glm) -importFrom(stats,lm) -importFrom(stats,logLik) -importFrom(stats,median) -importFrom(stats,model.frame) -importFrom(stats,model.matrix) -importFrom(stats,na.omit) -importFrom(stats,pnorm) -importFrom(stats,predict) -importFrom(stats,pt) -importFrom(stats,qnorm) -importFrom(stats,qt) -importFrom(stats,quantile) -importFrom(stats,rbinom) -importFrom(stats,relevel) -importFrom(stats,residuals) -importFrom(stats,rlnorm) -importFrom(stats,rnorm) -importFrom(stats,rpois) -importFrom(stats,runif) -importFrom(stats,sd) -importFrom(stats,setNames) -importFrom(stats,step) -importFrom(stats,terms) -importFrom(stats,terms.formula) -importFrom(stats,update) -importFrom(stats,weighted.mean) -importFrom(stats,wilcox.test) -importFrom(stringi,stri_trans_general) -importFrom(stringr,str_match) -importFrom(tidyr,gather) -importFrom(tidyr,spread) -importFrom(tidyselect,where) -importFrom(utils,as.relistable) -importFrom(utils,capture.output) -importFrom(utils,combn) -importFrom(utils,head) -importFrom(utils,relist) -importFrom(utils,tail) -importFrom(utils,write.table) -importFrom(vip,vi) -importFrom(xgboost,xgb.importance) -importFrom(xgboost,xgboost) -importFrom(yaml,yaml.load) +# Generated by roxygen2: do not edit by hand + +S3method(plot,confusion) +S3method(plot,coxp) +S3method(plot,crs) +S3method(plot,crtree) +S3method(plot,dtree) +S3method(plot,evalbin) +S3method(plot,evalreg) +S3method(plot,gbt) +S3method(plot,logistic) +S3method(plot,mnl) +S3method(plot,mnl.predict) +S3method(plot,model.predict) +S3method(plot,nb) +S3method(plot,nb.predict) +S3method(plot,nn) +S3method(plot,regress) +S3method(plot,repeater) +S3method(plot,rforest) +S3method(plot,rforest.predict) +S3method(plot,simulater) +S3method(plot,svm) +S3method(plot,uplift) +S3method(predict,coxp) +S3method(predict,crtree) +S3method(predict,gbt) +S3method(predict,logistic) +S3method(predict,mnl) +S3method(predict,nb) +S3method(predict,nn) +S3method(predict,regress) +S3method(predict,rforest) +S3method(predict,svm) +S3method(print,coxp.predict) +S3method(print,crtree.predict) +S3method(print,gbt.predict) +S3method(print,logistic.predict) +S3method(print,mnl.predict) +S3method(print,nb.predict) +S3method(print,nn.predict) +S3method(print,regress.predict) +S3method(print,rforest.predict) +S3method(print,svm.predict) +S3method(render,DiagrammeR) +S3method(sensitivity,dtree) +S3method(store,coxp.predict) +S3method(store,crs) +S3method(store,mnl.predict) +S3method(store,model) +S3method(store,model.predict) +S3method(store,nb.predict) +S3method(store,rforest.predict) +S3method(summary,confusion) +S3method(summary,coxp) +S3method(summary,crs) +S3method(summary,crtree) +S3method(summary,dtree) +S3method(summary,evalbin) +S3method(summary,evalreg) +S3method(summary,gbt) +S3method(summary,logistic) +S3method(summary,mnl) +S3method(summary,nb) +S3method(summary,nn) +S3method(summary,regress) +S3method(summary,repeater) +S3method(summary,rforest) +S3method(summary,simulater) +S3method(summary,svm) +S3method(summary,uplift) +export(.as_int) +export(.as_num) +export(MAE) +export(RMSE) +export(Rsq) +export(ann) +export(auc) +export(confint_robust) +export(confusion) +export(coxp) +export(crs) +export(crtree) +export(cv.crtree) +export(cv.gbt) +export(cv.nn) +export(cv.rforest) +export(dtree) +export(dtree_parser) +export(evalbin) +export(evalreg) +export(find_max) +export(find_min) +export(gbt) +export(logistic) +export(minmax) +export(mnl) +export(nb) +export(nn) +export(onehot) +export(pdp_plot) +export(pred_plot) +export(predict_model) +export(print_predict_model) +export(profit) +export(radiant.model) +export(radiant.model_viewer) +export(radiant.model_window) +export(regress) +export(remove_comments) +export(repeater) +export(rforest) +export(rig) +export(scale_df) +export(sdw) +export(sensitivity) +export(sim_cleaner) +export(sim_cor) +export(sim_splitter) +export(sim_summary) +export(simulater) +export(svm) +export(svm_boundary_plot) +export(svm_margin_plot) +export(svm_vip_plot) +export(test_specs) +export(uplift) +export(var_check) +export(varimp) +export(varimp_plot) +export(write.coeff) +import(ggplot2) +import(radiant.data) +import(shiny) +importFrom(DiagrammeR,DiagrammeR) +importFrom(DiagrammeR,DiagrammeROutput) +importFrom(DiagrammeR,mermaid) +importFrom(DiagrammeR,renderDiagrammeR) +importFrom(NeuralNetTools,garson) +importFrom(NeuralNetTools,olden) +importFrom(NeuralNetTools,plotnet) +importFrom(broom,augment) +importFrom(car,linearHypothesis) +importFrom(car,vif) +importFrom(data.tree,Clone) +importFrom(data.tree,FormatPercent) +importFrom(data.tree,Get) +importFrom(data.tree,Traverse) +importFrom(data.tree,as.Node) +importFrom(data.tree,isLeaf) +importFrom(data.tree,isNotLeaf) +importFrom(data.tree,isNotRoot) +importFrom(dplyr,across) +importFrom(dplyr,arrange) +importFrom(dplyr,arrange_at) +importFrom(dplyr,bind_cols) +importFrom(dplyr,bind_rows) +importFrom(dplyr,data_frame) +importFrom(dplyr,desc) +importFrom(dplyr,distinct_at) +importFrom(dplyr,everything) +importFrom(dplyr,filter) +importFrom(dplyr,first) +importFrom(dplyr,funs) +importFrom(dplyr,group_by) +importFrom(dplyr,group_by_) +importFrom(dplyr,group_by_at) +importFrom(dplyr,inner_join) +importFrom(dplyr,last) +importFrom(dplyr,min_rank) +importFrom(dplyr,mutate) +importFrom(dplyr,mutate_) +importFrom(dplyr,mutate_all) +importFrom(dplyr,mutate_at) +importFrom(dplyr,mutate_if) +importFrom(dplyr,near) +importFrom(dplyr,pull) +importFrom(dplyr,rename) +importFrom(dplyr,sample_n) +importFrom(dplyr,select) +importFrom(dplyr,select_at) +importFrom(dplyr,slice) +importFrom(dplyr,summarise) +importFrom(dplyr,summarise_) +importFrom(dplyr,summarise_all) +importFrom(dplyr,summarise_at) +importFrom(dplyr,summarize) +importFrom(dplyr,ungroup) +importFrom(e1071,naiveBayes) +importFrom(ggplot2,autoplot) +importFrom(ggrepel,geom_text_repel) +importFrom(graphics,par) +importFrom(import,from) +importFrom(lubridate,is.Date) +importFrom(lubridate,now) +importFrom(magrittr,"%<>%") +importFrom(magrittr,"%>%") +importFrom(magrittr,"%T>%") +importFrom(magrittr,extract2) +importFrom(magrittr,set_colnames) +importFrom(magrittr,set_names) +importFrom(magrittr,set_rownames) +importFrom(nnet,nnet) +importFrom(nnet,nnet.formula) +importFrom(patchwork,plot_annotation) +importFrom(patchwork,wrap_plots) +importFrom(pdp,partial) +importFrom(psych,cohen.kappa) +importFrom(radiant.data,launch) +importFrom(radiant.data,set_attr) +importFrom(radiant.data,visualize) +importFrom(ranger,ranger) +importFrom(rlang,":=") +importFrom(rlang,.data) +importFrom(rlang,parse_exprs) +importFrom(rpart,prune.rpart) +importFrom(rpart,rpart) +importFrom(rpart,rpart.control) +importFrom(sandwich,vcovHC) +importFrom(scales,percent) +importFrom(shiny,getDefaultReactiveDomain) +importFrom(shiny,incProgress) +importFrom(shiny,withProgress) +importFrom(stats,anova) +importFrom(stats,as.formula) +importFrom(stats,binomial) +importFrom(stats,coef) +importFrom(stats,confint) +importFrom(stats,confint.default) +importFrom(stats,contrasts) +importFrom(stats,cor) +importFrom(stats,deviance) +importFrom(stats,dnorm) +importFrom(stats,family) +importFrom(stats,formula) +importFrom(stats,glm) +importFrom(stats,lm) +importFrom(stats,logLik) +importFrom(stats,median) +importFrom(stats,model.frame) +importFrom(stats,model.matrix) +importFrom(stats,na.omit) +importFrom(stats,pnorm) +importFrom(stats,predict) +importFrom(stats,pt) +importFrom(stats,qnorm) +importFrom(stats,qt) +importFrom(stats,quantile) +importFrom(stats,rbinom) +importFrom(stats,relevel) +importFrom(stats,residuals) +importFrom(stats,rlnorm) +importFrom(stats,rnorm) +importFrom(stats,rpois) +importFrom(stats,runif) +importFrom(stats,sd) +importFrom(stats,setNames) +importFrom(stats,step) +importFrom(stats,terms) +importFrom(stats,terms.formula) +importFrom(stats,update) +importFrom(stats,weighted.mean) +importFrom(stats,wilcox.test) +importFrom(stringi,stri_trans_general) +importFrom(stringr,str_match) +importFrom(tidyr,gather) +importFrom(tidyr,spread) +importFrom(tidyselect,where) +importFrom(utils,as.relistable) +importFrom(utils,capture.output) +importFrom(utils,combn) +importFrom(utils,head) +importFrom(utils,relist) +importFrom(utils,tail) +importFrom(utils,write.table) +importFrom(vip,vi) +importFrom(xgboost,xgb.importance) +importFrom(xgboost,xgboost) +importFrom(yaml,yaml.load) diff --git a/radiant.model/R/svm.R b/radiant.model/R/svm.R index cadd0dc2128a426d69b5cb1d86d866a3b3e19a1d..91d4b45292a828f05c2feac9504f52b5faacfb5d 100644 --- a/radiant.model/R/svm.R +++ b/radiant.model/R/svm.R @@ -7,7 +7,7 @@ svm <- function(dataset, rvar, evar, form, data_filter = "", arr = "", rows = NULL, envir = parent.frame()) { - ## ---- 参数合法性检查(SVM特有) ---- + ## ---- 参数合法性检查---- valid_kernels <- c("linear", "radial", "poly", "sigmoid") if (!kernel %in% valid_kernels) { return(paste0("Kernel must be one of: ", paste(valid_kernels, collapse = ", ")) %>% @@ -38,11 +38,32 @@ svm <- function(dataset, rvar, evar, df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset)) dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir) + is_cat <- sapply(dataset, function(x) is.factor(x) || is.character(x) || is.logical(x)) + cat_vars <- names(is_cat)[is_cat] + + if (length(cat_vars) > 0) { + # 2. 字符型转因子(确保后续编码顺序一致) + for (var in cat_vars) { + if (is.character(dataset[[var]])) { + dataset[[var]] <- as.factor(dataset[[var]]) + } + # 3. 因子/逻辑型转数值(按原始水平顺序编码,如level1→1, level2→2...) + dataset[[var]] <- as.numeric(dataset[[var]]) + } + } + + # 缺失值处理 + dataset <- na.omit(dataset) + if (nrow(dataset) == 0) { + return("No valid samples after removing missing values. Please check your data." %>% add_class("svm")) + } + + # 标准化 if ("standardize" %in% check) { dataset <- scale_df(dataset) } - ## ---- 3. 分类任务的响应变量(转为因子) ---- + ## ---- 3. 分类任务的响应变量 ---- if (type == "classification") { dataset[[rvar]] <- as.factor(dataset[[rvar]]) if (lev == "") { @@ -71,10 +92,19 @@ svm <- function(dataset, rvar, evar, if (!is.na(seed)) set.seed(seed) - ## ---- 6. 训练模型 ---- - model <- do.call(e1071::svm, svm_input) + ## ---- 模型训练---- + model <- try({ + do.call(e1071::svm, svm_input) + }, silent = TRUE) - ## ---- 7. 附加关键信息 ---- + if (inherits(model, "try-error")) { + return(paste("Model training failed:", attr(model, "condition")$message) %>% add_class("svm")) + } + if (is.null(model) || !inherits(model, "svm")) { + return("Model training failed: Generated SVM model is invalid" %>% add_class("svm")) + } + + ## ---- 附加关键信息---- model$df_name <- df_name model$rvar <- rvar model$evar <- evar @@ -86,7 +116,21 @@ svm <- function(dataset, rvar, evar, model$gamma <- gamma model$kernel <- kernel - as.list(environment()) %>% add_class(c("svm", "model")) + ## ---- 返回模型对象---- + list( + model = model, + df_name = df_name, + rvar = rvar, + evar = evar, + type = type, + lev = lev, + wtsname = wtsname, + seed = seed, + cost = cost, + gamma = gamma, + kernel = kernel, + dataset = dataset + ) %>% add_class(c("svm", "model")) } #' Center or standardize variables in a data frame @@ -101,9 +145,9 @@ scale_df <- function(dataset, center = TRUE, scale = TRUE, descr <- attr(dataset, "description") # 保留原描述属性 if (calc) { - # 计算均值(忽略NA) + # 计算均值 ms <- sapply(dataset[, cn, drop = FALSE], function(x) mean(x, na.rm = TRUE)) - # 计算标准差(忽略NA,样本标准差ddof=1,避免除以零) + # 计算标准差 sds <- sapply(dataset[, cn, drop = FALSE], function(x) { sd_val <- sd(x, na.rm = TRUE) ifelse(sd_val == 0, 1, sd_val) @@ -135,7 +179,7 @@ summary.svm <- function(object, prn = TRUE, ...) { svm_model <- object$model n_obs <- nrow(object$dataset) - wtsname <- object$wtsname # 可能是 NULL 或长度 0 字符 + wtsname <- object$wtsname cat("Support Vector Machine\n") cat(sprintf("Kernel type : %s (%s)\n", object$kernel, object$type)) @@ -146,7 +190,7 @@ summary.svm <- function(object, prn = TRUE, ...) { } cat(sprintf("Explanatory variables: %s\n", paste(object$evar, collapse = ", "))) - if (!is.null(wtsname) && length(wtsname) && wtsname != "") { + if (!is.null(wtsname) && nzchar(wtsname)) { cat(sprintf("Weights used : %s\n", wtsname)) } @@ -156,56 +200,79 @@ summary.svm <- function(object, prn = TRUE, ...) { } if (!is.na(object$seed)) cat(sprintf("Seed : %s\n", object$seed)) + # 支持向量计数 if (object$type == "classification") { - n_sv_per_class <- svm_model$nSV - total_sv <- sum(n_sv_per_class) + n_sv_per_class <- if (!is.null(svm_model$nSV)) svm_model$nSV else c(0, 0) + total_sv <- sum(n_sv_per_class) sv_info <- paste(sprintf("class %s: %d", seq_along(n_sv_per_class), n_sv_per_class), collapse = ", ") cat(sprintf("Support vectors : %d (%s)\n", total_sv, sv_info)) } else { - total_sv <- sum(svm_model$nSV) + total_sv <- if (!is.null(svm_model$index)) length(svm_model$index) else 0 cat(sprintf("Support vectors : %d\n", total_sv)) } - ## ---- 权重样本数计算同样保护 ---- - if (!is.null(wtsname) && length(wtsname) && wtsname != "") { + nr_obs <- n_obs + if (!is.null(wtsname) && nzchar(wtsname) && wtsname %in% names(object$dataset)) { weights_values <- as.numeric(object$dataset[[wtsname]]) - nr_obs <- if (all(!is.na(weights_values))) sum(weights_values, na.rm = TRUE) else n_obs - } else { - nr_obs <- n_obs + weights_clean <- weights_values[!is.na(weights_values)] + if (length(weights_clean) > 0) { + sum_weights <- sum(weights_clean) + if (sum_weights > 0) nr_obs <- sum_weights + } } cat(sprintf("Nr_obs : %s\n", format_nr(nr_obs, dec = 0))) - ## ---- 系数输出(仅线性核) ---- if (prn) { cat("Coefficients/Support Vectors:\n") if (object$kernel == "linear") { - if (is.null(svm_model$w) || length(svm_model$w) == 0) { - cat(" Linear kernel coefficients not available (possible reasons:\n") - cat(" - Insufficient support vectors (data may be linearly inseparable)\n") - cat(" - Model training did not converge)\n") - if (object$type == "classification") { - cat(sprintf(" Support vectors count: %d\n", sum(svm_model$nSV))) - } - } else { - feat_coefs_raw <- as.numeric(svm_model$w) - bias <- as.numeric(-svm_model$rho) - n_evar <- length(object$evar) - feat_coefs <- matrix(data = feat_coefs_raw, nrow = 1, ncol = n_evar, byrow = TRUE, - dimnames = list(NULL, object$evar)) - coef_data <- data.frame( - Variable = c(object$evar, "bias"), - Value = as.numeric(c(as.vector(feat_coefs), bias)), - stringsAsFactors = FALSE, - check.names = FALSE - ) - for (i in seq(1, nrow(coef_data), 2)) { - if (i + 1 > nrow(coef_data)) { - cat(sprintf(" %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i])) - } else { - cat(sprintf(" %-12s: %.2f %-12s: %.2f\n", - coef_data$Variable[i], coef_data$Value[i], - coef_data$Variable[i+1], coef_data$Value[i+1])) + feat_coefs_raw <- rep(0, length(object$evar)) + bias <- 0 + + if (!is.null(svm_model$terms) && !is.null(svm_model$coefs) && !is.null(svm_model$SV)) { + tryCatch({ + model_vars <- attr(svm_model$terms, "term.labels") + coef_mat <- as.matrix(svm_model$coefs) + sv_mat <- as.matrix(svm_model$SV) + + if (nrow(coef_mat) > 0 && nrow(sv_mat) > 0) { + w_full <- as.numeric(t(coef_mat) %*% sv_mat) + sv_colnames <- colnames(sv_mat) + + if (!is.null(sv_colnames)) { + # 按变量名精确匹配 + for (i in seq_along(object$evar)) { + var_name <- object$evar[i] + matching_cols <- which(sv_colnames == var_name) + if (length(matching_cols) > 0) { + feat_coefs_raw[i] <- sum(w_full[matching_cols]) + } + } + } else { + # 无列名则按位置取 + n_evar <- length(object$evar) + feat_coefs_raw[1:min(n_evar, length(w_full))] <- w_full[1:min(n_evar, length(w_full))] + } + bias <- as.numeric(-svm_model$rho) } + }, error = function(e) { + warning("系数提取失败: ", e$message, call. = FALSE) + }) + } + + # 输出 + coef_data <- data.frame( + Variable = c(object$evar, "bias"), + Value = c(feat_coefs_raw, bias), + stringsAsFactors = FALSE + ) + + for (i in seq(1, nrow(coef_data), 2)) { + if (i + 1 > nrow(coef_data)) { + cat(sprintf(" %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i])) + } else { + cat(sprintf(" %-12s: %.2f %-12s: %.2f\n", + coef_data$Variable[i], coef_data$Value[i], + coef_data$Variable[i+1], coef_data$Value[i+1])) } } } else { @@ -256,134 +323,234 @@ plot.svm <- function(x, return("No valid plots selected for SVM") } - ## 返回 patchwork 对象(Shiny 自动打印) + ## 返回 patchwork 对象 patchwork::wrap_plots(plot_list, ncol = min(2, length(plot_list))) %>% { if (isTRUE(shiny)) . else print(.) } } -#' Predict method for the svm function +#' Predict method for SVM model #' @export -predict.svm <- function(object, pred_data = NULL, pred_cmd = "", - dec = 3, envir = parent.frame(), ...) { - ## 1. 基础校验 +predict.svm <- function(object, pred_data = NULL, pred_cmd = "", dec = 3, + envir = parent.frame(), ...) { + # 1. 基础校验 if (is.character(object)) return(object) if (!inherits(object, "svm")) stop("Object must be of class 'svm'") + if (is.null(object$model) || !inherits(object$model, "svm")) { + err_msg <- "Prediction failed: Invalid SVM model" + err_obj <- structure(list(error = err_msg), class = "svm.predict.error") + return(err_obj) + } + svm_info <- object + + # 2. 处理预测数据 + if (is.character(pred_data) && nzchar(pred_data)) { + if (!exists(pred_data, envir = envir)) { + err_obj <- structure(list(error = sprintf("Dataset '%s' not found", pred_data)), class = "svm.predict.error") + return(err_obj) + } + pred_data <- get(pred_data, envir = envir) + } + has_data <- !is.null(pred_data) && is.data.frame(pred_data) && nrow(pred_data) > 0 + has_cmd <- is.character(pred_cmd) && nzchar(pred_cmd) + if (!has_data && !has_cmd) { + err_obj <- structure(list(error = "Please select data and/or specify a command to generate predictions."), class = "svm.predict.error") + return(err_obj) + } - ## 2.1 确定预测数据源 - if (is.null(pred_data) || is.character(pred_data)) { - # 当pred_data为NULL或字符(数据集名)时,使用训练数据 - pred_data_raw <- object$dataset[, object$evar, drop = FALSE] + ## ensure you have a name for the prediction dataset + if (is.data.frame(pred_data)) { + df_name <- deparse(substitute(pred_data)) } else { - # 当pred_data是数据框时,直接使用 - pred_data_raw <- as.data.frame(pred_data) + df_name <- pred_data } - pred_names <- colnames(pred_data_raw) - missing_vars <- setdiff(object$evar, pred_names) - - if (length(missing_vars) > 0) { - msg <- paste0( - "NA\n" - ) - return(msg %>% add_class("svm.predict")) - } - - pred_data <- pred_data_raw[, object$evar, drop = FALSE] - - if (!is.empty(pred_cmd)) { - pred_cmd <- gsub("\\s{2,}", " ", pred_cmd) %>% - gsub(";\\s+", ";", .) %>% - strsplit(";")[[1]] - for (cmd in pred_cmd) { - if (grepl("=", cmd)) { - var_val <- strsplit(trimws(cmd), "=")[[1]] - var <- trimws(var_val[1]) - val <- try(eval(parse(text = trimws(var_val[2])), envir = envir), silent = TRUE) - if (!inherits(val, "try-error")) pred_data[[var]] <- val - else warning(sprintf("Invalid command '%s': using original values", cmd)) - } + # 3. 预测核心函数 - 修改参数名以匹配NN + pfun <- function(model, pred, se, conf_lev) { + # 验证数据 + missing_vars <- setdiff(svm_info$evar, colnames(pred)) + if (length(missing_vars) > 0) { + return(paste("Missing variables:", paste(missing_vars, collapse = ", "))) } - pred_data <- unique(pred_data) - } - - ## 3. 变量类型与因子水平校验 - train_types <- sapply(object$dataset[, object$evar], class) - pred_types <- sapply(pred_data, class) - type_mismatch <- names(which(train_types != pred_types)) - if (length(type_mismatch) > 0) { - return(paste0("Variable type mismatch (train vs pred):\n", - paste(sprintf(" %s: %s vs %s", type_mismatch, - train_types[type_mismatch], pred_types[type_mismatch]), - collapse = "\n")) %>% add_class("svm.predict")) - } - - for (var in object$evar) { - if (is.factor(object$dataset[[var]])) { - train_levs <- levels(object$dataset[[var]]) - pred_levs <- levels(pred_data[[var]]) - if (!identical(train_levs, pred_levs)) { - pred_data[[var]] <- factor(pred_data[[var]], levels = train_levs) - warning(sprintf("Aligned factor levels for '%s' to match training data", var)) + + # 内部标准化 + ms <- attr(svm_info$dataset, "radiant_ms") + sds <- attr(svm_info$dataset, "radiant_sds") + if (!is.null(ms) && !is.null(sds)) { + isNum <- sapply(pred, is.numeric) + cn <- names(isNum)[isNum] + for (var in cn) { + if (var %in% names(ms) && var %in% names(sds) && sds[var] != 0) { + pred[[var]] <- (pred[[var]] - ms[var]) / sds[var] + } } } - } - - ## 4. 标准化对齐 - train_ms <- attr(object$dataset, "radiant_ms") - train_sds <- attr(object$dataset, "radiant_sds") - train_sf <- attr(object$dataset, "radiant_sf") %||% 2 - if (!is.null(train_ms) && !is.null(train_sds)) { - pred_data <- scale_df( - dataset = pred_data, - center = TRUE, scale = TRUE, - sf = train_sf, wts = NULL, - calc = FALSE - ) - attr(pred_data, "radiant_ms") <- train_ms - attr(pred_data, "radiant_sds") <- train_sds - } - - ## 5. 生成预测值 - predict_args <- list( - object = object$model, - newdata = pred_data, - na.action = na.omit - ) - - pred_result <- if (object$type == "classification") { - predict_args$type <- "class" - pred_class <- do.call(predict, predict_args) - if (isTRUE(object$model$param$probability)) { - predict_args$type <- "probabilities" - pred_prob <- do.call(predict, predict_args) %>% - as.data.frame() %>% - set_colnames(paste0("Prob_", colnames(.))) - data.frame( - Predicted_Class = as.character(pred_class), - pred_prob, - stringsAsFactors = FALSE + + # 调试信息:检查模型是否启用了概率 + cat("Model probability enabled:", svm_info$model$prob.model, "\n") + + # 执行预测 + pred_result <- try({ + predict( + svm_info$model, + newdata = pred, + probability = TRUE, # 始终设置为TRUE + decision.values = TRUE ) + }, silent = TRUE) + + if (inherits(pred_result, "try-error")) { + return(paste("Prediction failed:", attr(pred_result, "condition")$message)) + } + + # 4. 结果整理 + if (svm_info$type == "classification") { + # 正确的属性名是"probabilities" + prob_mat <- attr(pred_result, "probabilities") + + if (!is.null(prob_mat)) { + # 调试:显示概率矩阵的列名 + cat("Probability matrix columns:", paste(colnames(prob_mat), collapse = ", "), "\n") + + # 更智能的列名匹配 + target_level <- as.character(svm_info$lev) + + # 尝试多种匹配方式 + matching_col <- NULL + + # 1. 精确匹配 + exact_match <- grep(paste0("^", target_level, "$"), colnames(prob_mat), value = TRUE, ignore.case = TRUE) + if (length(exact_match) > 0) { + matching_col <- exact_match + } + # 2. 部分匹配 + else { + partial_match <- grep(target_level, colnames(prob_mat), value = TRUE, ignore.case = TRUE) + if (length(partial_match) > 0) { + matching_col <- partial_match[1] # 取第一个匹配 + } + } + + # 3. 如果还是没有匹配,尝试逻辑值匹配 + if (is.null(matching_col) || length(matching_col) == 0) { + if (target_level %in% c("TRUE", "Yes", "1", "1.0")) { + logical_match <- grep("TRUE", colnames(prob_mat), value = TRUE, ignore.case = TRUE) + if (length(logical_match) > 0) { + matching_col <- logical_match[1] + } + } else if (target_level %in% c("FALSE", "No", "0", "0.0")) { + logical_match <- grep("FALSE", colnames(prob_mat), value = TRUE, ignore.case = TRUE) + if (length(logical_match) > 0) { + matching_col <- logical_match[1] + } + } + } + + # 4. 作为最后手段,使用第一列 + if (is.null(matching_col) || length(matching_col) == 0) { + if (ncol(prob_mat) > 0) { + matching_col <- colnames(prob_mat)[1] + cat("Warning: Using first probability column '", matching_col, "' for level '", target_level, "'\n") + } + } + + if (!is.null(matching_col) && length(matching_col) > 0) { + surv_prob <- as.numeric(prob_mat[, matching_col]) + pred_df <- data.frame( + Prediction = round(surv_prob, dec), + stringsAsFactors = FALSE + ) + } else { + # 备用方案:使用决策值 + decision_vals <- attr(pred_result, "decision.values") + if (!is.null(decision_vals)) { + # 将决策值转换为概率 + pred_prob <- 1 / (1 + exp(-decision_vals)) + pred_df <- data.frame( + Prediction = round(pred_prob, dec), + stringsAsFactors = FALSE + ) + } else { + # 最后手段:使用预测类 + pred_class <- as.character(pred_result) + pred_prob <- ifelse(pred_class == target_level, 1, 0) + pred_df <- data.frame( + Prediction = round(pred_prob, dec), + stringsAsFactors = FALSE + ) + } + } + } else { + # 备用方案:使用决策值 + decision_vals <- attr(pred_result, "decision.values") + if (!is.null(decision_vals)) { + # 将决策值转换为概率 + pred_prob <- 1 / (1 + exp(-decision_vals)) + pred_df <- data.frame( + Prediction = round(pred_prob, dec), + stringsAsFactors = FALSE + ) + } else { + # 最后手段:使用预测类 + target_level <- as.character(svm_info$lev) + pred_class <- as.character(pred_result) + pred_prob <- ifelse(pred_class == target_level, 1, 0) + pred_df <- data.frame( + Prediction = round(pred_prob, dec), + stringsAsFactors = FALSE + ) + } + } } else { - data.frame(Predicted_Class = as.character(pred_class), stringsAsFactors = FALSE) + # 回归模型 + pred_values <- as.numeric(pred_result) + + # 关键:添加反标准化处理 + rvar_ms <- if (!is.null(ms) && svm_info$rvar %in% names(ms)) ms[svm_info$rvar] else 0 + rvar_sds <- if (!is.null(sds) && svm_info$rvar %in% names(sds)) sds[svm_info$rvar] else 1 + + # 反标准化 + pred_values <- pred_values * rvar_sds + rvar_ms + + pred_df <- data.frame( + Prediction = round(pred_values, dec), + stringsAsFactors = FALSE + ) } - } else { - predict_args$type <- "response" - pred_val <- do.call(predict, predict_args) - pred_val <- as.numeric(pred_val) - data.frame(Predicted_Value = round(pred_val, dec), stringsAsFactors = FALSE) - } - - pred_result <- cbind(pred_data, pred_result) - attr(pred_result, "svm_meta") <- list( - kernel = object$kernel, - cost = object$cost, - gamma = if (object$kernel != "linear") object$gamma else NA, - seed = object$seed, - train_data = object$df_name, - model_type = object$type + + return(pred_df) + } + + # 5. 调用预测框架 - 与NN完全一致 + result <- predict_model( + object, + pfun, + "svm.predict", # 模型类型 + pred_data, + pred_cmd, + conf_lev = 0.95, + se = FALSE, + dec, + envir = envir + ) %>% + set_attr("radiant_pred_data", df_name) + + # 6. 结果元数据 + if (inherits(result, "svm.predict.error")) { + return(result$error) + } + if (is.character(result)) return(result) + + result <- add_class(result, "svm.predict") + attr(result, "svm_meta") <- list( + model_type = svm_info$type, + kernel = svm_info$kernel, + cost = svm_info$cost, + gamma = if (svm_info$kernel != "linear") svm_info$gamma else NA, + seed = svm_info$seed, + train_data = svm_info$df_name, + dec = dec ) - attr(pred_result, "class") <- c("svm.predict", "data.frame") - return(pred_result) + return(result) } #' Print method for predict.svm @@ -396,16 +563,30 @@ print.svm.predict <- function(x, ..., n = 10) { if (!inherits(x, "svm.predict")) stop("Object must be of class 'svm.predict'") meta <- attr(x, "svm_meta") + if (is.null(meta)) { + meta <- list( + model_type = "Unknown", + kernel = "Unknown", + cost = NA, + gamma = NA, + seed = NA, + train_data = "Unknown", + dec = 1 + ) + } + dec <- meta$dec %||% 1 + n_pred <- nrow(x) show_n <- if (n < 0 || n >= n_pred) n_pred else n cat("SVM Predictions\n") - cat(sprintf("Model Type : %s\n", tools::toTitleCase(meta$model_type))) - cat(sprintf("Kernel : %s\n", meta$kernel)) - cat(sprintf("Cost (C) : %.2f\n", meta$cost)) + model_type <- if (is.empty(meta$model_type)) "Unknown" else as.character(meta$model_type) + cat(sprintf("Model Type : %s\n", tools::toTitleCase(model_type))) + cat(sprintf("Kernel : %s\n", ifelse(is.empty(meta$kernel), "Unknown", meta$kernel))) + cat(sprintf("Cost (C) : %.2f\n", ifelse(is.na(meta$cost), 0, meta$cost))) if (!is.na(meta$gamma)) cat(sprintf("Gamma : %.2f\n", meta$gamma)) if (!is.na(meta$seed)) cat(sprintf("Seed : %s\n", meta$seed)) - cat(sprintf("Training Dataset : %s\n", meta$train_data)) + cat(sprintf("Training Dataset : %s\n", ifelse(is.empty(meta$train_data), "Unknown", meta$train_data))) cat(sprintf("Total Predictions : %d\n", n_pred)) if (n_pred == 0) { @@ -413,8 +594,13 @@ print.svm.predict <- function(x, ..., n = 10) { return(invisible(x)) } - x_show <- x[1:show_n, , drop = FALSE] # 确保保持数据框结构 + x_show <- x[1:show_n, , drop = FALSE] + num_cols <- sapply(x_show, is.numeric) + x_show[num_cols] <- lapply(x_show[num_cols], function(col) { + sprintf(paste0("%.", dec, "f"), col) + }) + # 重新计算列宽 col_widths <- sapply(colnames(x_show), function(cn) { max(nchar(cn), max(nchar(as.character(x_show[[cn]])), na.rm = TRUE)) }) @@ -435,38 +621,13 @@ print.svm.predict <- function(x, ..., n = 10) { } if (show_n < n_pred) { - cat(sprintf("\n... (showing first %d of %d; use 'n=-1' to view all)\n", + cat(sprintf("\n... (showing first %d of %d)\n", show_n, n_pred)) } return(invisible(x)) } -#' Cross-validation for SVM -#' @export -cv.svm <- function(object, K = 5, repeats = 1, - kernel = c("linear", "radial"), - cost = 2^(-2:2), - gamma = 2^(-2:2), - seed = 1234, trace = TRUE, fun, ...) { - if (!inherits(object, "svm")) stop("Object must be of class 'svm'") - - tune_grid <- expand.grid( - kernel = kernel, - cost = cost, - gamma = if (any(kernel %in% c("radial", "poly", "sigmoid"))) gamma else NA - ) - - cv_results <- data.frame( - mean_perf = rep(NA, nrow(tune_grid)), - std_perf = rep(NA, nrow(tune_grid)), - kernel = tune_grid$kernel, - cost = tune_grid$cost, - gamma = tune_grid$gamma - ) - cv_results -} - # ---- 决策边界 ---- #' @export svm_boundary_plot <- function(object, size, custom) { diff --git a/radiant.model/inst/app/init.R b/radiant.model/inst/app/init.R index a979f6cf02c54619d0efcebede6c6bbcb9353bba..45990f39132ce0b65add8bdc63ae3a768714f570 100644 --- a/radiant.model/inst/app/init.R +++ b/radiant.model/inst/app/init.R @@ -79,11 +79,12 @@ options( i18n$t("Estimate"), tabPanel(i18n$t("Linear regression (OLS)"), uiOutput("regress")), tabPanel(i18n$t("Logistic regression (GLM)"), uiOutput("logistic")), - tabPanel(i18n$t("Cox Proportional Hazards Regression"),uiOutput("coxp")), tabPanel(i18n$t("Multinomial logistic regression (MNL)"), uiOutput("mnl")), tabPanel(i18n$t("Naive Bayes"), uiOutput("nb")), tabPanel(i18n$t("Neural Network"), uiOutput("nn")), tabPanel(i18n$t("Support Vector Machine (SVM)"),uiOutput("svm")), + "----", i18n$t("Survival Analysis"), + tabPanel(i18n$t("Cox Proportional Hazards Regression"),uiOutput("coxp")), "----", i18n$t("Trees"), tabPanel(i18n$t("Classification and regression trees"), uiOutput("crtree")), tabPanel(i18n$t("Random Forest"), uiOutput("rf")), diff --git a/radiant.model/inst/app/tools/analysis/svm_ui.R b/radiant.model/inst/app/tools/analysis/svm_ui.R index 565b11f26a8706c2df5be185413ab28fd5471e9f..3b15bc41b438ae178c94f25f48c66329bf8dcff4 100644 --- a/radiant.model/inst/app/tools/analysis/svm_ui.R +++ b/radiant.model/inst/app/tools/analysis/svm_ui.R @@ -23,7 +23,7 @@ svm_inputs <- reactive({ svm_args }) -## 预测参数(保留命令模式,未改动) +## 预测参数 svm_pred_args <- as.list(if (exists("predict.svm")) { formals(predict.svm) } else { @@ -52,7 +52,7 @@ svm_pred_inputs <- reactive({ return(svm_pred_args) }) -## 绘图参数(砍掉vip、pdp、svm_margin) +## 绘图参数 svm_plot_args <- as.list(if (exists("plot.svm")) { formals(plot.svm) } else { @@ -77,11 +77,7 @@ output$ui_svm_rvar <- renderUI({ vars <- varnames()[isNum] } }) - init <- if (input$svm_type == "classification") { - if (is.empty(input$logit_rvar)) isolate(input$svm_rvar) else input$logit_rvar - } else { - if (is.empty(input$reg_rvar)) isolate(input$svm_rvar) else input$reg_rvar - } + init <- state_single("svm_rvar", vars, isolate(input$svm_rvar)) selectInput( inputId = "svm_rvar", label = i18n$t("Response variable:"), @@ -109,11 +105,7 @@ output$ui_svm_evar <- renderUI({ if (not_available(input$svm_rvar)) return() vars <- varnames() if (length(vars) > 0) vars <- vars[-which(vars == input$svm_rvar)] - init <- if (input$svm_type == "classification") { - if (is.empty(input$logit_evar)) isolate(input$svm_evar) else input$logit_evar - } else { - if (is.empty(input$reg_evar)) isolate(input$svm_evar) else input$reg_evar - } + init <- state_multiple("svm_evar", vars, isolate(input$svm_evar)) selectInput( inputId = "svm_evar", label = i18n$t("Explanatory variables:"), @@ -141,7 +133,7 @@ output$ui_svm_wts <- renderUI({ ) }) -## 存储预测值UI(残差存储已删除) +## 存储预测值UI output$ui_svm_store_pred_name <- renderUI({ init <- state_init("svm_store_pred_name", "pred_svm") %>% sub("\\d{1,}$", "", .) %>% @@ -164,7 +156,7 @@ observeEvent(input$svm_type, { updateSelectInput(session = session, inputId = "svm_plots", selected = "none") }) -## 绘图选项UI(已删vip、pdp、svm_margin) +## 绘图选项UI output$ui_svm_plots <- renderUI({ req(input$svm_type) avail_plots <- svm_plots @@ -178,19 +170,7 @@ output$ui_svm_plots <- renderUI({ ) }) -## 数据点数量UI(仅dashboard用,保留) -output$ui_svm_nrobs <- renderUI({ - nrobs <- nrow(.get_data()) - choices <- c("1,000" = 1000, "5,000" = 5000, "10,000" = 10000, "All" = -1) %>% - .[. < nrobs] - selectInput( - "svm_nrobs", i18n$t("Number of data points plotted:"), - choices = choices, - selected = state_single("svm_nrobs", choices, 1000) - ) -}) - -## 主UI面板(已删残差存储入口) +## 主UI面板 output$ui_svm <- renderUI({ req(input$dataset) tagList( @@ -258,7 +238,7 @@ output$ui_svm <- renderUI({ ) ) ), - # 预测面板(残差存储已删除) + # 预测面板 conditionalPanel( condition = "input.tabs_svm == 'Predict'", selectInput( @@ -303,19 +283,10 @@ output$ui_svm <- renderUI({ ) ) ), - # 绘图面板(已删vip、pdp、svm_margin) + # 绘图面板 conditionalPanel( condition = "input.tabs_svm == 'Plot'", - uiOutput("ui_svm_plots"), - conditionalPanel( - condition = "input.svm_plots == 'pred_plot'", - uiOutput("ui_svm_incl"), - uiOutput("ui_svm_incl_int") - ), - conditionalPanel( - condition = "input.svm_plots == 'dashboard'", - uiOutput("ui_svm_nrobs") - ) + uiOutput("ui_svm_plots") ) ), # 帮助和报告面板 @@ -327,7 +298,7 @@ output$ui_svm <- renderUI({ ) }) -## 绘图尺寸计算(已删vip、pdp、svm_margin) +## 绘图尺寸计算 svm_plot <- reactive({ if (svm_available() != "available") return() if (is.empty(input$svm_plots, "none")) return() @@ -337,9 +308,6 @@ svm_plot <- reactive({ plot_width <- 650 if ("decision_boundary" %in% input$svm_plots) { plot_height <- 500 - } else if (input$svm_plots == "pred_plot") { - nr_vars <- length(input$svm_incl) + length(input$svm_incl_int) - plot_height <- max(250, ceiling(nr_vars / 2) * 250) } else { plot_height <- max(500, length(res$evar) * 30) } @@ -349,7 +317,7 @@ svm_plot <- reactive({ svm_plot_width <- function() svm_plot()$plot_width %||% 650 svm_plot_height <- function() svm_plot()$plot_height %||% 500 -## 主输出面板(已删残差存储) +## 主输出面板 output$svm <- renderUI({ register_print_output("summary_svm", ".summary_svm") register_print_output("predict_svm", ".predict_print_svm") @@ -393,14 +361,14 @@ svm_available <- reactive({ } }) -## 核心函数壳子 +## 核心函数 .svm <- eventReactive(input$svm_run, { svi <- svm_inputs() svi$envir <- r_data withProgress(message = i18n$t("Estimating SVM model"), value = 1, do.call(svm, svi)) }) -## 辅助输出函数壳子 +## 辅助输出函数 .summary_svm <- reactive({ if (not_pressed(input$svm_run)) return(i18n$t("** Press the Estimate button to estimate the SVM model **")) if (svm_available() != "available") return(svm_available()) @@ -444,7 +412,7 @@ svm_available <- reactive({ withProgress(message = i18n$t("Generating SVM plots"), value = 1, do.call(plot, c(list(x = .svm()), pinp))) }) -## 存储预测值(残差存储已删除) +## 存储预测值 observeEvent(input$svm_store_pred, { req( pressed(input$svm_run), @@ -457,14 +425,14 @@ observeEvent(input$svm_store_pred, { base_col_name <- fix_names(input$svm_store_pred_name) meta <- attr(pred_result, "svm_meta") - pred_cols <- if (meta$model_type == "classification") { - colnames(pred_result)[grepl("^Predicted_Class|^Prob_", colnames(pred_result))] + pred_cols <- if (meta$model_type %in% c("classification", "regression")) { + colnames(pred_result)[colnames(pred_result) == "Prediction"] } else { - "Predicted_Value" + NULL } new_col_names <- if (length(pred_cols) == 1) base_col_name else { - suffix <- gsub("^Predicted_|^Prob_", "", pred_cols) - paste0(base_col_name, "_", suffix) + suffix <- gsub("^Prediction", "", pred_cols) + paste0(base_col_name, ifelse(suffix == "", "", paste0("_", suffix))) } colnames(pred_result)[match(pred_cols, colnames(pred_result))] <- new_col_names @@ -510,9 +478,75 @@ download_handler( height = svm_plot_height ) -## 报告生成(空壳,保留接口) -svm_report <- function() invisible() +svm_report <- function() { + if (is.empty(input$svm_evar)) { + showNotification(i18n$t("Select at least one explanatory variable to generate report"), type = "error") + return(invisible()) + } + + outputs <- c("summary") + inp_out <- list(list(prn = TRUE), "") + figs <- FALSE + xcmd <- "" + + if (!is.empty(input$svm_plots, "none")) { + inp <- check_plot_inputs(svm_plot_inputs()) + inp$size <- NULL + inp_out[[2]] <- clean_args(inp, svm_plot_args[-1]) + inp_out[[2]]$custom <- FALSE + outputs <- c(outputs, "plot") + figs <- TRUE + } + + if (!is.empty(input$svm_predict, "none") && + (!is.empty(input$svm_pred_data) || !is.empty(input$svm_pred_cmd))) { + pred_args <- clean_args(svm_pred_inputs(), svm_pred_args[-1]) + if (!is.empty(pred_args$pred_cmd)) { + pred_args$pred_cmd <- strsplit(pred_args$pred_cmd, ";\\s*")[[1]] + } else { + pred_args$pred_cmd <- NULL + } + if (!is.empty(pred_args$pred_data)) { + pred_args$pred_data <- as.symbol(pred_args$pred_data) + } else { + pred_args$pred_data <- NULL + } + + inp_out[[2 + figs]] <- pred_args + outputs <- c(outputs, "pred <- predict") + xcmd <- paste0(xcmd, "print(pred, n = 10)") + + if (input$svm_predict %in% c("data", "datacmd") && !is.empty(input$svm_store_pred_name)) { + fixed <- fix_names(input$svm_store_pred_name) + updateTextInput(session, "svm_store_pred_name", value = fixed) + xcmd <- paste0( + xcmd, "\n", input$svm_pred_data, " <- store(", + input$svm_pred_data, ", pred, name = \"", fixed, "\")" + ) + } + } + + svm_inp <- svm_inputs() + if (input$svm_type == "regression") { + svm_inp$lev <- NULL + } + if (input$svm_kernel == "linear") { + svm_inp$gamma <- NULL + } + + update_report( + inp_main = clean_args(svm_inp, svm_args), + fun_name = "svm", + inp_out = inp_out, + outputs = outputs, + figs = figs, + fig.width = svm_plot_width(), + fig.height = svm_plot_height(), + xcmd = xcmd + ) +} +## 报告生成 observeEvent(input$svm_report, { r_info[["latest_screenshot"]] <- NULL svm_report() diff --git a/radiant.model/man/cv.svm.Rd b/radiant.model/man/cv.svm.Rd deleted file mode 100644 index 11260bf3a2aaf50a508bd65fd9f3f42c50059ab1..0000000000000000000000000000000000000000 --- a/radiant.model/man/cv.svm.Rd +++ /dev/null @@ -1,22 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/svm.R -\name{cv.svm} -\alias{cv.svm} -\title{Cross-validation for SVM} -\usage{ -cv.svm( - object, - K = 5, - repeats = 1, - kernel = c("linear", "radial"), - cost = seq(0.1, 10, by = 0.5), - gamma = seq(0.1, 5, by = 0.5), - seed = 1234, - trace = TRUE, - fun, - ... -) -} -\description{ -Cross-validation for SVM -} diff --git a/radiant.model/man/plot.svm.Rd b/radiant.model/man/plot.svm.Rd index 2a62b49b05300fda13f9ea8511347d30fb2000a1..f82ae52b0c679c0e7ae356314a74d3e555985128 100644 --- a/radiant.model/man/plot.svm.Rd +++ b/radiant.model/man/plot.svm.Rd @@ -4,17 +4,7 @@ \alias{plot.svm} \title{Plot method for the svm function} \usage{ -\method{plot}{svm}( - x, - plots = "vip", - size = 12, - nrobs = -1, - incl = NULL, - incl_int = NULL, - shiny = FALSE, - custom = FALSE, - ... -) +\method{plot}{svm}(x, plots = "none", size = 12, shiny = FALSE, custom = FALSE, ...) } \description{ Plot method for the svm function diff --git a/radiant.model/man/predict.svm.Rd b/radiant.model/man/predict.svm.Rd index e759df66988ebcb241a1a814381eaad0ef2c27ee..0db202c85349994fa0fb9f9e01ccbe08a1949b6b 100644 --- a/radiant.model/man/predict.svm.Rd +++ b/radiant.model/man/predict.svm.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/svm.R \name{predict.svm} \alias{predict.svm} -\title{Predict method for the svm function} +\title{Predict method for SVM model} \usage{ \method{predict}{svm}( object, @@ -14,5 +14,5 @@ ) } \description{ -Predict method for the svm function +Predict method for SVM model } diff --git a/radiant.model/man/varimp.Rd b/radiant.model/man/varimp.Rd index 0eb6704c1988a73cd3fbcaed0564a49b2343d310..7ad69f4bd72002459be0ab62e30d46c9068f1319 100644 --- a/radiant.model/man/varimp.Rd +++ b/radiant.model/man/varimp.Rd @@ -4,9 +4,9 @@ \alias{varimp} \title{Variable importance using the vip package and permutation importance} \usage{ -varimp(object, rvar, lev, data = NULL, seed = 1234) +varimp(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10) -varimp(object, rvar, lev, data = NULL, seed = 1234) +varimp(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10) } \arguments{ \item{object}{Model object created by Radiant} @@ -22,5 +22,5 @@ varimp(object, rvar, lev, data = NULL, seed = 1234) \description{ Variable importance using the vip package and permutation importance -Variable importance using the vip package and permutation importance +Variable importance for SVM using permutation importance } diff --git a/radiant.model/man/varimp_plot.Rd b/radiant.model/man/varimp_plot.Rd index 1480b815d0e7437501dc02ca0e23349b44b74bff..dda339c5875feb7f4a7723650aef24ef78ce8f6a 100644 --- a/radiant.model/man/varimp_plot.Rd +++ b/radiant.model/man/varimp_plot.Rd @@ -1,11 +1,9 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/nn.R, R/svm.R +% Please edit documentation in R/nn.R \name{varimp_plot} \alias{varimp_plot} \title{Plot permutation importance} \usage{ -varimp_plot(object, rvar, lev, data = NULL, seed = 1234) - varimp_plot(object, rvar, lev, data = NULL, seed = 1234) } \arguments{ @@ -20,7 +18,5 @@ varimp_plot(object, rvar, lev, data = NULL, seed = 1234) \item{seed}{Random seed for reproducibility} } \description{ -Plot permutation importance - Plot permutation importance }