Commit 53dd5828 authored by wuzekai's avatar wuzekai

更新了svm

parent e9df660e
...@@ -85,7 +85,6 @@ export(cv.crtree) ...@@ -85,7 +85,6 @@ export(cv.crtree)
export(cv.gbt) export(cv.gbt)
export(cv.nn) export(cv.nn)
export(cv.rforest) export(cv.rforest)
export(cv.svm)
export(dtree) export(dtree)
export(dtree_parser) export(dtree_parser)
export(evalbin) export(evalbin)
...@@ -121,6 +120,9 @@ export(sim_splitter) ...@@ -121,6 +120,9 @@ export(sim_splitter)
export(sim_summary) export(sim_summary)
export(simulater) export(simulater)
export(svm) export(svm)
export(svm_boundary_plot)
export(svm_margin_plot)
export(svm_vip_plot)
export(test_specs) export(test_specs)
export(uplift) export(uplift)
export(var_check) export(var_check)
......
...@@ -7,7 +7,7 @@ svm <- function(dataset, rvar, evar, ...@@ -7,7 +7,7 @@ svm <- function(dataset, rvar, evar,
form, data_filter = "", arr = "", form, data_filter = "", arr = "",
rows = NULL, envir = parent.frame()) { rows = NULL, envir = parent.frame()) {
## ---- 参数合法性检查(SVM特有) ---- ## ---- 参数合法性检查----
valid_kernels <- c("linear", "radial", "poly", "sigmoid") valid_kernels <- c("linear", "radial", "poly", "sigmoid")
if (!kernel %in% valid_kernels) { if (!kernel %in% valid_kernels) {
return(paste0("Kernel must be one of: ", paste(valid_kernels, collapse = ", ")) %>% return(paste0("Kernel must be one of: ", paste(valid_kernels, collapse = ", ")) %>%
...@@ -38,11 +38,32 @@ svm <- function(dataset, rvar, evar, ...@@ -38,11 +38,32 @@ svm <- function(dataset, rvar, evar,
df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset)) df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset))
dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir) dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir)
is_cat <- sapply(dataset, function(x) is.factor(x) || is.character(x) || is.logical(x))
cat_vars <- names(is_cat)[is_cat]
if (length(cat_vars) > 0) {
# 2. 字符型转因子(确保后续编码顺序一致)
for (var in cat_vars) {
if (is.character(dataset[[var]])) {
dataset[[var]] <- as.factor(dataset[[var]])
}
# 3. 因子/逻辑型转数值(按原始水平顺序编码,如level1→1, level2→2...)
dataset[[var]] <- as.numeric(dataset[[var]])
}
}
# 缺失值处理
dataset <- na.omit(dataset)
if (nrow(dataset) == 0) {
return("No valid samples after removing missing values. Please check your data." %>% add_class("svm"))
}
# 标准化
if ("standardize" %in% check) { if ("standardize" %in% check) {
dataset <- scale_df(dataset) dataset <- scale_df(dataset)
} }
## ---- 3. 分类任务的响应变量(转为因子) ---- ## ---- 3. 分类任务的响应变量 ----
if (type == "classification") { if (type == "classification") {
dataset[[rvar]] <- as.factor(dataset[[rvar]]) dataset[[rvar]] <- as.factor(dataset[[rvar]])
if (lev == "") { if (lev == "") {
...@@ -71,10 +92,19 @@ svm <- function(dataset, rvar, evar, ...@@ -71,10 +92,19 @@ svm <- function(dataset, rvar, evar,
if (!is.na(seed)) set.seed(seed) if (!is.na(seed)) set.seed(seed)
## ---- 6. 训练模型 ---- ## ---- 模型训练----
model <- do.call(e1071::svm, svm_input) model <- try({
do.call(e1071::svm, svm_input)
}, silent = TRUE)
if (inherits(model, "try-error")) {
return(paste("Model training failed:", attr(model, "condition")$message) %>% add_class("svm"))
}
if (is.null(model) || !inherits(model, "svm")) {
return("Model training failed: Generated SVM model is invalid" %>% add_class("svm"))
}
## ---- 7. 附加关键信息 ---- ## ---- 附加关键信息----
model$df_name <- df_name model$df_name <- df_name
model$rvar <- rvar model$rvar <- rvar
model$evar <- evar model$evar <- evar
...@@ -86,7 +116,21 @@ svm <- function(dataset, rvar, evar, ...@@ -86,7 +116,21 @@ svm <- function(dataset, rvar, evar,
model$gamma <- gamma model$gamma <- gamma
model$kernel <- kernel model$kernel <- kernel
as.list(environment()) %>% add_class(c("svm", "model")) ## ---- 返回模型对象----
list(
model = model,
df_name = df_name,
rvar = rvar,
evar = evar,
type = type,
lev = lev,
wtsname = wtsname,
seed = seed,
cost = cost,
gamma = gamma,
kernel = kernel,
dataset = dataset
) %>% add_class(c("svm", "model"))
} }
#' Center or standardize variables in a data frame #' Center or standardize variables in a data frame
...@@ -101,9 +145,9 @@ scale_df <- function(dataset, center = TRUE, scale = TRUE, ...@@ -101,9 +145,9 @@ scale_df <- function(dataset, center = TRUE, scale = TRUE,
descr <- attr(dataset, "description") # 保留原描述属性 descr <- attr(dataset, "description") # 保留原描述属性
if (calc) { if (calc) {
# 计算均值(忽略NA) # 计算均值
ms <- sapply(dataset[, cn, drop = FALSE], function(x) mean(x, na.rm = TRUE)) ms <- sapply(dataset[, cn, drop = FALSE], function(x) mean(x, na.rm = TRUE))
# 计算标准差(忽略NA,样本标准差ddof=1,避免除以零) # 计算标准差
sds <- sapply(dataset[, cn, drop = FALSE], function(x) { sds <- sapply(dataset[, cn, drop = FALSE], function(x) {
sd_val <- sd(x, na.rm = TRUE) sd_val <- sd(x, na.rm = TRUE)
ifelse(sd_val == 0, 1, sd_val) ifelse(sd_val == 0, 1, sd_val)
...@@ -135,7 +179,7 @@ summary.svm <- function(object, prn = TRUE, ...) { ...@@ -135,7 +179,7 @@ summary.svm <- function(object, prn = TRUE, ...) {
svm_model <- object$model svm_model <- object$model
n_obs <- nrow(object$dataset) n_obs <- nrow(object$dataset)
wtsname <- object$wtsname # 可能是 NULL 或长度 0 字符 wtsname <- object$wtsname
cat("Support Vector Machine\n") cat("Support Vector Machine\n")
cat(sprintf("Kernel type : %s (%s)\n", object$kernel, object$type)) cat(sprintf("Kernel type : %s (%s)\n", object$kernel, object$type))
...@@ -146,7 +190,7 @@ summary.svm <- function(object, prn = TRUE, ...) { ...@@ -146,7 +190,7 @@ summary.svm <- function(object, prn = TRUE, ...) {
} }
cat(sprintf("Explanatory variables: %s\n", paste(object$evar, collapse = ", "))) cat(sprintf("Explanatory variables: %s\n", paste(object$evar, collapse = ", ")))
if (!is.null(wtsname) && length(wtsname) && wtsname != "") { if (!is.null(wtsname) && nzchar(wtsname)) {
cat(sprintf("Weights used : %s\n", wtsname)) cat(sprintf("Weights used : %s\n", wtsname))
} }
...@@ -156,48 +200,72 @@ summary.svm <- function(object, prn = TRUE, ...) { ...@@ -156,48 +200,72 @@ summary.svm <- function(object, prn = TRUE, ...) {
} }
if (!is.na(object$seed)) cat(sprintf("Seed : %s\n", object$seed)) if (!is.na(object$seed)) cat(sprintf("Seed : %s\n", object$seed))
# 支持向量计数
if (object$type == "classification") { if (object$type == "classification") {
n_sv_per_class <- svm_model$nSV n_sv_per_class <- if (!is.null(svm_model$nSV)) svm_model$nSV else c(0, 0)
total_sv <- sum(n_sv_per_class) total_sv <- sum(n_sv_per_class)
sv_info <- paste(sprintf("class %s: %d", seq_along(n_sv_per_class), n_sv_per_class), collapse = ", ") sv_info <- paste(sprintf("class %s: %d", seq_along(n_sv_per_class), n_sv_per_class), collapse = ", ")
cat(sprintf("Support vectors : %d (%s)\n", total_sv, sv_info)) cat(sprintf("Support vectors : %d (%s)\n", total_sv, sv_info))
} else { } else {
total_sv <- sum(svm_model$nSV) total_sv <- if (!is.null(svm_model$index)) length(svm_model$index) else 0
cat(sprintf("Support vectors : %d\n", total_sv)) cat(sprintf("Support vectors : %d\n", total_sv))
} }
## ---- 权重样本数计算同样保护 ----
if (!is.null(wtsname) && length(wtsname) && wtsname != "") {
weights_values <- as.numeric(object$dataset[[wtsname]])
nr_obs <- if (all(!is.na(weights_values))) sum(weights_values, na.rm = TRUE) else n_obs
} else {
nr_obs <- n_obs nr_obs <- n_obs
if (!is.null(wtsname) && nzchar(wtsname) && wtsname %in% names(object$dataset)) {
weights_values <- as.numeric(object$dataset[[wtsname]])
weights_clean <- weights_values[!is.na(weights_values)]
if (length(weights_clean) > 0) {
sum_weights <- sum(weights_clean)
if (sum_weights > 0) nr_obs <- sum_weights
}
} }
cat(sprintf("Nr_obs : %s\n", format_nr(nr_obs, dec = 0))) cat(sprintf("Nr_obs : %s\n", format_nr(nr_obs, dec = 0)))
## ---- 系数输出(仅线性核) ----
if (prn) { if (prn) {
cat("Coefficients/Support Vectors:\n") cat("Coefficients/Support Vectors:\n")
if (object$kernel == "linear") { if (object$kernel == "linear") {
if (is.null(svm_model$w) || length(svm_model$w) == 0) { feat_coefs_raw <- rep(0, length(object$evar))
cat(" Linear kernel coefficients not available (possible reasons:\n") bias <- 0
cat(" - Insufficient support vectors (data may be linearly inseparable)\n")
cat(" - Model training did not converge)\n") if (!is.null(svm_model$terms) && !is.null(svm_model$coefs) && !is.null(svm_model$SV)) {
if (object$type == "classification") { tryCatch({
cat(sprintf(" Support vectors count: %d\n", sum(svm_model$nSV))) model_vars <- attr(svm_model$terms, "term.labels")
coef_mat <- as.matrix(svm_model$coefs)
sv_mat <- as.matrix(svm_model$SV)
if (nrow(coef_mat) > 0 && nrow(sv_mat) > 0) {
w_full <- as.numeric(t(coef_mat) %*% sv_mat)
sv_colnames <- colnames(sv_mat)
if (!is.null(sv_colnames)) {
# 按变量名精确匹配
for (i in seq_along(object$evar)) {
var_name <- object$evar[i]
matching_cols <- which(sv_colnames == var_name)
if (length(matching_cols) > 0) {
feat_coefs_raw[i] <- sum(w_full[matching_cols])
}
} }
} else { } else {
feat_coefs_raw <- as.numeric(svm_model$w) # 无列名则按位置取
bias <- as.numeric(-svm_model$rho)
n_evar <- length(object$evar) n_evar <- length(object$evar)
feat_coefs <- matrix(data = feat_coefs_raw, nrow = 1, ncol = n_evar, byrow = TRUE, feat_coefs_raw[1:min(n_evar, length(w_full))] <- w_full[1:min(n_evar, length(w_full))]
dimnames = list(NULL, object$evar)) }
bias <- as.numeric(-svm_model$rho)
}
}, error = function(e) {
warning("系数提取失败: ", e$message, call. = FALSE)
})
}
# 输出
coef_data <- data.frame( coef_data <- data.frame(
Variable = c(object$evar, "bias"), Variable = c(object$evar, "bias"),
Value = as.numeric(c(as.vector(feat_coefs), bias)), Value = c(feat_coefs_raw, bias),
stringsAsFactors = FALSE, stringsAsFactors = FALSE
check.names = FALSE
) )
for (i in seq(1, nrow(coef_data), 2)) { for (i in seq(1, nrow(coef_data), 2)) {
if (i + 1 > nrow(coef_data)) { if (i + 1 > nrow(coef_data)) {
cat(sprintf(" %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i])) cat(sprintf(" %-12s: %.2f\n", coef_data$Variable[i], coef_data$Value[i]))
...@@ -207,7 +275,6 @@ summary.svm <- function(object, prn = TRUE, ...) { ...@@ -207,7 +275,6 @@ summary.svm <- function(object, prn = TRUE, ...) {
coef_data$Variable[i+1], coef_data$Value[i+1])) coef_data$Variable[i+1], coef_data$Value[i+1]))
} }
} }
}
} else { } else {
cat(" Non-linear kernel: Coefficients not available\n") cat(" Non-linear kernel: Coefficients not available\n")
} }
...@@ -256,134 +323,234 @@ plot.svm <- function(x, ...@@ -256,134 +323,234 @@ plot.svm <- function(x,
return("No valid plots selected for SVM") return("No valid plots selected for SVM")
} }
## 返回 patchwork 对象(Shiny 自动打印) ## 返回 patchwork 对象
patchwork::wrap_plots(plot_list, ncol = min(2, length(plot_list))) %>% patchwork::wrap_plots(plot_list, ncol = min(2, length(plot_list))) %>%
{ if (isTRUE(shiny)) . else print(.) } { if (isTRUE(shiny)) . else print(.) }
} }
#' Predict method for the svm function #' Predict method for SVM model
#' @export #' @export
predict.svm <- function(object, pred_data = NULL, pred_cmd = "", predict.svm <- function(object, pred_data = NULL, pred_cmd = "", dec = 3,
dec = 3, envir = parent.frame(), ...) { envir = parent.frame(), ...) {
## 1. 基础校验 # 1. 基础校验
if (is.character(object)) return(object) if (is.character(object)) return(object)
if (!inherits(object, "svm")) stop("Object must be of class 'svm'") if (!inherits(object, "svm")) stop("Object must be of class 'svm'")
if (is.null(object$model) || !inherits(object$model, "svm")) {
err_msg <- "Prediction failed: Invalid SVM model"
err_obj <- structure(list(error = err_msg), class = "svm.predict.error")
return(err_obj)
}
svm_info <- object
## 2.1 确定预测数据源 # 2. 处理预测数据
if (is.null(pred_data) || is.character(pred_data)) { if (is.character(pred_data) && nzchar(pred_data)) {
# 当pred_data为NULL或字符(数据集名)时,使用训练数据 if (!exists(pred_data, envir = envir)) {
pred_data_raw <- object$dataset[, object$evar, drop = FALSE] err_obj <- structure(list(error = sprintf("Dataset '%s' not found", pred_data)), class = "svm.predict.error")
} else { return(err_obj)
# 当pred_data是数据框时,直接使用 }
pred_data_raw <- as.data.frame(pred_data) pred_data <- get(pred_data, envir = envir)
}
has_data <- !is.null(pred_data) && is.data.frame(pred_data) && nrow(pred_data) > 0
has_cmd <- is.character(pred_cmd) && nzchar(pred_cmd)
if (!has_data && !has_cmd) {
err_obj <- structure(list(error = "Please select data and/or specify a command to generate predictions."), class = "svm.predict.error")
return(err_obj)
} }
pred_names <- colnames(pred_data_raw) ## ensure you have a name for the prediction dataset
missing_vars <- setdiff(object$evar, pred_names) if (is.data.frame(pred_data)) {
df_name <- deparse(substitute(pred_data))
} else {
df_name <- pred_data
}
# 3. 预测核心函数 - 修改参数名以匹配NN
pfun <- function(model, pred, se, conf_lev) {
# 验证数据
missing_vars <- setdiff(svm_info$evar, colnames(pred))
if (length(missing_vars) > 0) { if (length(missing_vars) > 0) {
msg <- paste0( return(paste("Missing variables:", paste(missing_vars, collapse = ", ")))
"NA\n" }
# 内部标准化
ms <- attr(svm_info$dataset, "radiant_ms")
sds <- attr(svm_info$dataset, "radiant_sds")
if (!is.null(ms) && !is.null(sds)) {
isNum <- sapply(pred, is.numeric)
cn <- names(isNum)[isNum]
for (var in cn) {
if (var %in% names(ms) && var %in% names(sds) && sds[var] != 0) {
pred[[var]] <- (pred[[var]] - ms[var]) / sds[var]
}
}
}
# 调试信息:检查模型是否启用了概率
cat("Model probability enabled:", svm_info$model$prob.model, "\n")
# 执行预测
pred_result <- try({
predict(
svm_info$model,
newdata = pred,
probability = TRUE, # 始终设置为TRUE
decision.values = TRUE
) )
return(msg %>% add_class("svm.predict")) }, silent = TRUE)
if (inherits(pred_result, "try-error")) {
return(paste("Prediction failed:", attr(pred_result, "condition")$message))
} }
pred_data <- pred_data_raw[, object$evar, drop = FALSE] # 4. 结果整理
if (svm_info$type == "classification") {
# 正确的属性名是"probabilities"
prob_mat <- attr(pred_result, "probabilities")
if (!is.null(prob_mat)) {
# 调试:显示概率矩阵的列名
cat("Probability matrix columns:", paste(colnames(prob_mat), collapse = ", "), "\n")
# 更智能的列名匹配
target_level <- as.character(svm_info$lev)
if (!is.empty(pred_cmd)) { # 尝试多种匹配方式
pred_cmd <- gsub("\\s{2,}", " ", pred_cmd) %>% matching_col <- NULL
gsub(";\\s+", ";", .) %>%
strsplit(";")[[1]] # 1. 精确匹配
for (cmd in pred_cmd) { exact_match <- grep(paste0("^", target_level, "$"), colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (grepl("=", cmd)) { if (length(exact_match) > 0) {
var_val <- strsplit(trimws(cmd), "=")[[1]] matching_col <- exact_match
var <- trimws(var_val[1])
val <- try(eval(parse(text = trimws(var_val[2])), envir = envir), silent = TRUE)
if (!inherits(val, "try-error")) pred_data[[var]] <- val
else warning(sprintf("Invalid command '%s': using original values", cmd))
} }
# 2. 部分匹配
else {
partial_match <- grep(target_level, colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(partial_match) > 0) {
matching_col <- partial_match[1] # 取第一个匹配
} }
pred_data <- unique(pred_data)
} }
## 3. 变量类型与因子水平校验 # 3. 如果还是没有匹配,尝试逻辑值匹配
train_types <- sapply(object$dataset[, object$evar], class) if (is.null(matching_col) || length(matching_col) == 0) {
pred_types <- sapply(pred_data, class) if (target_level %in% c("TRUE", "Yes", "1", "1.0")) {
type_mismatch <- names(which(train_types != pred_types)) logical_match <- grep("TRUE", colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(type_mismatch) > 0) { if (length(logical_match) > 0) {
return(paste0("Variable type mismatch (train vs pred):\n", matching_col <- logical_match[1]
paste(sprintf(" %s: %s vs %s", type_mismatch,
train_types[type_mismatch], pred_types[type_mismatch]),
collapse = "\n")) %>% add_class("svm.predict"))
} }
} else if (target_level %in% c("FALSE", "No", "0", "0.0")) {
for (var in object$evar) { logical_match <- grep("FALSE", colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (is.factor(object$dataset[[var]])) { if (length(logical_match) > 0) {
train_levs <- levels(object$dataset[[var]]) matching_col <- logical_match[1]
pred_levs <- levels(pred_data[[var]])
if (!identical(train_levs, pred_levs)) {
pred_data[[var]] <- factor(pred_data[[var]], levels = train_levs)
warning(sprintf("Aligned factor levels for '%s' to match training data", var))
} }
} }
} }
## 4. 标准化对齐 # 4. 作为最后手段,使用第一列
train_ms <- attr(object$dataset, "radiant_ms") if (is.null(matching_col) || length(matching_col) == 0) {
train_sds <- attr(object$dataset, "radiant_sds") if (ncol(prob_mat) > 0) {
train_sf <- attr(object$dataset, "radiant_sf") %||% 2 matching_col <- colnames(prob_mat)[1]
if (!is.null(train_ms) && !is.null(train_sds)) { cat("Warning: Using first probability column '", matching_col, "' for level '", target_level, "'\n")
pred_data <- scale_df( }
dataset = pred_data,
center = TRUE, scale = TRUE,
sf = train_sf, wts = NULL,
calc = FALSE
)
attr(pred_data, "radiant_ms") <- train_ms
attr(pred_data, "radiant_sds") <- train_sds
} }
## 5. 生成预测值 if (!is.null(matching_col) && length(matching_col) > 0) {
predict_args <- list( surv_prob <- as.numeric(prob_mat[, matching_col])
object = object$model, pred_df <- data.frame(
newdata = pred_data, Prediction = round(surv_prob, dec),
na.action = na.omit stringsAsFactors = FALSE
) )
} else {
pred_result <- if (object$type == "classification") { # 备用方案:使用决策值
predict_args$type <- "class" decision_vals <- attr(pred_result, "decision.values")
pred_class <- do.call(predict, predict_args) if (!is.null(decision_vals)) {
if (isTRUE(object$model$param$probability)) { # 将决策值转换为概率
predict_args$type <- "probabilities" pred_prob <- 1 / (1 + exp(-decision_vals))
pred_prob <- do.call(predict, predict_args) %>% pred_df <- data.frame(
as.data.frame() %>% Prediction = round(pred_prob, dec),
set_colnames(paste0("Prob_", colnames(.)))
data.frame(
Predicted_Class = as.character(pred_class),
pred_prob,
stringsAsFactors = FALSE stringsAsFactors = FALSE
) )
} else { } else {
data.frame(Predicted_Class = as.character(pred_class), stringsAsFactors = FALSE) # 最后手段:使用预测类
pred_class <- as.character(pred_result)
pred_prob <- ifelse(pred_class == target_level, 1, 0)
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
} }
}
} else {
# 备用方案:使用决策值
decision_vals <- attr(pred_result, "decision.values")
if (!is.null(decision_vals)) {
# 将决策值转换为概率
pred_prob <- 1 / (1 + exp(-decision_vals))
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
} else { } else {
predict_args$type <- "response" # 最后手段:使用预测类
pred_val <- do.call(predict, predict_args) target_level <- as.character(svm_info$lev)
pred_val <- as.numeric(pred_val) pred_class <- as.character(pred_result)
data.frame(Predicted_Value = round(pred_val, dec), stringsAsFactors = FALSE) pred_prob <- ifelse(pred_class == target_level, 1, 0)
} pred_df <- data.frame(
Prediction = round(pred_prob, dec),
pred_result <- cbind(pred_data, pred_result) stringsAsFactors = FALSE
attr(pred_result, "svm_meta") <- list( )
kernel = object$kernel, }
cost = object$cost, }
gamma = if (object$kernel != "linear") object$gamma else NA, } else {
seed = object$seed, # 回归模型
train_data = object$df_name, pred_values <- as.numeric(pred_result)
model_type = object$type
# 关键:添加反标准化处理
rvar_ms <- if (!is.null(ms) && svm_info$rvar %in% names(ms)) ms[svm_info$rvar] else 0
rvar_sds <- if (!is.null(sds) && svm_info$rvar %in% names(sds)) sds[svm_info$rvar] else 1
# 反标准化
pred_values <- pred_values * rvar_sds + rvar_ms
pred_df <- data.frame(
Prediction = round(pred_values, dec),
stringsAsFactors = FALSE
)
}
return(pred_df)
}
# 5. 调用预测框架 - 与NN完全一致
result <- predict_model(
object,
pfun,
"svm.predict", # 模型类型
pred_data,
pred_cmd,
conf_lev = 0.95,
se = FALSE,
dec,
envir = envir
) %>%
set_attr("radiant_pred_data", df_name)
# 6. 结果元数据
if (inherits(result, "svm.predict.error")) {
return(result$error)
}
if (is.character(result)) return(result)
result <- add_class(result, "svm.predict")
attr(result, "svm_meta") <- list(
model_type = svm_info$type,
kernel = svm_info$kernel,
cost = svm_info$cost,
gamma = if (svm_info$kernel != "linear") svm_info$gamma else NA,
seed = svm_info$seed,
train_data = svm_info$df_name,
dec = dec
) )
attr(pred_result, "class") <- c("svm.predict", "data.frame") return(result)
return(pred_result)
} }
#' Print method for predict.svm #' Print method for predict.svm
...@@ -396,16 +563,30 @@ print.svm.predict <- function(x, ..., n = 10) { ...@@ -396,16 +563,30 @@ print.svm.predict <- function(x, ..., n = 10) {
if (!inherits(x, "svm.predict")) stop("Object must be of class 'svm.predict'") if (!inherits(x, "svm.predict")) stop("Object must be of class 'svm.predict'")
meta <- attr(x, "svm_meta") meta <- attr(x, "svm_meta")
if (is.null(meta)) {
meta <- list(
model_type = "Unknown",
kernel = "Unknown",
cost = NA,
gamma = NA,
seed = NA,
train_data = "Unknown",
dec = 1
)
}
dec <- meta$dec %||% 1
n_pred <- nrow(x) n_pred <- nrow(x)
show_n <- if (n < 0 || n >= n_pred) n_pred else n show_n <- if (n < 0 || n >= n_pred) n_pred else n
cat("SVM Predictions\n") cat("SVM Predictions\n")
cat(sprintf("Model Type : %s\n", tools::toTitleCase(meta$model_type))) model_type <- if (is.empty(meta$model_type)) "Unknown" else as.character(meta$model_type)
cat(sprintf("Kernel : %s\n", meta$kernel)) cat(sprintf("Model Type : %s\n", tools::toTitleCase(model_type)))
cat(sprintf("Cost (C) : %.2f\n", meta$cost)) cat(sprintf("Kernel : %s\n", ifelse(is.empty(meta$kernel), "Unknown", meta$kernel)))
cat(sprintf("Cost (C) : %.2f\n", ifelse(is.na(meta$cost), 0, meta$cost)))
if (!is.na(meta$gamma)) cat(sprintf("Gamma : %.2f\n", meta$gamma)) if (!is.na(meta$gamma)) cat(sprintf("Gamma : %.2f\n", meta$gamma))
if (!is.na(meta$seed)) cat(sprintf("Seed : %s\n", meta$seed)) if (!is.na(meta$seed)) cat(sprintf("Seed : %s\n", meta$seed))
cat(sprintf("Training Dataset : %s\n", meta$train_data)) cat(sprintf("Training Dataset : %s\n", ifelse(is.empty(meta$train_data), "Unknown", meta$train_data)))
cat(sprintf("Total Predictions : %d\n", n_pred)) cat(sprintf("Total Predictions : %d\n", n_pred))
if (n_pred == 0) { if (n_pred == 0) {
...@@ -413,8 +594,13 @@ print.svm.predict <- function(x, ..., n = 10) { ...@@ -413,8 +594,13 @@ print.svm.predict <- function(x, ..., n = 10) {
return(invisible(x)) return(invisible(x))
} }
x_show <- x[1:show_n, , drop = FALSE] # 确保保持数据框结构 x_show <- x[1:show_n, , drop = FALSE]
num_cols <- sapply(x_show, is.numeric)
x_show[num_cols] <- lapply(x_show[num_cols], function(col) {
sprintf(paste0("%.", dec, "f"), col)
})
# 重新计算列宽
col_widths <- sapply(colnames(x_show), function(cn) { col_widths <- sapply(colnames(x_show), function(cn) {
max(nchar(cn), max(nchar(as.character(x_show[[cn]])), na.rm = TRUE)) max(nchar(cn), max(nchar(as.character(x_show[[cn]])), na.rm = TRUE))
}) })
...@@ -435,38 +621,13 @@ print.svm.predict <- function(x, ..., n = 10) { ...@@ -435,38 +621,13 @@ print.svm.predict <- function(x, ..., n = 10) {
} }
if (show_n < n_pred) { if (show_n < n_pred) {
cat(sprintf("\n... (showing first %d of %d; use 'n=-1' to view all)\n", cat(sprintf("\n... (showing first %d of %d)\n",
show_n, n_pred)) show_n, n_pred))
} }
return(invisible(x)) return(invisible(x))
} }
#' Cross-validation for SVM
#' @export
cv.svm <- function(object, K = 5, repeats = 1,
kernel = c("linear", "radial"),
cost = 2^(-2:2),
gamma = 2^(-2:2),
seed = 1234, trace = TRUE, fun, ...) {
if (!inherits(object, "svm")) stop("Object must be of class 'svm'")
tune_grid <- expand.grid(
kernel = kernel,
cost = cost,
gamma = if (any(kernel %in% c("radial", "poly", "sigmoid"))) gamma else NA
)
cv_results <- data.frame(
mean_perf = rep(NA, nrow(tune_grid)),
std_perf = rep(NA, nrow(tune_grid)),
kernel = tune_grid$kernel,
cost = tune_grid$cost,
gamma = tune_grid$gamma
)
cv_results
}
# ---- 决策边界 ---- # ---- 决策边界 ----
#' @export #' @export
svm_boundary_plot <- function(object, size, custom) { svm_boundary_plot <- function(object, size, custom) {
......
...@@ -79,11 +79,12 @@ options( ...@@ -79,11 +79,12 @@ options(
i18n$t("Estimate"), i18n$t("Estimate"),
tabPanel(i18n$t("Linear regression (OLS)"), uiOutput("regress")), tabPanel(i18n$t("Linear regression (OLS)"), uiOutput("regress")),
tabPanel(i18n$t("Logistic regression (GLM)"), uiOutput("logistic")), tabPanel(i18n$t("Logistic regression (GLM)"), uiOutput("logistic")),
tabPanel(i18n$t("Cox Proportional Hazards Regression"),uiOutput("coxp")),
tabPanel(i18n$t("Multinomial logistic regression (MNL)"), uiOutput("mnl")), tabPanel(i18n$t("Multinomial logistic regression (MNL)"), uiOutput("mnl")),
tabPanel(i18n$t("Naive Bayes"), uiOutput("nb")), tabPanel(i18n$t("Naive Bayes"), uiOutput("nb")),
tabPanel(i18n$t("Neural Network"), uiOutput("nn")), tabPanel(i18n$t("Neural Network"), uiOutput("nn")),
tabPanel(i18n$t("Support Vector Machine (SVM)"),uiOutput("svm")), tabPanel(i18n$t("Support Vector Machine (SVM)"),uiOutput("svm")),
"----", i18n$t("Survival Analysis"),
tabPanel(i18n$t("Cox Proportional Hazards Regression"),uiOutput("coxp")),
"----", i18n$t("Trees"), "----", i18n$t("Trees"),
tabPanel(i18n$t("Classification and regression trees"), uiOutput("crtree")), tabPanel(i18n$t("Classification and regression trees"), uiOutput("crtree")),
tabPanel(i18n$t("Random Forest"), uiOutput("rf")), tabPanel(i18n$t("Random Forest"), uiOutput("rf")),
......
...@@ -23,7 +23,7 @@ svm_inputs <- reactive({ ...@@ -23,7 +23,7 @@ svm_inputs <- reactive({
svm_args svm_args
}) })
## 预测参数(保留命令模式,未改动) ## 预测参数
svm_pred_args <- as.list(if (exists("predict.svm")) { svm_pred_args <- as.list(if (exists("predict.svm")) {
formals(predict.svm) formals(predict.svm)
} else { } else {
...@@ -52,7 +52,7 @@ svm_pred_inputs <- reactive({ ...@@ -52,7 +52,7 @@ svm_pred_inputs <- reactive({
return(svm_pred_args) return(svm_pred_args)
}) })
## 绘图参数(砍掉vip、pdp、svm_margin) ## 绘图参数
svm_plot_args <- as.list(if (exists("plot.svm")) { svm_plot_args <- as.list(if (exists("plot.svm")) {
formals(plot.svm) formals(plot.svm)
} else { } else {
...@@ -77,11 +77,7 @@ output$ui_svm_rvar <- renderUI({ ...@@ -77,11 +77,7 @@ output$ui_svm_rvar <- renderUI({
vars <- varnames()[isNum] vars <- varnames()[isNum]
} }
}) })
init <- if (input$svm_type == "classification") { init <- state_single("svm_rvar", vars, isolate(input$svm_rvar))
if (is.empty(input$logit_rvar)) isolate(input$svm_rvar) else input$logit_rvar
} else {
if (is.empty(input$reg_rvar)) isolate(input$svm_rvar) else input$reg_rvar
}
selectInput( selectInput(
inputId = "svm_rvar", inputId = "svm_rvar",
label = i18n$t("Response variable:"), label = i18n$t("Response variable:"),
...@@ -109,11 +105,7 @@ output$ui_svm_evar <- renderUI({ ...@@ -109,11 +105,7 @@ output$ui_svm_evar <- renderUI({
if (not_available(input$svm_rvar)) return() if (not_available(input$svm_rvar)) return()
vars <- varnames() vars <- varnames()
if (length(vars) > 0) vars <- vars[-which(vars == input$svm_rvar)] if (length(vars) > 0) vars <- vars[-which(vars == input$svm_rvar)]
init <- if (input$svm_type == "classification") { init <- state_multiple("svm_evar", vars, isolate(input$svm_evar))
if (is.empty(input$logit_evar)) isolate(input$svm_evar) else input$logit_evar
} else {
if (is.empty(input$reg_evar)) isolate(input$svm_evar) else input$reg_evar
}
selectInput( selectInput(
inputId = "svm_evar", inputId = "svm_evar",
label = i18n$t("Explanatory variables:"), label = i18n$t("Explanatory variables:"),
...@@ -141,7 +133,7 @@ output$ui_svm_wts <- renderUI({ ...@@ -141,7 +133,7 @@ output$ui_svm_wts <- renderUI({
) )
}) })
## 存储预测值UI(残差存储已删除) ## 存储预测值UI
output$ui_svm_store_pred_name <- renderUI({ output$ui_svm_store_pred_name <- renderUI({
init <- state_init("svm_store_pred_name", "pred_svm") %>% init <- state_init("svm_store_pred_name", "pred_svm") %>%
sub("\\d{1,}$", "", .) %>% sub("\\d{1,}$", "", .) %>%
...@@ -164,7 +156,7 @@ observeEvent(input$svm_type, { ...@@ -164,7 +156,7 @@ observeEvent(input$svm_type, {
updateSelectInput(session = session, inputId = "svm_plots", selected = "none") updateSelectInput(session = session, inputId = "svm_plots", selected = "none")
}) })
## 绘图选项UI(已删vip、pdp、svm_margin) ## 绘图选项UI
output$ui_svm_plots <- renderUI({ output$ui_svm_plots <- renderUI({
req(input$svm_type) req(input$svm_type)
avail_plots <- svm_plots avail_plots <- svm_plots
...@@ -178,19 +170,7 @@ output$ui_svm_plots <- renderUI({ ...@@ -178,19 +170,7 @@ output$ui_svm_plots <- renderUI({
) )
}) })
## 数据点数量UI(仅dashboard用,保留) ## 主UI面板
output$ui_svm_nrobs <- renderUI({
nrobs <- nrow(.get_data())
choices <- c("1,000" = 1000, "5,000" = 5000, "10,000" = 10000, "All" = -1) %>%
.[. < nrobs]
selectInput(
"svm_nrobs", i18n$t("Number of data points plotted:"),
choices = choices,
selected = state_single("svm_nrobs", choices, 1000)
)
})
## 主UI面板(已删残差存储入口)
output$ui_svm <- renderUI({ output$ui_svm <- renderUI({
req(input$dataset) req(input$dataset)
tagList( tagList(
...@@ -258,7 +238,7 @@ output$ui_svm <- renderUI({ ...@@ -258,7 +238,7 @@ output$ui_svm <- renderUI({
) )
) )
), ),
# 预测面板(残差存储已删除) # 预测面板
conditionalPanel( conditionalPanel(
condition = "input.tabs_svm == 'Predict'", condition = "input.tabs_svm == 'Predict'",
selectInput( selectInput(
...@@ -303,19 +283,10 @@ output$ui_svm <- renderUI({ ...@@ -303,19 +283,10 @@ output$ui_svm <- renderUI({
) )
) )
), ),
# 绘图面板(已删vip、pdp、svm_margin) # 绘图面板
conditionalPanel( conditionalPanel(
condition = "input.tabs_svm == 'Plot'", condition = "input.tabs_svm == 'Plot'",
uiOutput("ui_svm_plots"), uiOutput("ui_svm_plots")
conditionalPanel(
condition = "input.svm_plots == 'pred_plot'",
uiOutput("ui_svm_incl"),
uiOutput("ui_svm_incl_int")
),
conditionalPanel(
condition = "input.svm_plots == 'dashboard'",
uiOutput("ui_svm_nrobs")
)
) )
), ),
# 帮助和报告面板 # 帮助和报告面板
...@@ -327,7 +298,7 @@ output$ui_svm <- renderUI({ ...@@ -327,7 +298,7 @@ output$ui_svm <- renderUI({
) )
}) })
## 绘图尺寸计算(已删vip、pdp、svm_margin) ## 绘图尺寸计算
svm_plot <- reactive({ svm_plot <- reactive({
if (svm_available() != "available") return() if (svm_available() != "available") return()
if (is.empty(input$svm_plots, "none")) return() if (is.empty(input$svm_plots, "none")) return()
...@@ -337,9 +308,6 @@ svm_plot <- reactive({ ...@@ -337,9 +308,6 @@ svm_plot <- reactive({
plot_width <- 650 plot_width <- 650
if ("decision_boundary" %in% input$svm_plots) { if ("decision_boundary" %in% input$svm_plots) {
plot_height <- 500 plot_height <- 500
} else if (input$svm_plots == "pred_plot") {
nr_vars <- length(input$svm_incl) + length(input$svm_incl_int)
plot_height <- max(250, ceiling(nr_vars / 2) * 250)
} else { } else {
plot_height <- max(500, length(res$evar) * 30) plot_height <- max(500, length(res$evar) * 30)
} }
...@@ -349,7 +317,7 @@ svm_plot <- reactive({ ...@@ -349,7 +317,7 @@ svm_plot <- reactive({
svm_plot_width <- function() svm_plot()$plot_width %||% 650 svm_plot_width <- function() svm_plot()$plot_width %||% 650
svm_plot_height <- function() svm_plot()$plot_height %||% 500 svm_plot_height <- function() svm_plot()$plot_height %||% 500
## 主输出面板(已删残差存储) ## 主输出面板
output$svm <- renderUI({ output$svm <- renderUI({
register_print_output("summary_svm", ".summary_svm") register_print_output("summary_svm", ".summary_svm")
register_print_output("predict_svm", ".predict_print_svm") register_print_output("predict_svm", ".predict_print_svm")
...@@ -393,14 +361,14 @@ svm_available <- reactive({ ...@@ -393,14 +361,14 @@ svm_available <- reactive({
} }
}) })
## 核心函数壳子 ## 核心函数
.svm <- eventReactive(input$svm_run, { .svm <- eventReactive(input$svm_run, {
svi <- svm_inputs() svi <- svm_inputs()
svi$envir <- r_data svi$envir <- r_data
withProgress(message = i18n$t("Estimating SVM model"), value = 1, do.call(svm, svi)) withProgress(message = i18n$t("Estimating SVM model"), value = 1, do.call(svm, svi))
}) })
## 辅助输出函数壳子 ## 辅助输出函数
.summary_svm <- reactive({ .summary_svm <- reactive({
if (not_pressed(input$svm_run)) return(i18n$t("** Press the Estimate button to estimate the SVM model **")) if (not_pressed(input$svm_run)) return(i18n$t("** Press the Estimate button to estimate the SVM model **"))
if (svm_available() != "available") return(svm_available()) if (svm_available() != "available") return(svm_available())
...@@ -444,7 +412,7 @@ svm_available <- reactive({ ...@@ -444,7 +412,7 @@ svm_available <- reactive({
withProgress(message = i18n$t("Generating SVM plots"), value = 1, do.call(plot, c(list(x = .svm()), pinp))) withProgress(message = i18n$t("Generating SVM plots"), value = 1, do.call(plot, c(list(x = .svm()), pinp)))
}) })
## 存储预测值(残差存储已删除) ## 存储预测值
observeEvent(input$svm_store_pred, { observeEvent(input$svm_store_pred, {
req( req(
pressed(input$svm_run), pressed(input$svm_run),
...@@ -457,14 +425,14 @@ observeEvent(input$svm_store_pred, { ...@@ -457,14 +425,14 @@ observeEvent(input$svm_store_pred, {
base_col_name <- fix_names(input$svm_store_pred_name) base_col_name <- fix_names(input$svm_store_pred_name)
meta <- attr(pred_result, "svm_meta") meta <- attr(pred_result, "svm_meta")
pred_cols <- if (meta$model_type == "classification") { pred_cols <- if (meta$model_type %in% c("classification", "regression")) {
colnames(pred_result)[grepl("^Predicted_Class|^Prob_", colnames(pred_result))] colnames(pred_result)[colnames(pred_result) == "Prediction"]
} else { } else {
"Predicted_Value" NULL
} }
new_col_names <- if (length(pred_cols) == 1) base_col_name else { new_col_names <- if (length(pred_cols) == 1) base_col_name else {
suffix <- gsub("^Predicted_|^Prob_", "", pred_cols) suffix <- gsub("^Prediction", "", pred_cols)
paste0(base_col_name, "_", suffix) paste0(base_col_name, ifelse(suffix == "", "", paste0("_", suffix)))
} }
colnames(pred_result)[match(pred_cols, colnames(pred_result))] <- new_col_names colnames(pred_result)[match(pred_cols, colnames(pred_result))] <- new_col_names
...@@ -510,9 +478,75 @@ download_handler( ...@@ -510,9 +478,75 @@ download_handler(
height = svm_plot_height height = svm_plot_height
) )
## 报告生成(空壳,保留接口) svm_report <- function() {
svm_report <- function() invisible() if (is.empty(input$svm_evar)) {
showNotification(i18n$t("Select at least one explanatory variable to generate report"), type = "error")
return(invisible())
}
outputs <- c("summary")
inp_out <- list(list(prn = TRUE), "")
figs <- FALSE
xcmd <- ""
if (!is.empty(input$svm_plots, "none")) {
inp <- check_plot_inputs(svm_plot_inputs())
inp$size <- NULL
inp_out[[2]] <- clean_args(inp, svm_plot_args[-1])
inp_out[[2]]$custom <- FALSE
outputs <- c(outputs, "plot")
figs <- TRUE
}
if (!is.empty(input$svm_predict, "none") &&
(!is.empty(input$svm_pred_data) || !is.empty(input$svm_pred_cmd))) {
pred_args <- clean_args(svm_pred_inputs(), svm_pred_args[-1])
if (!is.empty(pred_args$pred_cmd)) {
pred_args$pred_cmd <- strsplit(pred_args$pred_cmd, ";\\s*")[[1]]
} else {
pred_args$pred_cmd <- NULL
}
if (!is.empty(pred_args$pred_data)) {
pred_args$pred_data <- as.symbol(pred_args$pred_data)
} else {
pred_args$pred_data <- NULL
}
inp_out[[2 + figs]] <- pred_args
outputs <- c(outputs, "pred <- predict")
xcmd <- paste0(xcmd, "print(pred, n = 10)")
if (input$svm_predict %in% c("data", "datacmd") && !is.empty(input$svm_store_pred_name)) {
fixed <- fix_names(input$svm_store_pred_name)
updateTextInput(session, "svm_store_pred_name", value = fixed)
xcmd <- paste0(
xcmd, "\n", input$svm_pred_data, " <- store(",
input$svm_pred_data, ", pred, name = \"", fixed, "\")"
)
}
}
svm_inp <- svm_inputs()
if (input$svm_type == "regression") {
svm_inp$lev <- NULL
}
if (input$svm_kernel == "linear") {
svm_inp$gamma <- NULL
}
update_report(
inp_main = clean_args(svm_inp, svm_args),
fun_name = "svm",
inp_out = inp_out,
outputs = outputs,
figs = figs,
fig.width = svm_plot_width(),
fig.height = svm_plot_height(),
xcmd = xcmd
)
}
## 报告生成
observeEvent(input$svm_report, { observeEvent(input$svm_report, {
r_info[["latest_screenshot"]] <- NULL r_info[["latest_screenshot"]] <- NULL
svm_report() svm_report()
......
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/svm.R
\name{cv.svm}
\alias{cv.svm}
\title{Cross-validation for SVM}
\usage{
cv.svm(
object,
K = 5,
repeats = 1,
kernel = c("linear", "radial"),
cost = seq(0.1, 10, by = 0.5),
gamma = seq(0.1, 5, by = 0.5),
seed = 1234,
trace = TRUE,
fun,
...
)
}
\description{
Cross-validation for SVM
}
...@@ -4,17 +4,7 @@ ...@@ -4,17 +4,7 @@
\alias{plot.svm} \alias{plot.svm}
\title{Plot method for the svm function} \title{Plot method for the svm function}
\usage{ \usage{
\method{plot}{svm}( \method{plot}{svm}(x, plots = "none", size = 12, shiny = FALSE, custom = FALSE, ...)
x,
plots = "vip",
size = 12,
nrobs = -1,
incl = NULL,
incl_int = NULL,
shiny = FALSE,
custom = FALSE,
...
)
} }
\description{ \description{
Plot method for the svm function Plot method for the svm function
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
% Please edit documentation in R/svm.R % Please edit documentation in R/svm.R
\name{predict.svm} \name{predict.svm}
\alias{predict.svm} \alias{predict.svm}
\title{Predict method for the svm function} \title{Predict method for SVM model}
\usage{ \usage{
\method{predict}{svm}( \method{predict}{svm}(
object, object,
...@@ -14,5 +14,5 @@ ...@@ -14,5 +14,5 @@
) )
} }
\description{ \description{
Predict method for the svm function Predict method for SVM model
} }
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
\alias{varimp} \alias{varimp}
\title{Variable importance using the vip package and permutation importance} \title{Variable importance using the vip package and permutation importance}
\usage{ \usage{
varimp(object, rvar, lev, data = NULL, seed = 1234) varimp(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10)
varimp(object, rvar, lev, data = NULL, seed = 1234) varimp(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10)
} }
\arguments{ \arguments{
\item{object}{Model object created by Radiant} \item{object}{Model object created by Radiant}
...@@ -22,5 +22,5 @@ varimp(object, rvar, lev, data = NULL, seed = 1234) ...@@ -22,5 +22,5 @@ varimp(object, rvar, lev, data = NULL, seed = 1234)
\description{ \description{
Variable importance using the vip package and permutation importance Variable importance using the vip package and permutation importance
Variable importance using the vip package and permutation importance Variable importance for SVM using permutation importance
} }
% Generated by roxygen2: do not edit by hand % Generated by roxygen2: do not edit by hand
% Please edit documentation in R/nn.R, R/svm.R % Please edit documentation in R/nn.R
\name{varimp_plot} \name{varimp_plot}
\alias{varimp_plot} \alias{varimp_plot}
\title{Plot permutation importance} \title{Plot permutation importance}
\usage{ \usage{
varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
varimp_plot(object, rvar, lev, data = NULL, seed = 1234) varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
} }
\arguments{ \arguments{
...@@ -20,7 +18,5 @@ varimp_plot(object, rvar, lev, data = NULL, seed = 1234) ...@@ -20,7 +18,5 @@ varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
\item{seed}{Random seed for reproducibility} \item{seed}{Random seed for reproducibility}
} }
\description{ \description{
Plot permutation importance
Plot permutation importance Plot permutation importance
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment