Commit 53dd5828 authored by wuzekai's avatar wuzekai

更新了svm

parent e9df660e
# Generated by roxygen2: do not edit by hand
S3method(plot,confusion)
S3method(plot,coxp)
S3method(plot,crs)
S3method(plot,crtree)
S3method(plot,dtree)
S3method(plot,evalbin)
S3method(plot,evalreg)
S3method(plot,gbt)
S3method(plot,logistic)
S3method(plot,mnl)
S3method(plot,mnl.predict)
S3method(plot,model.predict)
S3method(plot,nb)
S3method(plot,nb.predict)
S3method(plot,nn)
S3method(plot,regress)
S3method(plot,repeater)
S3method(plot,rforest)
S3method(plot,rforest.predict)
S3method(plot,simulater)
S3method(plot,svm)
S3method(plot,uplift)
S3method(predict,coxp)
S3method(predict,crtree)
S3method(predict,gbt)
S3method(predict,logistic)
S3method(predict,mnl)
S3method(predict,nb)
S3method(predict,nn)
S3method(predict,regress)
S3method(predict,rforest)
S3method(predict,svm)
S3method(print,coxp.predict)
S3method(print,crtree.predict)
S3method(print,gbt.predict)
S3method(print,logistic.predict)
S3method(print,mnl.predict)
S3method(print,nb.predict)
S3method(print,nn.predict)
S3method(print,regress.predict)
S3method(print,rforest.predict)
S3method(print,svm.predict)
S3method(render,DiagrammeR)
S3method(sensitivity,dtree)
S3method(store,coxp.predict)
S3method(store,crs)
S3method(store,mnl.predict)
S3method(store,model)
S3method(store,model.predict)
S3method(store,nb.predict)
S3method(store,rforest.predict)
S3method(summary,confusion)
S3method(summary,coxp)
S3method(summary,crs)
S3method(summary,crtree)
S3method(summary,dtree)
S3method(summary,evalbin)
S3method(summary,evalreg)
S3method(summary,gbt)
S3method(summary,logistic)
S3method(summary,mnl)
S3method(summary,nb)
S3method(summary,nn)
S3method(summary,regress)
S3method(summary,repeater)
S3method(summary,rforest)
S3method(summary,simulater)
S3method(summary,svm)
S3method(summary,uplift)
export(.as_int)
export(.as_num)
export(MAE)
export(RMSE)
export(Rsq)
export(ann)
export(auc)
export(confint_robust)
export(confusion)
export(coxp)
export(crs)
export(crtree)
export(cv.crtree)
export(cv.gbt)
export(cv.nn)
export(cv.rforest)
export(cv.svm)
export(dtree)
export(dtree_parser)
export(evalbin)
export(evalreg)
export(find_max)
export(find_min)
export(gbt)
export(logistic)
export(minmax)
export(mnl)
export(nb)
export(nn)
export(onehot)
export(pdp_plot)
export(pred_plot)
export(predict_model)
export(print_predict_model)
export(profit)
export(radiant.model)
export(radiant.model_viewer)
export(radiant.model_window)
export(regress)
export(remove_comments)
export(repeater)
export(rforest)
export(rig)
export(scale_df)
export(sdw)
export(sensitivity)
export(sim_cleaner)
export(sim_cor)
export(sim_splitter)
export(sim_summary)
export(simulater)
export(svm)
export(test_specs)
export(uplift)
export(var_check)
export(varimp)
export(varimp_plot)
export(write.coeff)
import(ggplot2)
import(radiant.data)
import(shiny)
importFrom(DiagrammeR,DiagrammeR)
importFrom(DiagrammeR,DiagrammeROutput)
importFrom(DiagrammeR,mermaid)
importFrom(DiagrammeR,renderDiagrammeR)
importFrom(NeuralNetTools,garson)
importFrom(NeuralNetTools,olden)
importFrom(NeuralNetTools,plotnet)
importFrom(broom,augment)
importFrom(car,linearHypothesis)
importFrom(car,vif)
importFrom(data.tree,Clone)
importFrom(data.tree,FormatPercent)
importFrom(data.tree,Get)
importFrom(data.tree,Traverse)
importFrom(data.tree,as.Node)
importFrom(data.tree,isLeaf)
importFrom(data.tree,isNotLeaf)
importFrom(data.tree,isNotRoot)
importFrom(dplyr,across)
importFrom(dplyr,arrange)
importFrom(dplyr,arrange_at)
importFrom(dplyr,bind_cols)
importFrom(dplyr,bind_rows)
importFrom(dplyr,data_frame)
importFrom(dplyr,desc)
importFrom(dplyr,distinct_at)
importFrom(dplyr,everything)
importFrom(dplyr,filter)
importFrom(dplyr,first)
importFrom(dplyr,funs)
importFrom(dplyr,group_by)
importFrom(dplyr,group_by_)
importFrom(dplyr,group_by_at)
importFrom(dplyr,inner_join)
importFrom(dplyr,last)
importFrom(dplyr,min_rank)
importFrom(dplyr,mutate)
importFrom(dplyr,mutate_)
importFrom(dplyr,mutate_all)
importFrom(dplyr,mutate_at)
importFrom(dplyr,mutate_if)
importFrom(dplyr,near)
importFrom(dplyr,pull)
importFrom(dplyr,rename)
importFrom(dplyr,sample_n)
importFrom(dplyr,select)
importFrom(dplyr,select_at)
importFrom(dplyr,slice)
importFrom(dplyr,summarise)
importFrom(dplyr,summarise_)
importFrom(dplyr,summarise_all)
importFrom(dplyr,summarise_at)
importFrom(dplyr,summarize)
importFrom(dplyr,ungroup)
importFrom(e1071,naiveBayes)
importFrom(ggplot2,autoplot)
importFrom(ggrepel,geom_text_repel)
importFrom(graphics,par)
importFrom(import,from)
importFrom(lubridate,is.Date)
importFrom(lubridate,now)
importFrom(magrittr,"%<>%")
importFrom(magrittr,"%>%")
importFrom(magrittr,"%T>%")
importFrom(magrittr,extract2)
importFrom(magrittr,set_colnames)
importFrom(magrittr,set_names)
importFrom(magrittr,set_rownames)
importFrom(nnet,nnet)
importFrom(nnet,nnet.formula)
importFrom(patchwork,plot_annotation)
importFrom(patchwork,wrap_plots)
importFrom(pdp,partial)
importFrom(psych,cohen.kappa)
importFrom(radiant.data,launch)
importFrom(radiant.data,set_attr)
importFrom(radiant.data,visualize)
importFrom(ranger,ranger)
importFrom(rlang,":=")
importFrom(rlang,.data)
importFrom(rlang,parse_exprs)
importFrom(rpart,prune.rpart)
importFrom(rpart,rpart)
importFrom(rpart,rpart.control)
importFrom(sandwich,vcovHC)
importFrom(scales,percent)
importFrom(shiny,getDefaultReactiveDomain)
importFrom(shiny,incProgress)
importFrom(shiny,withProgress)
importFrom(stats,anova)
importFrom(stats,as.formula)
importFrom(stats,binomial)
importFrom(stats,coef)
importFrom(stats,confint)
importFrom(stats,confint.default)
importFrom(stats,contrasts)
importFrom(stats,cor)
importFrom(stats,deviance)
importFrom(stats,dnorm)
importFrom(stats,family)
importFrom(stats,formula)
importFrom(stats,glm)
importFrom(stats,lm)
importFrom(stats,logLik)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,na.omit)
importFrom(stats,pnorm)
importFrom(stats,predict)
importFrom(stats,pt)
importFrom(stats,qnorm)
importFrom(stats,qt)
importFrom(stats,quantile)
importFrom(stats,rbinom)
importFrom(stats,relevel)
importFrom(stats,residuals)
importFrom(stats,rlnorm)
importFrom(stats,rnorm)
importFrom(stats,rpois)
importFrom(stats,runif)
importFrom(stats,sd)
importFrom(stats,setNames)
importFrom(stats,step)
importFrom(stats,terms)
importFrom(stats,terms.formula)
importFrom(stats,update)
importFrom(stats,weighted.mean)
importFrom(stats,wilcox.test)
importFrom(stringi,stri_trans_general)
importFrom(stringr,str_match)
importFrom(tidyr,gather)
importFrom(tidyr,spread)
importFrom(tidyselect,where)
importFrom(utils,as.relistable)
importFrom(utils,capture.output)
importFrom(utils,combn)
importFrom(utils,head)
importFrom(utils,relist)
importFrom(utils,tail)
importFrom(utils,write.table)
importFrom(vip,vi)
importFrom(xgboost,xgb.importance)
importFrom(xgboost,xgboost)
importFrom(yaml,yaml.load)
# Generated by roxygen2: do not edit by hand
S3method(plot,confusion)
S3method(plot,coxp)
S3method(plot,crs)
S3method(plot,crtree)
S3method(plot,dtree)
S3method(plot,evalbin)
S3method(plot,evalreg)
S3method(plot,gbt)
S3method(plot,logistic)
S3method(plot,mnl)
S3method(plot,mnl.predict)
S3method(plot,model.predict)
S3method(plot,nb)
S3method(plot,nb.predict)
S3method(plot,nn)
S3method(plot,regress)
S3method(plot,repeater)
S3method(plot,rforest)
S3method(plot,rforest.predict)
S3method(plot,simulater)
S3method(plot,svm)
S3method(plot,uplift)
S3method(predict,coxp)
S3method(predict,crtree)
S3method(predict,gbt)
S3method(predict,logistic)
S3method(predict,mnl)
S3method(predict,nb)
S3method(predict,nn)
S3method(predict,regress)
S3method(predict,rforest)
S3method(predict,svm)
S3method(print,coxp.predict)
S3method(print,crtree.predict)
S3method(print,gbt.predict)
S3method(print,logistic.predict)
S3method(print,mnl.predict)
S3method(print,nb.predict)
S3method(print,nn.predict)
S3method(print,regress.predict)
S3method(print,rforest.predict)
S3method(print,svm.predict)
S3method(render,DiagrammeR)
S3method(sensitivity,dtree)
S3method(store,coxp.predict)
S3method(store,crs)
S3method(store,mnl.predict)
S3method(store,model)
S3method(store,model.predict)
S3method(store,nb.predict)
S3method(store,rforest.predict)
S3method(summary,confusion)
S3method(summary,coxp)
S3method(summary,crs)
S3method(summary,crtree)
S3method(summary,dtree)
S3method(summary,evalbin)
S3method(summary,evalreg)
S3method(summary,gbt)
S3method(summary,logistic)
S3method(summary,mnl)
S3method(summary,nb)
S3method(summary,nn)
S3method(summary,regress)
S3method(summary,repeater)
S3method(summary,rforest)
S3method(summary,simulater)
S3method(summary,svm)
S3method(summary,uplift)
export(.as_int)
export(.as_num)
export(MAE)
export(RMSE)
export(Rsq)
export(ann)
export(auc)
export(confint_robust)
export(confusion)
export(coxp)
export(crs)
export(crtree)
export(cv.crtree)
export(cv.gbt)
export(cv.nn)
export(cv.rforest)
export(dtree)
export(dtree_parser)
export(evalbin)
export(evalreg)
export(find_max)
export(find_min)
export(gbt)
export(logistic)
export(minmax)
export(mnl)
export(nb)
export(nn)
export(onehot)
export(pdp_plot)
export(pred_plot)
export(predict_model)
export(print_predict_model)
export(profit)
export(radiant.model)
export(radiant.model_viewer)
export(radiant.model_window)
export(regress)
export(remove_comments)
export(repeater)
export(rforest)
export(rig)
export(scale_df)
export(sdw)
export(sensitivity)
export(sim_cleaner)
export(sim_cor)
export(sim_splitter)
export(sim_summary)
export(simulater)
export(svm)
export(svm_boundary_plot)
export(svm_margin_plot)
export(svm_vip_plot)
export(test_specs)
export(uplift)
export(var_check)
export(varimp)
export(varimp_plot)
export(write.coeff)
import(ggplot2)
import(radiant.data)
import(shiny)
importFrom(DiagrammeR,DiagrammeR)
importFrom(DiagrammeR,DiagrammeROutput)
importFrom(DiagrammeR,mermaid)
importFrom(DiagrammeR,renderDiagrammeR)
importFrom(NeuralNetTools,garson)
importFrom(NeuralNetTools,olden)
importFrom(NeuralNetTools,plotnet)
importFrom(broom,augment)
importFrom(car,linearHypothesis)
importFrom(car,vif)
importFrom(data.tree,Clone)
importFrom(data.tree,FormatPercent)
importFrom(data.tree,Get)
importFrom(data.tree,Traverse)
importFrom(data.tree,as.Node)
importFrom(data.tree,isLeaf)
importFrom(data.tree,isNotLeaf)
importFrom(data.tree,isNotRoot)
importFrom(dplyr,across)
importFrom(dplyr,arrange)
importFrom(dplyr,arrange_at)
importFrom(dplyr,bind_cols)
importFrom(dplyr,bind_rows)
importFrom(dplyr,data_frame)
importFrom(dplyr,desc)
importFrom(dplyr,distinct_at)
importFrom(dplyr,everything)
importFrom(dplyr,filter)
importFrom(dplyr,first)
importFrom(dplyr,funs)
importFrom(dplyr,group_by)
importFrom(dplyr,group_by_)
importFrom(dplyr,group_by_at)
importFrom(dplyr,inner_join)
importFrom(dplyr,last)
importFrom(dplyr,min_rank)
importFrom(dplyr,mutate)
importFrom(dplyr,mutate_)
importFrom(dplyr,mutate_all)
importFrom(dplyr,mutate_at)
importFrom(dplyr,mutate_if)
importFrom(dplyr,near)
importFrom(dplyr,pull)
importFrom(dplyr,rename)
importFrom(dplyr,sample_n)
importFrom(dplyr,select)
importFrom(dplyr,select_at)
importFrom(dplyr,slice)
importFrom(dplyr,summarise)
importFrom(dplyr,summarise_)
importFrom(dplyr,summarise_all)
importFrom(dplyr,summarise_at)
importFrom(dplyr,summarize)
importFrom(dplyr,ungroup)
importFrom(e1071,naiveBayes)
importFrom(ggplot2,autoplot)
importFrom(ggrepel,geom_text_repel)
importFrom(graphics,par)
importFrom(import,from)
importFrom(lubridate,is.Date)
importFrom(lubridate,now)
importFrom(magrittr,"%<>%")
importFrom(magrittr,"%>%")
importFrom(magrittr,"%T>%")
importFrom(magrittr,extract2)
importFrom(magrittr,set_colnames)
importFrom(magrittr,set_names)
importFrom(magrittr,set_rownames)
importFrom(nnet,nnet)
importFrom(nnet,nnet.formula)
importFrom(patchwork,plot_annotation)
importFrom(patchwork,wrap_plots)
importFrom(pdp,partial)
importFrom(psych,cohen.kappa)
importFrom(radiant.data,launch)
importFrom(radiant.data,set_attr)
importFrom(radiant.data,visualize)
importFrom(ranger,ranger)
importFrom(rlang,":=")
importFrom(rlang,.data)
importFrom(rlang,parse_exprs)
importFrom(rpart,prune.rpart)
importFrom(rpart,rpart)
importFrom(rpart,rpart.control)
importFrom(sandwich,vcovHC)
importFrom(scales,percent)
importFrom(shiny,getDefaultReactiveDomain)
importFrom(shiny,incProgress)
importFrom(shiny,withProgress)
importFrom(stats,anova)
importFrom(stats,as.formula)
importFrom(stats,binomial)
importFrom(stats,coef)
importFrom(stats,confint)
importFrom(stats,confint.default)
importFrom(stats,contrasts)
importFrom(stats,cor)
importFrom(stats,deviance)
importFrom(stats,dnorm)
importFrom(stats,family)
importFrom(stats,formula)
importFrom(stats,glm)
importFrom(stats,lm)
importFrom(stats,logLik)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,na.omit)
importFrom(stats,pnorm)
importFrom(stats,predict)
importFrom(stats,pt)
importFrom(stats,qnorm)
importFrom(stats,qt)
importFrom(stats,quantile)
importFrom(stats,rbinom)
importFrom(stats,relevel)
importFrom(stats,residuals)
importFrom(stats,rlnorm)
importFrom(stats,rnorm)
importFrom(stats,rpois)
importFrom(stats,runif)
importFrom(stats,sd)
importFrom(stats,setNames)
importFrom(stats,step)
importFrom(stats,terms)
importFrom(stats,terms.formula)
importFrom(stats,update)
importFrom(stats,weighted.mean)
importFrom(stats,wilcox.test)
importFrom(stringi,stri_trans_general)
importFrom(stringr,str_match)
importFrom(tidyr,gather)
importFrom(tidyr,spread)
importFrom(tidyselect,where)
importFrom(utils,as.relistable)
importFrom(utils,capture.output)
importFrom(utils,combn)
importFrom(utils,head)
importFrom(utils,relist)
importFrom(utils,tail)
importFrom(utils,write.table)
importFrom(vip,vi)
importFrom(xgboost,xgb.importance)
importFrom(xgboost,xgboost)
importFrom(yaml,yaml.load)
......@@ -7,7 +7,7 @@ svm <- function(dataset, rvar, evar,
form, data_filter = "", arr = "",
rows = NULL, envir = parent.frame()) {
## ---- 参数合法性检查(SVM特有) ----
## ---- 参数合法性检查----
valid_kernels <- c("linear", "radial", "poly", "sigmoid")
if (!kernel %in% valid_kernels) {
return(paste0("Kernel must be one of: ", paste(valid_kernels, collapse = ", ")) %>%
......@@ -38,11 +38,32 @@ svm <- function(dataset, rvar, evar,
df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset))
dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir)
is_cat <- sapply(dataset, function(x) is.factor(x) || is.character(x) || is.logical(x))
cat_vars <- names(is_cat)[is_cat]
if (length(cat_vars) > 0) {
# 2. 字符型转因子(确保后续编码顺序一致)
for (var in cat_vars) {
if (is.character(dataset[[var]])) {
dataset[[var]] <- as.factor(dataset[[var]])
}
# 3. 因子/逻辑型转数值(按原始水平顺序编码,如level1→1, level2→2...)
dataset[[var]] <- as.numeric(dataset[[var]])
}
}
# 缺失值处理
dataset <- na.omit(dataset)
if (nrow(dataset) == 0) {
return("No valid samples after removing missing values. Please check your data." %>% add_class("svm"))
}
# 标准化
if ("standardize" %in% check) {
dataset <- scale_df(dataset)
}
## ---- 3. 分类任务的响应变量(转为因子) ----
## ---- 3. 分类任务的响应变量 ----
if (type == "classification") {
dataset[[rvar]] <- as.factor(dataset[[rvar]])
if (lev == "") {
......@@ -71,10 +92,19 @@ svm <- function(dataset, rvar, evar,
if (!is.na(seed)) set.seed(seed)
## ---- 6. 训练模型 ----
model <- do.call(e1071::svm, svm_input)
## ---- 模型训练----
model <- try({
do.call(e1071::svm, svm_input)
}, silent = TRUE)
## ---- 7. 附加关键信息 ----
if (inherits(model, "try-error")) {
return(paste("Model training failed:", attr(model, "condition")$message) %>% add_class("svm"))
}
if (is.null(model) || !inherits(model, "svm")) {
return("Model training failed: Generated SVM model is invalid" %>% add_class("svm"))
}
## ---- 附加关键信息----
model$df_name <- df_name
model$rvar <- rvar
model$evar <- evar
......@@ -86,7 +116,21 @@ svm <- function(dataset, rvar, evar,
model$gamma <- gamma
model$kernel <- kernel
as.list(environment()) %>% add_class(c("svm", "model"))
## ---- 返回模型对象----
list(
model = model,
df_name = df_name,
rvar = rvar,
evar = evar,
type = type,
lev = lev,
wtsname = wtsname,
seed = seed,
cost = cost,
gamma = gamma,
kernel = kernel,
dataset = dataset
) %>% add_class(c("svm", "model"))
}
#' Center or standardize variables in a data frame
......@@ -101,9 +145,9 @@ scale_df <- function(dataset, center = TRUE, scale = TRUE,
descr <- attr(dataset, "description") # 保留原描述属性
if (calc) {
# 计算均值(忽略NA)
# 计算均值
ms <- sapply(dataset[, cn, drop = FALSE], function(x) mean(x, na.rm = TRUE))
# 计算标准差(忽略NA,样本标准差ddof=1,避免除以零)
# 计算标准差
sds <- sapply(dataset[, cn, drop = FALSE], function(x) {
sd_val <- sd(x, na.rm = TRUE)
ifelse(sd_val == 0, 1, sd_val)
......@@ -135,7 +179,7 @@ summary.svm <- function(object, prn = TRUE, ...) {
svm_model <- object$model
n_obs <- nrow(object$dataset)
wtsname <- object$wtsname # 可能是 NULL 或长度 0 字符
wtsname <- object$wtsname
cat("Support Vector Machine\n")
cat(sprintf("Kernel type : %s (%s)\n", object$kernel, object$type))
......@@ -146,7 +190,7 @@ summary.svm <- function(object, prn = TRUE, ...) {
}
cat(sprintf("Explanatory variables: %s\n", paste(object$evar, collapse = ", ")))
if (!is.null(wtsname) && length(wtsname) && wtsname != "") {
if (!is.null(wtsname) && nzchar(wtsname)) {
cat(sprintf("Weights used : %s\n", wtsname))
}
......@@ -156,56 +200,79 @@ summary.svm <- function(object, prn = TRUE, ...) {
}
if (!is.na(object$seed)) cat(sprintf("Seed : %s\n", object$seed))
# 支持向量计数
if (object$type == "classification") {
n_sv_per_class <- svm_model$nSV
total_sv <- sum(n_sv_per_class)
n_sv_per_class <- if (!is.null(svm_model$nSV)) svm_model$nSV else c(0, 0)
total_sv <- sum(n_sv_per_class)
sv_info <- paste(sprintf("class %s: %d", seq_along(n_sv_per_class), n_sv_per_class), collapse = ", ")
cat(sprintf("Support vectors : %d (%s)\n", total_sv, sv_info))
} else {
total_sv <- sum(svm_model$nSV)
total_sv <- if (!is.null(svm_model$index)) length(svm_model$index) else 0
cat(sprintf("Support vectors : %d\n", total_sv))
}
## ---- 权重样本数计算同样保护 ----
if (!is.null(wtsname) && length(wtsname) && wtsname != "") {
nr_obs <- n_obs
if (!is.null(wtsname) && nzchar(wtsname) && wtsname %in% names(object$dataset)) {
weights_values <- as.numeric(object$dataset[[wtsname]])
nr_obs <- if (all(!is.na(weights_values))) sum(weights_values, na.rm = TRUE) else n_obs
} else {
nr_obs <- n_obs
weights_clean <- weights_values[!is.na(weights_values)]
if (length(weights_clean) > 0) {
sum_weights <- sum(weights_clean)
if (sum_weights > 0) nr_obs <- sum_weights
}
}
cat(sprintf("Nr_obs : %s\n", format_nr(nr_obs, dec = 0)))
## ---- 系数输出(仅线性核) ----
if (prn) {
cat("Coefficients/Support Vectors:\n")
if (object$kernel == "linear") {
if (is.null(svm_model$w) || length(svm_model$w) == 0) {
cat(" Linear kernel coefficients not available (possible reasons:\n")
cat(" - Insufficient support vectors (data may be linearly inseparable)\n")
cat(" - Model training did not converge)\n")
if (object$type == "classification") {
cat(sprintf(" Support vectors count: %d\n", sum(svm_model$nSV)))
}
} else {
feat_coefs_raw <- as.numeric(svm_model$w)
bias <- as.numeric(-svm_model$rho)
n_evar <- length(object$evar)
feat_coefs <- matrix(data = feat_coefs_raw, nrow = 1, ncol = n_evar, byrow = TRUE,
dimnames = list(NULL, object$evar))
coef_data <- data.frame(
Variable = c(object$evar, "bias"),
Value = as.numeric(c(as.vector(feat_coefs), bias)),
stringsAsFactors = FALSE,
check.names = FALSE
)
for (i in seq(1, nrow(coef_data), 2)) {
if (i + 1 > nrow(coef_data)) {
cat(sprintf(" %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i]))
} else {
cat(sprintf(" %-12s: %.2f %-12s: %.2f\n",
coef_data$Variable[i], coef_data$Value[i],
coef_data$Variable[i+1], coef_data$Value[i+1]))
feat_coefs_raw <- rep(0, length(object$evar))
bias <- 0
if (!is.null(svm_model$terms) && !is.null(svm_model$coefs) && !is.null(svm_model$SV)) {
tryCatch({
model_vars <- attr(svm_model$terms, "term.labels")
coef_mat <- as.matrix(svm_model$coefs)
sv_mat <- as.matrix(svm_model$SV)
if (nrow(coef_mat) > 0 && nrow(sv_mat) > 0) {
w_full <- as.numeric(t(coef_mat) %*% sv_mat)
sv_colnames <- colnames(sv_mat)
if (!is.null(sv_colnames)) {
# 按变量名精确匹配
for (i in seq_along(object$evar)) {
var_name <- object$evar[i]
matching_cols <- which(sv_colnames == var_name)
if (length(matching_cols) > 0) {
feat_coefs_raw[i] <- sum(w_full[matching_cols])
}
}
} else {
# 无列名则按位置取
n_evar <- length(object$evar)
feat_coefs_raw[1:min(n_evar, length(w_full))] <- w_full[1:min(n_evar, length(w_full))]
}
bias <- as.numeric(-svm_model$rho)
}
}, error = function(e) {
warning("系数提取失败: ", e$message, call. = FALSE)
})
}
# 输出
coef_data <- data.frame(
Variable = c(object$evar, "bias"),
Value = c(feat_coefs_raw, bias),
stringsAsFactors = FALSE
)
for (i in seq(1, nrow(coef_data), 2)) {
if (i + 1 > nrow(coef_data)) {
cat(sprintf(" %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i]))
} else {
cat(sprintf(" %-12s: %.2f %-12s: %.2f\n",
coef_data$Variable[i], coef_data$Value[i],
coef_data$Variable[i+1], coef_data$Value[i+1]))
}
}
} else {
......@@ -256,134 +323,234 @@ plot.svm <- function(x,
return("No valid plots selected for SVM")
}
## 返回 patchwork 对象(Shiny 自动打印)
## 返回 patchwork 对象
patchwork::wrap_plots(plot_list, ncol = min(2, length(plot_list))) %>%
{ if (isTRUE(shiny)) . else print(.) }
}
#' Predict method for the svm function
#' Predict method for SVM model
#' @export
predict.svm <- function(object, pred_data = NULL, pred_cmd = "",
dec = 3, envir = parent.frame(), ...) {
## 1. 基础校验
predict.svm <- function(object, pred_data = NULL, pred_cmd = "", dec = 3,
envir = parent.frame(), ...) {
# 1. 基础校验
if (is.character(object)) return(object)
if (!inherits(object, "svm")) stop("Object must be of class 'svm'")
if (is.null(object$model) || !inherits(object$model, "svm")) {
err_msg <- "Prediction failed: Invalid SVM model"
err_obj <- structure(list(error = err_msg), class = "svm.predict.error")
return(err_obj)
}
svm_info <- object
# 2. 处理预测数据
if (is.character(pred_data) && nzchar(pred_data)) {
if (!exists(pred_data, envir = envir)) {
err_obj <- structure(list(error = sprintf("Dataset '%s' not found", pred_data)), class = "svm.predict.error")
return(err_obj)
}
pred_data <- get(pred_data, envir = envir)
}
has_data <- !is.null(pred_data) && is.data.frame(pred_data) && nrow(pred_data) > 0
has_cmd <- is.character(pred_cmd) && nzchar(pred_cmd)
if (!has_data && !has_cmd) {
err_obj <- structure(list(error = "Please select data and/or specify a command to generate predictions."), class = "svm.predict.error")
return(err_obj)
}
## 2.1 确定预测数据源
if (is.null(pred_data) || is.character(pred_data)) {
# 当pred_data为NULL或字符(数据集名)时,使用训练数据
pred_data_raw <- object$dataset[, object$evar, drop = FALSE]
## ensure you have a name for the prediction dataset
if (is.data.frame(pred_data)) {
df_name <- deparse(substitute(pred_data))
} else {
# 当pred_data是数据框时,直接使用
pred_data_raw <- as.data.frame(pred_data)
df_name <- pred_data
}
pred_names <- colnames(pred_data_raw)
missing_vars <- setdiff(object$evar, pred_names)
if (length(missing_vars) > 0) {
msg <- paste0(
"NA\n"
)
return(msg %>% add_class("svm.predict"))
}
pred_data <- pred_data_raw[, object$evar, drop = FALSE]
if (!is.empty(pred_cmd)) {
pred_cmd <- gsub("\\s{2,}", " ", pred_cmd) %>%
gsub(";\\s+", ";", .) %>%
strsplit(";")[[1]]
for (cmd in pred_cmd) {
if (grepl("=", cmd)) {
var_val <- strsplit(trimws(cmd), "=")[[1]]
var <- trimws(var_val[1])
val <- try(eval(parse(text = trimws(var_val[2])), envir = envir), silent = TRUE)
if (!inherits(val, "try-error")) pred_data[[var]] <- val
else warning(sprintf("Invalid command '%s': using original values", cmd))
}
# 3. 预测核心函数 - 修改参数名以匹配NN
pfun <- function(model, pred, se, conf_lev) {
# 验证数据
missing_vars <- setdiff(svm_info$evar, colnames(pred))
if (length(missing_vars) > 0) {
return(paste("Missing variables:", paste(missing_vars, collapse = ", ")))
}
pred_data <- unique(pred_data)
}
## 3. 变量类型与因子水平校验
train_types <- sapply(object$dataset[, object$evar], class)
pred_types <- sapply(pred_data, class)
type_mismatch <- names(which(train_types != pred_types))
if (length(type_mismatch) > 0) {
return(paste0("Variable type mismatch (train vs pred):\n",
paste(sprintf(" %s: %s vs %s", type_mismatch,
train_types[type_mismatch], pred_types[type_mismatch]),
collapse = "\n")) %>% add_class("svm.predict"))
}
for (var in object$evar) {
if (is.factor(object$dataset[[var]])) {
train_levs <- levels(object$dataset[[var]])
pred_levs <- levels(pred_data[[var]])
if (!identical(train_levs, pred_levs)) {
pred_data[[var]] <- factor(pred_data[[var]], levels = train_levs)
warning(sprintf("Aligned factor levels for '%s' to match training data", var))
# 内部标准化
ms <- attr(svm_info$dataset, "radiant_ms")
sds <- attr(svm_info$dataset, "radiant_sds")
if (!is.null(ms) && !is.null(sds)) {
isNum <- sapply(pred, is.numeric)
cn <- names(isNum)[isNum]
for (var in cn) {
if (var %in% names(ms) && var %in% names(sds) && sds[var] != 0) {
pred[[var]] <- (pred[[var]] - ms[var]) / sds[var]
}
}
}
}
## 4. 标准化对齐
train_ms <- attr(object$dataset, "radiant_ms")
train_sds <- attr(object$dataset, "radiant_sds")
train_sf <- attr(object$dataset, "radiant_sf") %||% 2
if (!is.null(train_ms) && !is.null(train_sds)) {
pred_data <- scale_df(
dataset = pred_data,
center = TRUE, scale = TRUE,
sf = train_sf, wts = NULL,
calc = FALSE
)
attr(pred_data, "radiant_ms") <- train_ms
attr(pred_data, "radiant_sds") <- train_sds
}
## 5. 生成预测值
predict_args <- list(
object = object$model,
newdata = pred_data,
na.action = na.omit
)
pred_result <- if (object$type == "classification") {
predict_args$type <- "class"
pred_class <- do.call(predict, predict_args)
if (isTRUE(object$model$param$probability)) {
predict_args$type <- "probabilities"
pred_prob <- do.call(predict, predict_args) %>%
as.data.frame() %>%
set_colnames(paste0("Prob_", colnames(.)))
data.frame(
Predicted_Class = as.character(pred_class),
pred_prob,
stringsAsFactors = FALSE
# 调试信息:检查模型是否启用了概率
cat("Model probability enabled:", svm_info$model$prob.model, "\n")
# 执行预测
pred_result <- try({
predict(
svm_info$model,
newdata = pred,
probability = TRUE, # 始终设置为TRUE
decision.values = TRUE
)
}, silent = TRUE)
if (inherits(pred_result, "try-error")) {
return(paste("Prediction failed:", attr(pred_result, "condition")$message))
}
# 4. 结果整理
if (svm_info$type == "classification") {
# 正确的属性名是"probabilities"
prob_mat <- attr(pred_result, "probabilities")
if (!is.null(prob_mat)) {
# 调试:显示概率矩阵的列名
cat("Probability matrix columns:", paste(colnames(prob_mat), collapse = ", "), "\n")
# 更智能的列名匹配
target_level <- as.character(svm_info$lev)
# 尝试多种匹配方式
matching_col <- NULL
# 1. 精确匹配
exact_match <- grep(paste0("^", target_level, "$"), colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(exact_match) > 0) {
matching_col <- exact_match
}
# 2. 部分匹配
else {
partial_match <- grep(target_level, colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(partial_match) > 0) {
matching_col <- partial_match[1] # 取第一个匹配
}
}
# 3. 如果还是没有匹配,尝试逻辑值匹配
if (is.null(matching_col) || length(matching_col) == 0) {
if (target_level %in% c("TRUE", "Yes", "1", "1.0")) {
logical_match <- grep("TRUE", colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(logical_match) > 0) {
matching_col <- logical_match[1]
}
} else if (target_level %in% c("FALSE", "No", "0", "0.0")) {
logical_match <- grep("FALSE", colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(logical_match) > 0) {
matching_col <- logical_match[1]
}
}
}
# 4. 作为最后手段,使用第一列
if (is.null(matching_col) || length(matching_col) == 0) {
if (ncol(prob_mat) > 0) {
matching_col <- colnames(prob_mat)[1]
cat("Warning: Using first probability column '", matching_col, "' for level '", target_level, "'\n")
}
}
if (!is.null(matching_col) && length(matching_col) > 0) {
surv_prob <- as.numeric(prob_mat[, matching_col])
pred_df <- data.frame(
Prediction = round(surv_prob, dec),
stringsAsFactors = FALSE
)
} else {
# 备用方案:使用决策值
decision_vals <- attr(pred_result, "decision.values")
if (!is.null(decision_vals)) {
# 将决策值转换为概率
pred_prob <- 1 / (1 + exp(-decision_vals))
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
} else {
# 最后手段:使用预测类
pred_class <- as.character(pred_result)
pred_prob <- ifelse(pred_class == target_level, 1, 0)
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
}
}
} else {
# 备用方案:使用决策值
decision_vals <- attr(pred_result, "decision.values")
if (!is.null(decision_vals)) {
# 将决策值转换为概率
pred_prob <- 1 / (1 + exp(-decision_vals))
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
} else {
# 最后手段:使用预测类
target_level <- as.character(svm_info$lev)
pred_class <- as.character(pred_result)
pred_prob <- ifelse(pred_class == target_level, 1, 0)
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
}
}
} else {
data.frame(Predicted_Class = as.character(pred_class), stringsAsFactors = FALSE)
# 回归模型
pred_values <- as.numeric(pred_result)
# 关键:添加反标准化处理
rvar_ms <- if (!is.null(ms) && svm_info$rvar %in% names(ms)) ms[svm_info$rvar] else 0
rvar_sds <- if (!is.null(sds) && svm_info$rvar %in% names(sds)) sds[svm_info$rvar] else 1
# 反标准化
pred_values <- pred_values * rvar_sds + rvar_ms
pred_df <- data.frame(
Prediction = round(pred_values, dec),
stringsAsFactors = FALSE
)
}
} else {
predict_args$type <- "response"
pred_val <- do.call(predict, predict_args)
pred_val <- as.numeric(pred_val)
data.frame(Predicted_Value = round(pred_val, dec), stringsAsFactors = FALSE)
}
pred_result <- cbind(pred_data, pred_result)
attr(pred_result, "svm_meta") <- list(
kernel = object$kernel,
cost = object$cost,
gamma = if (object$kernel != "linear") object$gamma else NA,
seed = object$seed,
train_data = object$df_name,
model_type = object$type
return(pred_df)
}
# 5. 调用预测框架 - 与NN完全一致
result <- predict_model(
object,
pfun,
"svm.predict", # 模型类型
pred_data,
pred_cmd,
conf_lev = 0.95,
se = FALSE,
dec,
envir = envir
) %>%
set_attr("radiant_pred_data", df_name)
# 6. 结果元数据
if (inherits(result, "svm.predict.error")) {
return(result$error)
}
if (is.character(result)) return(result)
result <- add_class(result, "svm.predict")
attr(result, "svm_meta") <- list(
model_type = svm_info$type,
kernel = svm_info$kernel,
cost = svm_info$cost,
gamma = if (svm_info$kernel != "linear") svm_info$gamma else NA,
seed = svm_info$seed,
train_data = svm_info$df_name,
dec = dec
)
attr(pred_result, "class") <- c("svm.predict", "data.frame")
return(pred_result)
return(result)
}
#' Print method for predict.svm
......@@ -396,16 +563,30 @@ print.svm.predict <- function(x, ..., n = 10) {
if (!inherits(x, "svm.predict")) stop("Object must be of class 'svm.predict'")
meta <- attr(x, "svm_meta")
if (is.null(meta)) {
meta <- list(
model_type = "Unknown",
kernel = "Unknown",
cost = NA,
gamma = NA,
seed = NA,
train_data = "Unknown",
dec = 1
)
}
dec <- meta$dec %||% 1
n_pred <- nrow(x)
show_n <- if (n < 0 || n >= n_pred) n_pred else n
cat("SVM Predictions\n")
cat(sprintf("Model Type : %s\n", tools::toTitleCase(meta$model_type)))
cat(sprintf("Kernel : %s\n", meta$kernel))
cat(sprintf("Cost (C) : %.2f\n", meta$cost))
model_type <- if (is.empty(meta$model_type)) "Unknown" else as.character(meta$model_type)
cat(sprintf("Model Type : %s\n", tools::toTitleCase(model_type)))
cat(sprintf("Kernel : %s\n", ifelse(is.empty(meta$kernel), "Unknown", meta$kernel)))
cat(sprintf("Cost (C) : %.2f\n", ifelse(is.na(meta$cost), 0, meta$cost)))
if (!is.na(meta$gamma)) cat(sprintf("Gamma : %.2f\n", meta$gamma))
if (!is.na(meta$seed)) cat(sprintf("Seed : %s\n", meta$seed))
cat(sprintf("Training Dataset : %s\n", meta$train_data))
cat(sprintf("Training Dataset : %s\n", ifelse(is.empty(meta$train_data), "Unknown", meta$train_data)))
cat(sprintf("Total Predictions : %d\n", n_pred))
if (n_pred == 0) {
......@@ -413,8 +594,13 @@ print.svm.predict <- function(x, ..., n = 10) {
return(invisible(x))
}
x_show <- x[1:show_n, , drop = FALSE] # 确保保持数据框结构
x_show <- x[1:show_n, , drop = FALSE]
num_cols <- sapply(x_show, is.numeric)
x_show[num_cols] <- lapply(x_show[num_cols], function(col) {
sprintf(paste0("%.", dec, "f"), col)
})
# 重新计算列宽
col_widths <- sapply(colnames(x_show), function(cn) {
max(nchar(cn), max(nchar(as.character(x_show[[cn]])), na.rm = TRUE))
})
......@@ -435,38 +621,13 @@ print.svm.predict <- function(x, ..., n = 10) {
}
if (show_n < n_pred) {
cat(sprintf("\n... (showing first %d of %d; use 'n=-1' to view all)\n",
cat(sprintf("\n... (showing first %d of %d)\n",
show_n, n_pred))
}
return(invisible(x))
}
#' Cross-validation for SVM
#' @export
cv.svm <- function(object, K = 5, repeats = 1,
kernel = c("linear", "radial"),
cost = 2^(-2:2),
gamma = 2^(-2:2),
seed = 1234, trace = TRUE, fun, ...) {
if (!inherits(object, "svm")) stop("Object must be of class 'svm'")
tune_grid <- expand.grid(
kernel = kernel,
cost = cost,
gamma = if (any(kernel %in% c("radial", "poly", "sigmoid"))) gamma else NA
)
cv_results <- data.frame(
mean_perf = rep(NA, nrow(tune_grid)),
std_perf = rep(NA, nrow(tune_grid)),
kernel = tune_grid$kernel,
cost = tune_grid$cost,
gamma = tune_grid$gamma
)
cv_results
}
# ---- 决策边界 ----
#' @export
svm_boundary_plot <- function(object, size, custom) {
......
......@@ -79,11 +79,12 @@ options(
i18n$t("Estimate"),
tabPanel(i18n$t("Linear regression (OLS)"), uiOutput("regress")),
tabPanel(i18n$t("Logistic regression (GLM)"), uiOutput("logistic")),
tabPanel(i18n$t("Cox Proportional Hazards Regression"),uiOutput("coxp")),
tabPanel(i18n$t("Multinomial logistic regression (MNL)"), uiOutput("mnl")),
tabPanel(i18n$t("Naive Bayes"), uiOutput("nb")),
tabPanel(i18n$t("Neural Network"), uiOutput("nn")),
tabPanel(i18n$t("Support Vector Machine (SVM)"),uiOutput("svm")),
"----", i18n$t("Survival Analysis"),
tabPanel(i18n$t("Cox Proportional Hazards Regression"),uiOutput("coxp")),
"----", i18n$t("Trees"),
tabPanel(i18n$t("Classification and regression trees"), uiOutput("crtree")),
tabPanel(i18n$t("Random Forest"), uiOutput("rf")),
......
......@@ -23,7 +23,7 @@ svm_inputs <- reactive({
svm_args
})
## 预测参数(保留命令模式,未改动)
## 预测参数
svm_pred_args <- as.list(if (exists("predict.svm")) {
formals(predict.svm)
} else {
......@@ -52,7 +52,7 @@ svm_pred_inputs <- reactive({
return(svm_pred_args)
})
## 绘图参数(砍掉vip、pdp、svm_margin)
## 绘图参数
svm_plot_args <- as.list(if (exists("plot.svm")) {
formals(plot.svm)
} else {
......@@ -77,11 +77,7 @@ output$ui_svm_rvar <- renderUI({
vars <- varnames()[isNum]
}
})
init <- if (input$svm_type == "classification") {
if (is.empty(input$logit_rvar)) isolate(input$svm_rvar) else input$logit_rvar
} else {
if (is.empty(input$reg_rvar)) isolate(input$svm_rvar) else input$reg_rvar
}
init <- state_single("svm_rvar", vars, isolate(input$svm_rvar))
selectInput(
inputId = "svm_rvar",
label = i18n$t("Response variable:"),
......@@ -109,11 +105,7 @@ output$ui_svm_evar <- renderUI({
if (not_available(input$svm_rvar)) return()
vars <- varnames()
if (length(vars) > 0) vars <- vars[-which(vars == input$svm_rvar)]
init <- if (input$svm_type == "classification") {
if (is.empty(input$logit_evar)) isolate(input$svm_evar) else input$logit_evar
} else {
if (is.empty(input$reg_evar)) isolate(input$svm_evar) else input$reg_evar
}
init <- state_multiple("svm_evar", vars, isolate(input$svm_evar))
selectInput(
inputId = "svm_evar",
label = i18n$t("Explanatory variables:"),
......@@ -141,7 +133,7 @@ output$ui_svm_wts <- renderUI({
)
})
## 存储预测值UI(残差存储已删除)
## 存储预测值UI
output$ui_svm_store_pred_name <- renderUI({
init <- state_init("svm_store_pred_name", "pred_svm") %>%
sub("\\d{1,}$", "", .) %>%
......@@ -164,7 +156,7 @@ observeEvent(input$svm_type, {
updateSelectInput(session = session, inputId = "svm_plots", selected = "none")
})
## 绘图选项UI(已删vip、pdp、svm_margin)
## 绘图选项UI
output$ui_svm_plots <- renderUI({
req(input$svm_type)
avail_plots <- svm_plots
......@@ -178,19 +170,7 @@ output$ui_svm_plots <- renderUI({
)
})
## 数据点数量UI(仅dashboard用,保留)
output$ui_svm_nrobs <- renderUI({
nrobs <- nrow(.get_data())
choices <- c("1,000" = 1000, "5,000" = 5000, "10,000" = 10000, "All" = -1) %>%
.[. < nrobs]
selectInput(
"svm_nrobs", i18n$t("Number of data points plotted:"),
choices = choices,
selected = state_single("svm_nrobs", choices, 1000)
)
})
## 主UI面板(已删残差存储入口)
## 主UI面板
output$ui_svm <- renderUI({
req(input$dataset)
tagList(
......@@ -258,7 +238,7 @@ output$ui_svm <- renderUI({
)
)
),
# 预测面板(残差存储已删除)
# 预测面板
conditionalPanel(
condition = "input.tabs_svm == 'Predict'",
selectInput(
......@@ -303,19 +283,10 @@ output$ui_svm <- renderUI({
)
)
),
# 绘图面板(已删vip、pdp、svm_margin)
# 绘图面板
conditionalPanel(
condition = "input.tabs_svm == 'Plot'",
uiOutput("ui_svm_plots"),
conditionalPanel(
condition = "input.svm_plots == 'pred_plot'",
uiOutput("ui_svm_incl"),
uiOutput("ui_svm_incl_int")
),
conditionalPanel(
condition = "input.svm_plots == 'dashboard'",
uiOutput("ui_svm_nrobs")
)
uiOutput("ui_svm_plots")
)
),
# 帮助和报告面板
......@@ -327,7 +298,7 @@ output$ui_svm <- renderUI({
)
})
## 绘图尺寸计算(已删vip、pdp、svm_margin)
## 绘图尺寸计算
svm_plot <- reactive({
if (svm_available() != "available") return()
if (is.empty(input$svm_plots, "none")) return()
......@@ -337,9 +308,6 @@ svm_plot <- reactive({
plot_width <- 650
if ("decision_boundary" %in% input$svm_plots) {
plot_height <- 500
} else if (input$svm_plots == "pred_plot") {
nr_vars <- length(input$svm_incl) + length(input$svm_incl_int)
plot_height <- max(250, ceiling(nr_vars / 2) * 250)
} else {
plot_height <- max(500, length(res$evar) * 30)
}
......@@ -349,7 +317,7 @@ svm_plot <- reactive({
svm_plot_width <- function() svm_plot()$plot_width %||% 650
svm_plot_height <- function() svm_plot()$plot_height %||% 500
## 主输出面板(已删残差存储)
## 主输出面板
output$svm <- renderUI({
register_print_output("summary_svm", ".summary_svm")
register_print_output("predict_svm", ".predict_print_svm")
......@@ -393,14 +361,14 @@ svm_available <- reactive({
}
})
## 核心函数壳子
## 核心函数
.svm <- eventReactive(input$svm_run, {
svi <- svm_inputs()
svi$envir <- r_data
withProgress(message = i18n$t("Estimating SVM model"), value = 1, do.call(svm, svi))
})
## 辅助输出函数壳子
## 辅助输出函数
.summary_svm <- reactive({
if (not_pressed(input$svm_run)) return(i18n$t("** Press the Estimate button to estimate the SVM model **"))
if (svm_available() != "available") return(svm_available())
......@@ -444,7 +412,7 @@ svm_available <- reactive({
withProgress(message = i18n$t("Generating SVM plots"), value = 1, do.call(plot, c(list(x = .svm()), pinp)))
})
## 存储预测值(残差存储已删除)
## 存储预测值
observeEvent(input$svm_store_pred, {
req(
pressed(input$svm_run),
......@@ -457,14 +425,14 @@ observeEvent(input$svm_store_pred, {
base_col_name <- fix_names(input$svm_store_pred_name)
meta <- attr(pred_result, "svm_meta")
pred_cols <- if (meta$model_type == "classification") {
colnames(pred_result)[grepl("^Predicted_Class|^Prob_", colnames(pred_result))]
pred_cols <- if (meta$model_type %in% c("classification", "regression")) {
colnames(pred_result)[colnames(pred_result) == "Prediction"]
} else {
"Predicted_Value"
NULL
}
new_col_names <- if (length(pred_cols) == 1) base_col_name else {
suffix <- gsub("^Predicted_|^Prob_", "", pred_cols)
paste0(base_col_name, "_", suffix)
suffix <- gsub("^Prediction", "", pred_cols)
paste0(base_col_name, ifelse(suffix == "", "", paste0("_", suffix)))
}
colnames(pred_result)[match(pred_cols, colnames(pred_result))] <- new_col_names
......@@ -510,9 +478,75 @@ download_handler(
height = svm_plot_height
)
## 报告生成(空壳,保留接口)
svm_report <- function() invisible()
svm_report <- function() {
if (is.empty(input$svm_evar)) {
showNotification(i18n$t("Select at least one explanatory variable to generate report"), type = "error")
return(invisible())
}
outputs <- c("summary")
inp_out <- list(list(prn = TRUE), "")
figs <- FALSE
xcmd <- ""
if (!is.empty(input$svm_plots, "none")) {
inp <- check_plot_inputs(svm_plot_inputs())
inp$size <- NULL
inp_out[[2]] <- clean_args(inp, svm_plot_args[-1])
inp_out[[2]]$custom <- FALSE
outputs <- c(outputs, "plot")
figs <- TRUE
}
if (!is.empty(input$svm_predict, "none") &&
(!is.empty(input$svm_pred_data) || !is.empty(input$svm_pred_cmd))) {
pred_args <- clean_args(svm_pred_inputs(), svm_pred_args[-1])
if (!is.empty(pred_args$pred_cmd)) {
pred_args$pred_cmd <- strsplit(pred_args$pred_cmd, ";\\s*")[[1]]
} else {
pred_args$pred_cmd <- NULL
}
if (!is.empty(pred_args$pred_data)) {
pred_args$pred_data <- as.symbol(pred_args$pred_data)
} else {
pred_args$pred_data <- NULL
}
inp_out[[2 + figs]] <- pred_args
outputs <- c(outputs, "pred <- predict")
xcmd <- paste0(xcmd, "print(pred, n = 10)")
if (input$svm_predict %in% c("data", "datacmd") && !is.empty(input$svm_store_pred_name)) {
fixed <- fix_names(input$svm_store_pred_name)
updateTextInput(session, "svm_store_pred_name", value = fixed)
xcmd <- paste0(
xcmd, "\n", input$svm_pred_data, " <- store(",
input$svm_pred_data, ", pred, name = \"", fixed, "\")"
)
}
}
svm_inp <- svm_inputs()
if (input$svm_type == "regression") {
svm_inp$lev <- NULL
}
if (input$svm_kernel == "linear") {
svm_inp$gamma <- NULL
}
update_report(
inp_main = clean_args(svm_inp, svm_args),
fun_name = "svm",
inp_out = inp_out,
outputs = outputs,
figs = figs,
fig.width = svm_plot_width(),
fig.height = svm_plot_height(),
xcmd = xcmd
)
}
## 报告生成
observeEvent(input$svm_report, {
r_info[["latest_screenshot"]] <- NULL
svm_report()
......
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/svm.R
\name{cv.svm}
\alias{cv.svm}
\title{Cross-validation for SVM}
\usage{
cv.svm(
object,
K = 5,
repeats = 1,
kernel = c("linear", "radial"),
cost = seq(0.1, 10, by = 0.5),
gamma = seq(0.1, 5, by = 0.5),
seed = 1234,
trace = TRUE,
fun,
...
)
}
\description{
Cross-validation for SVM
}
......@@ -4,17 +4,7 @@
\alias{plot.svm}
\title{Plot method for the svm function}
\usage{
\method{plot}{svm}(
x,
plots = "vip",
size = 12,
nrobs = -1,
incl = NULL,
incl_int = NULL,
shiny = FALSE,
custom = FALSE,
...
)
\method{plot}{svm}(x, plots = "none", size = 12, shiny = FALSE, custom = FALSE, ...)
}
\description{
Plot method for the svm function
......
......@@ -2,7 +2,7 @@
% Please edit documentation in R/svm.R
\name{predict.svm}
\alias{predict.svm}
\title{Predict method for the svm function}
\title{Predict method for SVM model}
\usage{
\method{predict}{svm}(
object,
......@@ -14,5 +14,5 @@
)
}
\description{
Predict method for the svm function
Predict method for SVM model
}
......@@ -4,9 +4,9 @@
\alias{varimp}
\title{Variable importance using the vip package and permutation importance}
\usage{
varimp(object, rvar, lev, data = NULL, seed = 1234)
varimp(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10)
varimp(object, rvar, lev, data = NULL, seed = 1234)
varimp(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10)
}
\arguments{
\item{object}{Model object created by Radiant}
......@@ -22,5 +22,5 @@ varimp(object, rvar, lev, data = NULL, seed = 1234)
\description{
Variable importance using the vip package and permutation importance
Variable importance using the vip package and permutation importance
Variable importance for SVM using permutation importance
}
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/nn.R, R/svm.R
% Please edit documentation in R/nn.R
\name{varimp_plot}
\alias{varimp_plot}
\title{Plot permutation importance}
\usage{
varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
}
\arguments{
......@@ -20,7 +18,5 @@ varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
\item{seed}{Random seed for reproducibility}
}
\description{
Plot permutation importance
Plot permutation importance
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment