Commit 53dd5828 authored by wuzekai's avatar wuzekai

更新了svm

parent e9df660e
......@@ -85,7 +85,6 @@ export(cv.crtree)
export(cv.gbt)
export(cv.nn)
export(cv.rforest)
export(cv.svm)
export(dtree)
export(dtree_parser)
export(evalbin)
......@@ -121,6 +120,9 @@ export(sim_splitter)
export(sim_summary)
export(simulater)
export(svm)
export(svm_boundary_plot)
export(svm_margin_plot)
export(svm_vip_plot)
export(test_specs)
export(uplift)
export(var_check)
......
......@@ -7,7 +7,7 @@ svm <- function(dataset, rvar, evar,
form, data_filter = "", arr = "",
rows = NULL, envir = parent.frame()) {
## ---- 参数合法性检查(SVM特有) ----
## ---- 参数合法性检查----
valid_kernels <- c("linear", "radial", "poly", "sigmoid")
if (!kernel %in% valid_kernels) {
return(paste0("Kernel must be one of: ", paste(valid_kernels, collapse = ", ")) %>%
......@@ -38,11 +38,32 @@ svm <- function(dataset, rvar, evar,
df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset))
dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir)
is_cat <- sapply(dataset, function(x) is.factor(x) || is.character(x) || is.logical(x))
cat_vars <- names(is_cat)[is_cat]
if (length(cat_vars) > 0) {
# 2. 字符型转因子(确保后续编码顺序一致)
for (var in cat_vars) {
if (is.character(dataset[[var]])) {
dataset[[var]] <- as.factor(dataset[[var]])
}
# 3. 因子/逻辑型转数值(按原始水平顺序编码,如level1→1, level2→2...)
dataset[[var]] <- as.numeric(dataset[[var]])
}
}
# 缺失值处理
dataset <- na.omit(dataset)
if (nrow(dataset) == 0) {
return("No valid samples after removing missing values. Please check your data." %>% add_class("svm"))
}
# 标准化
if ("standardize" %in% check) {
dataset <- scale_df(dataset)
}
## ---- 3. 分类任务的响应变量(转为因子) ----
## ---- 3. 分类任务的响应变量 ----
if (type == "classification") {
dataset[[rvar]] <- as.factor(dataset[[rvar]])
if (lev == "") {
......@@ -71,10 +92,19 @@ svm <- function(dataset, rvar, evar,
if (!is.na(seed)) set.seed(seed)
## ---- 6. 训练模型 ----
model <- do.call(e1071::svm, svm_input)
## ---- 模型训练----
model <- try({
do.call(e1071::svm, svm_input)
}, silent = TRUE)
if (inherits(model, "try-error")) {
return(paste("Model training failed:", attr(model, "condition")$message) %>% add_class("svm"))
}
if (is.null(model) || !inherits(model, "svm")) {
return("Model training failed: Generated SVM model is invalid" %>% add_class("svm"))
}
## ---- 7. 附加关键信息 ----
## ---- 附加关键信息----
model$df_name <- df_name
model$rvar <- rvar
model$evar <- evar
......@@ -86,7 +116,21 @@ svm <- function(dataset, rvar, evar,
model$gamma <- gamma
model$kernel <- kernel
as.list(environment()) %>% add_class(c("svm", "model"))
## ---- 返回模型对象----
list(
model = model,
df_name = df_name,
rvar = rvar,
evar = evar,
type = type,
lev = lev,
wtsname = wtsname,
seed = seed,
cost = cost,
gamma = gamma,
kernel = kernel,
dataset = dataset
) %>% add_class(c("svm", "model"))
}
#' Center or standardize variables in a data frame
......@@ -101,9 +145,9 @@ scale_df <- function(dataset, center = TRUE, scale = TRUE,
descr <- attr(dataset, "description") # 保留原描述属性
if (calc) {
# 计算均值(忽略NA)
# 计算均值
ms <- sapply(dataset[, cn, drop = FALSE], function(x) mean(x, na.rm = TRUE))
# 计算标准差(忽略NA,样本标准差ddof=1,避免除以零)
# 计算标准差
sds <- sapply(dataset[, cn, drop = FALSE], function(x) {
sd_val <- sd(x, na.rm = TRUE)
ifelse(sd_val == 0, 1, sd_val)
......@@ -135,7 +179,7 @@ summary.svm <- function(object, prn = TRUE, ...) {
svm_model <- object$model
n_obs <- nrow(object$dataset)
wtsname <- object$wtsname # 可能是 NULL 或长度 0 字符
wtsname <- object$wtsname
cat("Support Vector Machine\n")
cat(sprintf("Kernel type : %s (%s)\n", object$kernel, object$type))
......@@ -146,7 +190,7 @@ summary.svm <- function(object, prn = TRUE, ...) {
}
cat(sprintf("Explanatory variables: %s\n", paste(object$evar, collapse = ", ")))
if (!is.null(wtsname) && length(wtsname) && wtsname != "") {
if (!is.null(wtsname) && nzchar(wtsname)) {
cat(sprintf("Weights used : %s\n", wtsname))
}
......@@ -156,48 +200,72 @@ summary.svm <- function(object, prn = TRUE, ...) {
}
if (!is.na(object$seed)) cat(sprintf("Seed : %s\n", object$seed))
# 支持向量计数
if (object$type == "classification") {
n_sv_per_class <- svm_model$nSV
n_sv_per_class <- if (!is.null(svm_model$nSV)) svm_model$nSV else c(0, 0)
total_sv <- sum(n_sv_per_class)
sv_info <- paste(sprintf("class %s: %d", seq_along(n_sv_per_class), n_sv_per_class), collapse = ", ")
cat(sprintf("Support vectors : %d (%s)\n", total_sv, sv_info))
} else {
total_sv <- sum(svm_model$nSV)
total_sv <- if (!is.null(svm_model$index)) length(svm_model$index) else 0
cat(sprintf("Support vectors : %d\n", total_sv))
}
## ---- 权重样本数计算同样保护 ----
if (!is.null(wtsname) && length(wtsname) && wtsname != "") {
weights_values <- as.numeric(object$dataset[[wtsname]])
nr_obs <- if (all(!is.na(weights_values))) sum(weights_values, na.rm = TRUE) else n_obs
} else {
nr_obs <- n_obs
if (!is.null(wtsname) && nzchar(wtsname) && wtsname %in% names(object$dataset)) {
weights_values <- as.numeric(object$dataset[[wtsname]])
weights_clean <- weights_values[!is.na(weights_values)]
if (length(weights_clean) > 0) {
sum_weights <- sum(weights_clean)
if (sum_weights > 0) nr_obs <- sum_weights
}
}
cat(sprintf("Nr_obs : %s\n", format_nr(nr_obs, dec = 0)))
## ---- 系数输出(仅线性核) ----
if (prn) {
cat("Coefficients/Support Vectors:\n")
if (object$kernel == "linear") {
if (is.null(svm_model$w) || length(svm_model$w) == 0) {
cat(" Linear kernel coefficients not available (possible reasons:\n")
cat(" - Insufficient support vectors (data may be linearly inseparable)\n")
cat(" - Model training did not converge)\n")
if (object$type == "classification") {
cat(sprintf(" Support vectors count: %d\n", sum(svm_model$nSV)))
feat_coefs_raw <- rep(0, length(object$evar))
bias <- 0
if (!is.null(svm_model$terms) && !is.null(svm_model$coefs) && !is.null(svm_model$SV)) {
tryCatch({
model_vars <- attr(svm_model$terms, "term.labels")
coef_mat <- as.matrix(svm_model$coefs)
sv_mat <- as.matrix(svm_model$SV)
if (nrow(coef_mat) > 0 && nrow(sv_mat) > 0) {
w_full <- as.numeric(t(coef_mat) %*% sv_mat)
sv_colnames <- colnames(sv_mat)
if (!is.null(sv_colnames)) {
# 按变量名精确匹配
for (i in seq_along(object$evar)) {
var_name <- object$evar[i]
matching_cols <- which(sv_colnames == var_name)
if (length(matching_cols) > 0) {
feat_coefs_raw[i] <- sum(w_full[matching_cols])
}
}
} else {
feat_coefs_raw <- as.numeric(svm_model$w)
bias <- as.numeric(-svm_model$rho)
# 无列名则按位置取
n_evar <- length(object$evar)
feat_coefs <- matrix(data = feat_coefs_raw, nrow = 1, ncol = n_evar, byrow = TRUE,
dimnames = list(NULL, object$evar))
feat_coefs_raw[1:min(n_evar, length(w_full))] <- w_full[1:min(n_evar, length(w_full))]
}
bias <- as.numeric(-svm_model$rho)
}
}, error = function(e) {
warning("系数提取失败: ", e$message, call. = FALSE)
})
}
# 输出
coef_data <- data.frame(
Variable = c(object$evar, "bias"),
Value = as.numeric(c(as.vector(feat_coefs), bias)),
stringsAsFactors = FALSE,
check.names = FALSE
Value = c(feat_coefs_raw, bias),
stringsAsFactors = FALSE
)
for (i in seq(1, nrow(coef_data), 2)) {
if (i + 1 > nrow(coef_data)) {
cat(sprintf(" %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i]))
......@@ -207,7 +275,6 @@ summary.svm <- function(object, prn = TRUE, ...) {
coef_data$Variable[i+1], coef_data$Value[i+1]))
}
}
}
} else {
cat(" Non-linear kernel: Coefficients not available\n")
}
......@@ -256,134 +323,234 @@ plot.svm <- function(x,
return("No valid plots selected for SVM")
}
## 返回 patchwork 对象(Shiny 自动打印)
## 返回 patchwork 对象
patchwork::wrap_plots(plot_list, ncol = min(2, length(plot_list))) %>%
{ if (isTRUE(shiny)) . else print(.) }
}
#' Predict method for the svm function
#' Predict method for SVM model
#' @export
predict.svm <- function(object, pred_data = NULL, pred_cmd = "",
dec = 3, envir = parent.frame(), ...) {
## 1. 基础校验
predict.svm <- function(object, pred_data = NULL, pred_cmd = "", dec = 3,
envir = parent.frame(), ...) {
# 1. 基础校验
if (is.character(object)) return(object)
if (!inherits(object, "svm")) stop("Object must be of class 'svm'")
if (is.null(object$model) || !inherits(object$model, "svm")) {
err_msg <- "Prediction failed: Invalid SVM model"
err_obj <- structure(list(error = err_msg), class = "svm.predict.error")
return(err_obj)
}
svm_info <- object
## 2.1 确定预测数据源
if (is.null(pred_data) || is.character(pred_data)) {
# 当pred_data为NULL或字符(数据集名)时,使用训练数据
pred_data_raw <- object$dataset[, object$evar, drop = FALSE]
} else {
# 当pred_data是数据框时,直接使用
pred_data_raw <- as.data.frame(pred_data)
# 2. 处理预测数据
if (is.character(pred_data) && nzchar(pred_data)) {
if (!exists(pred_data, envir = envir)) {
err_obj <- structure(list(error = sprintf("Dataset '%s' not found", pred_data)), class = "svm.predict.error")
return(err_obj)
}
pred_data <- get(pred_data, envir = envir)
}
has_data <- !is.null(pred_data) && is.data.frame(pred_data) && nrow(pred_data) > 0
has_cmd <- is.character(pred_cmd) && nzchar(pred_cmd)
if (!has_data && !has_cmd) {
err_obj <- structure(list(error = "Please select data and/or specify a command to generate predictions."), class = "svm.predict.error")
return(err_obj)
}
pred_names <- colnames(pred_data_raw)
missing_vars <- setdiff(object$evar, pred_names)
## ensure you have a name for the prediction dataset
if (is.data.frame(pred_data)) {
df_name <- deparse(substitute(pred_data))
} else {
df_name <- pred_data
}
# 3. 预测核心函数 - 修改参数名以匹配NN
pfun <- function(model, pred, se, conf_lev) {
# 验证数据
missing_vars <- setdiff(svm_info$evar, colnames(pred))
if (length(missing_vars) > 0) {
msg <- paste0(
"NA\n"
return(paste("Missing variables:", paste(missing_vars, collapse = ", ")))
}
# 内部标准化
ms <- attr(svm_info$dataset, "radiant_ms")
sds <- attr(svm_info$dataset, "radiant_sds")
if (!is.null(ms) && !is.null(sds)) {
isNum <- sapply(pred, is.numeric)
cn <- names(isNum)[isNum]
for (var in cn) {
if (var %in% names(ms) && var %in% names(sds) && sds[var] != 0) {
pred[[var]] <- (pred[[var]] - ms[var]) / sds[var]
}
}
}
# 调试信息:检查模型是否启用了概率
cat("Model probability enabled:", svm_info$model$prob.model, "\n")
# 执行预测
pred_result <- try({
predict(
svm_info$model,
newdata = pred,
probability = TRUE, # 始终设置为TRUE
decision.values = TRUE
)
return(msg %>% add_class("svm.predict"))
}, silent = TRUE)
if (inherits(pred_result, "try-error")) {
return(paste("Prediction failed:", attr(pred_result, "condition")$message))
}
pred_data <- pred_data_raw[, object$evar, drop = FALSE]
# 4. 结果整理
if (svm_info$type == "classification") {
# 正确的属性名是"probabilities"
prob_mat <- attr(pred_result, "probabilities")
if (!is.null(prob_mat)) {
# 调试:显示概率矩阵的列名
cat("Probability matrix columns:", paste(colnames(prob_mat), collapse = ", "), "\n")
# 更智能的列名匹配
target_level <- as.character(svm_info$lev)
if (!is.empty(pred_cmd)) {
pred_cmd <- gsub("\\s{2,}", " ", pred_cmd) %>%
gsub(";\\s+", ";", .) %>%
strsplit(";")[[1]]
for (cmd in pred_cmd) {
if (grepl("=", cmd)) {
var_val <- strsplit(trimws(cmd), "=")[[1]]
var <- trimws(var_val[1])
val <- try(eval(parse(text = trimws(var_val[2])), envir = envir), silent = TRUE)
if (!inherits(val, "try-error")) pred_data[[var]] <- val
else warning(sprintf("Invalid command '%s': using original values", cmd))
# 尝试多种匹配方式
matching_col <- NULL
# 1. 精确匹配
exact_match <- grep(paste0("^", target_level, "$"), colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(exact_match) > 0) {
matching_col <- exact_match
}
# 2. 部分匹配
else {
partial_match <- grep(target_level, colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(partial_match) > 0) {
matching_col <- partial_match[1] # 取第一个匹配
}
pred_data <- unique(pred_data)
}
## 3. 变量类型与因子水平校验
train_types <- sapply(object$dataset[, object$evar], class)
pred_types <- sapply(pred_data, class)
type_mismatch <- names(which(train_types != pred_types))
if (length(type_mismatch) > 0) {
return(paste0("Variable type mismatch (train vs pred):\n",
paste(sprintf(" %s: %s vs %s", type_mismatch,
train_types[type_mismatch], pred_types[type_mismatch]),
collapse = "\n")) %>% add_class("svm.predict"))
# 3. 如果还是没有匹配,尝试逻辑值匹配
if (is.null(matching_col) || length(matching_col) == 0) {
if (target_level %in% c("TRUE", "Yes", "1", "1.0")) {
logical_match <- grep("TRUE", colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(logical_match) > 0) {
matching_col <- logical_match[1]
}
for (var in object$evar) {
if (is.factor(object$dataset[[var]])) {
train_levs <- levels(object$dataset[[var]])
pred_levs <- levels(pred_data[[var]])
if (!identical(train_levs, pred_levs)) {
pred_data[[var]] <- factor(pred_data[[var]], levels = train_levs)
warning(sprintf("Aligned factor levels for '%s' to match training data", var))
} else if (target_level %in% c("FALSE", "No", "0", "0.0")) {
logical_match <- grep("FALSE", colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(logical_match) > 0) {
matching_col <- logical_match[1]
}
}
}
## 4. 标准化对齐
train_ms <- attr(object$dataset, "radiant_ms")
train_sds <- attr(object$dataset, "radiant_sds")
train_sf <- attr(object$dataset, "radiant_sf") %||% 2
if (!is.null(train_ms) && !is.null(train_sds)) {
pred_data <- scale_df(
dataset = pred_data,
center = TRUE, scale = TRUE,
sf = train_sf, wts = NULL,
calc = FALSE
)
attr(pred_data, "radiant_ms") <- train_ms
attr(pred_data, "radiant_sds") <- train_sds
# 4. 作为最后手段,使用第一列
if (is.null(matching_col) || length(matching_col) == 0) {
if (ncol(prob_mat) > 0) {
matching_col <- colnames(prob_mat)[1]
cat("Warning: Using first probability column '", matching_col, "' for level '", target_level, "'\n")
}
}
## 5. 生成预测值
predict_args <- list(
object = object$model,
newdata = pred_data,
na.action = na.omit
if (!is.null(matching_col) && length(matching_col) > 0) {
surv_prob <- as.numeric(prob_mat[, matching_col])
pred_df <- data.frame(
Prediction = round(surv_prob, dec),
stringsAsFactors = FALSE
)
pred_result <- if (object$type == "classification") {
predict_args$type <- "class"
pred_class <- do.call(predict, predict_args)
if (isTRUE(object$model$param$probability)) {
predict_args$type <- "probabilities"
pred_prob <- do.call(predict, predict_args) %>%
as.data.frame() %>%
set_colnames(paste0("Prob_", colnames(.)))
data.frame(
Predicted_Class = as.character(pred_class),
pred_prob,
} else {
# 备用方案:使用决策值
decision_vals <- attr(pred_result, "decision.values")
if (!is.null(decision_vals)) {
# 将决策值转换为概率
pred_prob <- 1 / (1 + exp(-decision_vals))
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
} else {
data.frame(Predicted_Class = as.character(pred_class), stringsAsFactors = FALSE)
# 最后手段:使用预测类
pred_class <- as.character(pred_result)
pred_prob <- ifelse(pred_class == target_level, 1, 0)
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
}
}
} else {
# 备用方案:使用决策值
decision_vals <- attr(pred_result, "decision.values")
if (!is.null(decision_vals)) {
# 将决策值转换为概率
pred_prob <- 1 / (1 + exp(-decision_vals))
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
} else {
predict_args$type <- "response"
pred_val <- do.call(predict, predict_args)
pred_val <- as.numeric(pred_val)
data.frame(Predicted_Value = round(pred_val, dec), stringsAsFactors = FALSE)
}
pred_result <- cbind(pred_data, pred_result)
attr(pred_result, "svm_meta") <- list(
kernel = object$kernel,
cost = object$cost,
gamma = if (object$kernel != "linear") object$gamma else NA,
seed = object$seed,
train_data = object$df_name,
model_type = object$type
# 最后手段:使用预测类
target_level <- as.character(svm_info$lev)
pred_class <- as.character(pred_result)
pred_prob <- ifelse(pred_class == target_level, 1, 0)
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
}
}
} else {
# 回归模型
pred_values <- as.numeric(pred_result)
# 关键:添加反标准化处理
rvar_ms <- if (!is.null(ms) && svm_info$rvar %in% names(ms)) ms[svm_info$rvar] else 0
rvar_sds <- if (!is.null(sds) && svm_info$rvar %in% names(sds)) sds[svm_info$rvar] else 1
# 反标准化
pred_values <- pred_values * rvar_sds + rvar_ms
pred_df <- data.frame(
Prediction = round(pred_values, dec),
stringsAsFactors = FALSE
)
}
return(pred_df)
}
# 5. 调用预测框架 - 与NN完全一致
result <- predict_model(
object,
pfun,
"svm.predict", # 模型类型
pred_data,
pred_cmd,
conf_lev = 0.95,
se = FALSE,
dec,
envir = envir
) %>%
set_attr("radiant_pred_data", df_name)
# 6. 结果元数据
if (inherits(result, "svm.predict.error")) {
return(result$error)
}
if (is.character(result)) return(result)
result <- add_class(result, "svm.predict")
attr(result, "svm_meta") <- list(
model_type = svm_info$type,
kernel = svm_info$kernel,
cost = svm_info$cost,
gamma = if (svm_info$kernel != "linear") svm_info$gamma else NA,
seed = svm_info$seed,
train_data = svm_info$df_name,
dec = dec
)
attr(pred_result, "class") <- c("svm.predict", "data.frame")
return(pred_result)
return(result)
}
#' Print method for predict.svm
......@@ -396,16 +563,30 @@ print.svm.predict <- function(x, ..., n = 10) {
if (!inherits(x, "svm.predict")) stop("Object must be of class 'svm.predict'")
meta <- attr(x, "svm_meta")
if (is.null(meta)) {
meta <- list(
model_type = "Unknown",
kernel = "Unknown",
cost = NA,
gamma = NA,
seed = NA,
train_data = "Unknown",
dec = 1
)
}
dec <- meta$dec %||% 1
n_pred <- nrow(x)
show_n <- if (n < 0 || n >= n_pred) n_pred else n
cat("SVM Predictions\n")
cat(sprintf("Model Type : %s\n", tools::toTitleCase(meta$model_type)))
cat(sprintf("Kernel : %s\n", meta$kernel))
cat(sprintf("Cost (C) : %.2f\n", meta$cost))
model_type <- if (is.empty(meta$model_type)) "Unknown" else as.character(meta$model_type)
cat(sprintf("Model Type : %s\n", tools::toTitleCase(model_type)))
cat(sprintf("Kernel : %s\n", ifelse(is.empty(meta$kernel), "Unknown", meta$kernel)))
cat(sprintf("Cost (C) : %.2f\n", ifelse(is.na(meta$cost), 0, meta$cost)))
if (!is.na(meta$gamma)) cat(sprintf("Gamma : %.2f\n", meta$gamma))
if (!is.na(meta$seed)) cat(sprintf("Seed : %s\n", meta$seed))
cat(sprintf("Training Dataset : %s\n", meta$train_data))
cat(sprintf("Training Dataset : %s\n", ifelse(is.empty(meta$train_data), "Unknown", meta$train_data)))
cat(sprintf("Total Predictions : %d\n", n_pred))
if (n_pred == 0) {
......@@ -413,8 +594,13 @@ print.svm.predict <- function(x, ..., n = 10) {
return(invisible(x))
}
x_show <- x[1:show_n, , drop = FALSE] # 确保保持数据框结构
x_show <- x[1:show_n, , drop = FALSE]
num_cols <- sapply(x_show, is.numeric)
x_show[num_cols] <- lapply(x_show[num_cols], function(col) {
sprintf(paste0("%.", dec, "f"), col)
})
# 重新计算列宽
col_widths <- sapply(colnames(x_show), function(cn) {
max(nchar(cn), max(nchar(as.character(x_show[[cn]])), na.rm = TRUE))
})
......@@ -435,38 +621,13 @@ print.svm.predict <- function(x, ..., n = 10) {
}
if (show_n < n_pred) {
cat(sprintf("\n... (showing first %d of %d; use 'n=-1' to view all)\n",
cat(sprintf("\n... (showing first %d of %d)\n",
show_n, n_pred))
}
return(invisible(x))
}
#' Cross-validation for SVM
#' @export
cv.svm <- function(object, K = 5, repeats = 1,
kernel = c("linear", "radial"),
cost = 2^(-2:2),
gamma = 2^(-2:2),
seed = 1234, trace = TRUE, fun, ...) {
if (!inherits(object, "svm")) stop("Object must be of class 'svm'")
tune_grid <- expand.grid(
kernel = kernel,
cost = cost,
gamma = if (any(kernel %in% c("radial", "poly", "sigmoid"))) gamma else NA
)
cv_results <- data.frame(
mean_perf = rep(NA, nrow(tune_grid)),
std_perf = rep(NA, nrow(tune_grid)),
kernel = tune_grid$kernel,
cost = tune_grid$cost,
gamma = tune_grid$gamma
)
cv_results
}
# ---- 决策边界 ----
#' @export
svm_boundary_plot <- function(object, size, custom) {
......
......@@ -79,11 +79,12 @@ options(
i18n$t("Estimate"),
tabPanel(i18n$t("Linear regression (OLS)"), uiOutput("regress")),
tabPanel(i18n$t("Logistic regression (GLM)"), uiOutput("logistic")),
tabPanel(i18n$t("Cox Proportional Hazards Regression"),uiOutput("coxp")),
tabPanel(i18n$t("Multinomial logistic regression (MNL)"), uiOutput("mnl")),
tabPanel(i18n$t("Naive Bayes"), uiOutput("nb")),
tabPanel(i18n$t("Neural Network"), uiOutput("nn")),
tabPanel(i18n$t("Support Vector Machine (SVM)"),uiOutput("svm")),
"----", i18n$t("Survival Analysis"),
tabPanel(i18n$t("Cox Proportional Hazards Regression"),uiOutput("coxp")),
"----", i18n$t("Trees"),
tabPanel(i18n$t("Classification and regression trees"), uiOutput("crtree")),
tabPanel(i18n$t("Random Forest"), uiOutput("rf")),
......
......@@ -23,7 +23,7 @@ svm_inputs <- reactive({
svm_args
})
## 预测参数(保留命令模式,未改动)
## 预测参数
svm_pred_args <- as.list(if (exists("predict.svm")) {
formals(predict.svm)
} else {
......@@ -52,7 +52,7 @@ svm_pred_inputs <- reactive({
return(svm_pred_args)
})
## 绘图参数(砍掉vip、pdp、svm_margin)
## 绘图参数
svm_plot_args <- as.list(if (exists("plot.svm")) {
formals(plot.svm)
} else {
......@@ -77,11 +77,7 @@ output$ui_svm_rvar <- renderUI({
vars <- varnames()[isNum]
}
})
init <- if (input$svm_type == "classification") {
if (is.empty(input$logit_rvar)) isolate(input$svm_rvar) else input$logit_rvar
} else {
if (is.empty(input$reg_rvar)) isolate(input$svm_rvar) else input$reg_rvar
}
init <- state_single("svm_rvar", vars, isolate(input$svm_rvar))
selectInput(
inputId = "svm_rvar",
label = i18n$t("Response variable:"),
......@@ -109,11 +105,7 @@ output$ui_svm_evar <- renderUI({
if (not_available(input$svm_rvar)) return()
vars <- varnames()
if (length(vars) > 0) vars <- vars[-which(vars == input$svm_rvar)]
init <- if (input$svm_type == "classification") {
if (is.empty(input$logit_evar)) isolate(input$svm_evar) else input$logit_evar
} else {
if (is.empty(input$reg_evar)) isolate(input$svm_evar) else input$reg_evar
}
init <- state_multiple("svm_evar", vars, isolate(input$svm_evar))
selectInput(
inputId = "svm_evar",
label = i18n$t("Explanatory variables:"),
......@@ -141,7 +133,7 @@ output$ui_svm_wts <- renderUI({
)
})
## 存储预测值UI(残差存储已删除)
## 存储预测值UI
output$ui_svm_store_pred_name <- renderUI({
init <- state_init("svm_store_pred_name", "pred_svm") %>%
sub("\\d{1,}$", "", .) %>%
......@@ -164,7 +156,7 @@ observeEvent(input$svm_type, {
updateSelectInput(session = session, inputId = "svm_plots", selected = "none")
})
## 绘图选项UI(已删vip、pdp、svm_margin)
## 绘图选项UI
output$ui_svm_plots <- renderUI({
req(input$svm_type)
avail_plots <- svm_plots
......@@ -178,19 +170,7 @@ output$ui_svm_plots <- renderUI({
)
})
## 数据点数量UI(仅dashboard用,保留)
output$ui_svm_nrobs <- renderUI({
nrobs <- nrow(.get_data())
choices <- c("1,000" = 1000, "5,000" = 5000, "10,000" = 10000, "All" = -1) %>%
.[. < nrobs]
selectInput(
"svm_nrobs", i18n$t("Number of data points plotted:"),
choices = choices,
selected = state_single("svm_nrobs", choices, 1000)
)
})
## 主UI面板(已删残差存储入口)
## 主UI面板
output$ui_svm <- renderUI({
req(input$dataset)
tagList(
......@@ -258,7 +238,7 @@ output$ui_svm <- renderUI({
)
)
),
# 预测面板(残差存储已删除)
# 预测面板
conditionalPanel(
condition = "input.tabs_svm == 'Predict'",
selectInput(
......@@ -303,19 +283,10 @@ output$ui_svm <- renderUI({
)
)
),
# 绘图面板(已删vip、pdp、svm_margin)
# 绘图面板
conditionalPanel(
condition = "input.tabs_svm == 'Plot'",
uiOutput("ui_svm_plots"),
conditionalPanel(
condition = "input.svm_plots == 'pred_plot'",
uiOutput("ui_svm_incl"),
uiOutput("ui_svm_incl_int")
),
conditionalPanel(
condition = "input.svm_plots == 'dashboard'",
uiOutput("ui_svm_nrobs")
)
uiOutput("ui_svm_plots")
)
),
# 帮助和报告面板
......@@ -327,7 +298,7 @@ output$ui_svm <- renderUI({
)
})
## 绘图尺寸计算(已删vip、pdp、svm_margin)
## 绘图尺寸计算
svm_plot <- reactive({
if (svm_available() != "available") return()
if (is.empty(input$svm_plots, "none")) return()
......@@ -337,9 +308,6 @@ svm_plot <- reactive({
plot_width <- 650
if ("decision_boundary" %in% input$svm_plots) {
plot_height <- 500
} else if (input$svm_plots == "pred_plot") {
nr_vars <- length(input$svm_incl) + length(input$svm_incl_int)
plot_height <- max(250, ceiling(nr_vars / 2) * 250)
} else {
plot_height <- max(500, length(res$evar) * 30)
}
......@@ -349,7 +317,7 @@ svm_plot <- reactive({
svm_plot_width <- function() svm_plot()$plot_width %||% 650
svm_plot_height <- function() svm_plot()$plot_height %||% 500
## 主输出面板(已删残差存储)
## 主输出面板
output$svm <- renderUI({
register_print_output("summary_svm", ".summary_svm")
register_print_output("predict_svm", ".predict_print_svm")
......@@ -393,14 +361,14 @@ svm_available <- reactive({
}
})
## 核心函数壳子
## 核心函数
.svm <- eventReactive(input$svm_run, {
svi <- svm_inputs()
svi$envir <- r_data
withProgress(message = i18n$t("Estimating SVM model"), value = 1, do.call(svm, svi))
})
## 辅助输出函数壳子
## 辅助输出函数
.summary_svm <- reactive({
if (not_pressed(input$svm_run)) return(i18n$t("** Press the Estimate button to estimate the SVM model **"))
if (svm_available() != "available") return(svm_available())
......@@ -444,7 +412,7 @@ svm_available <- reactive({
withProgress(message = i18n$t("Generating SVM plots"), value = 1, do.call(plot, c(list(x = .svm()), pinp)))
})
## 存储预测值(残差存储已删除)
## 存储预测值
observeEvent(input$svm_store_pred, {
req(
pressed(input$svm_run),
......@@ -457,14 +425,14 @@ observeEvent(input$svm_store_pred, {
base_col_name <- fix_names(input$svm_store_pred_name)
meta <- attr(pred_result, "svm_meta")
pred_cols <- if (meta$model_type == "classification") {
colnames(pred_result)[grepl("^Predicted_Class|^Prob_", colnames(pred_result))]
pred_cols <- if (meta$model_type %in% c("classification", "regression")) {
colnames(pred_result)[colnames(pred_result) == "Prediction"]
} else {
"Predicted_Value"
NULL
}
new_col_names <- if (length(pred_cols) == 1) base_col_name else {
suffix <- gsub("^Predicted_|^Prob_", "", pred_cols)
paste0(base_col_name, "_", suffix)
suffix <- gsub("^Prediction", "", pred_cols)
paste0(base_col_name, ifelse(suffix == "", "", paste0("_", suffix)))
}
colnames(pred_result)[match(pred_cols, colnames(pred_result))] <- new_col_names
......@@ -510,9 +478,75 @@ download_handler(
height = svm_plot_height
)
## 报告生成(空壳,保留接口)
svm_report <- function() invisible()
svm_report <- function() {
if (is.empty(input$svm_evar)) {
showNotification(i18n$t("Select at least one explanatory variable to generate report"), type = "error")
return(invisible())
}
outputs <- c("summary")
inp_out <- list(list(prn = TRUE), "")
figs <- FALSE
xcmd <- ""
if (!is.empty(input$svm_plots, "none")) {
inp <- check_plot_inputs(svm_plot_inputs())
inp$size <- NULL
inp_out[[2]] <- clean_args(inp, svm_plot_args[-1])
inp_out[[2]]$custom <- FALSE
outputs <- c(outputs, "plot")
figs <- TRUE
}
if (!is.empty(input$svm_predict, "none") &&
(!is.empty(input$svm_pred_data) || !is.empty(input$svm_pred_cmd))) {
pred_args <- clean_args(svm_pred_inputs(), svm_pred_args[-1])
if (!is.empty(pred_args$pred_cmd)) {
pred_args$pred_cmd <- strsplit(pred_args$pred_cmd, ";\\s*")[[1]]
} else {
pred_args$pred_cmd <- NULL
}
if (!is.empty(pred_args$pred_data)) {
pred_args$pred_data <- as.symbol(pred_args$pred_data)
} else {
pred_args$pred_data <- NULL
}
inp_out[[2 + figs]] <- pred_args
outputs <- c(outputs, "pred <- predict")
xcmd <- paste0(xcmd, "print(pred, n = 10)")
if (input$svm_predict %in% c("data", "datacmd") && !is.empty(input$svm_store_pred_name)) {
fixed <- fix_names(input$svm_store_pred_name)
updateTextInput(session, "svm_store_pred_name", value = fixed)
xcmd <- paste0(
xcmd, "\n", input$svm_pred_data, " <- store(",
input$svm_pred_data, ", pred, name = \"", fixed, "\")"
)
}
}
svm_inp <- svm_inputs()
if (input$svm_type == "regression") {
svm_inp$lev <- NULL
}
if (input$svm_kernel == "linear") {
svm_inp$gamma <- NULL
}
update_report(
inp_main = clean_args(svm_inp, svm_args),
fun_name = "svm",
inp_out = inp_out,
outputs = outputs,
figs = figs,
fig.width = svm_plot_width(),
fig.height = svm_plot_height(),
xcmd = xcmd
)
}
## 报告生成
observeEvent(input$svm_report, {
r_info[["latest_screenshot"]] <- NULL
svm_report()
......
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/svm.R
\name{cv.svm}
\alias{cv.svm}
\title{Cross-validation for SVM}
\usage{
cv.svm(
object,
K = 5,
repeats = 1,
kernel = c("linear", "radial"),
cost = seq(0.1, 10, by = 0.5),
gamma = seq(0.1, 5, by = 0.5),
seed = 1234,
trace = TRUE,
fun,
...
)
}
\description{
Cross-validation for SVM
}
......@@ -4,17 +4,7 @@
\alias{plot.svm}
\title{Plot method for the svm function}
\usage{
\method{plot}{svm}(
x,
plots = "vip",
size = 12,
nrobs = -1,
incl = NULL,
incl_int = NULL,
shiny = FALSE,
custom = FALSE,
...
)
\method{plot}{svm}(x, plots = "none", size = 12, shiny = FALSE, custom = FALSE, ...)
}
\description{
Plot method for the svm function
......
......@@ -2,7 +2,7 @@
% Please edit documentation in R/svm.R
\name{predict.svm}
\alias{predict.svm}
\title{Predict method for the svm function}
\title{Predict method for SVM model}
\usage{
\method{predict}{svm}(
object,
......@@ -14,5 +14,5 @@
)
}
\description{
Predict method for the svm function
Predict method for SVM model
}
......@@ -4,9 +4,9 @@
\alias{varimp}
\title{Variable importance using the vip package and permutation importance}
\usage{
varimp(object, rvar, lev, data = NULL, seed = 1234)
varimp(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10)
varimp(object, rvar, lev, data = NULL, seed = 1234)
varimp(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10)
}
\arguments{
\item{object}{Model object created by Radiant}
......@@ -22,5 +22,5 @@ varimp(object, rvar, lev, data = NULL, seed = 1234)
\description{
Variable importance using the vip package and permutation importance
Variable importance using the vip package and permutation importance
Variable importance for SVM using permutation importance
}
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/nn.R, R/svm.R
% Please edit documentation in R/nn.R
\name{varimp_plot}
\alias{varimp_plot}
\title{Plot permutation importance}
\usage{
varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
}
\arguments{
......@@ -20,7 +18,5 @@ varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
\item{seed}{Random seed for reproducibility}
}
\description{
Plot permutation importance
Plot permutation importance
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment