Commit b3e914bc authored by wuzekai's avatar wuzekai

update

parent 078f95fa
......@@ -328,216 +328,140 @@ plot.svm <- function(x,
{ if (isTRUE(shiny)) . else print(.) }
}
#' Predict method for SVM model
#' Predict method for SVM model (FIXED VERSION)
#' @export
predict.svm <- function(object, pred_data = NULL, pred_cmd = "", dec = 3,
envir = parent.frame(), ...) {
# 1. 基础校验
if (is.character(object)) return(object)
if (!inherits(object, "svm")) stop("Object must be of class 'svm'")
if (is.null(object$model) || !inherits(object$model, "svm")) {
err_msg <- "Prediction failed: Invalid SVM model"
err_obj <- structure(list(error = err_msg), class = "svm.predict.error")
return(err_obj)
return(structure(list(error = err_msg), class = "svm.predict.error"))
}
svm_info <- object
# 2. 处理预测数据
# 2. 获取预测数据
if (is.character(pred_data) && nzchar(pred_data)) {
if (!exists(pred_data, envir = envir)) {
err_obj <- structure(list(error = sprintf("Dataset '%s' not found", pred_data)), class = "svm.predict.error")
return(err_obj)
return(structure(list(error = sprintf("Dataset '%s' not found", pred_data)),
class = "svm.predict.error"))
}
pred_data <- get(pred_data, envir = envir)
}
has_data <- !is.null(pred_data) && is.data.frame(pred_data) && nrow(pred_data) > 0
has_cmd <- is.character(pred_cmd) && nzchar(pred_cmd)
if (!has_data && !has_cmd) {
err_obj <- structure(list(error = "Please select data and/or specify a command to generate predictions."), class = "svm.predict.error")
return(err_obj)
return(structure(list(error = "Please select data and/or specify a prediction command."),
class = "svm.predict.error"))
}
## ensure you have a name for the prediction dataset
if (is.data.frame(pred_data)) {
df_name <- deparse(substitute(pred_data))
} else {
df_name <- pred_data
}
df_name <- if (is.data.frame(pred_data)) deparse(substitute(pred_data)) else pred_data
# 3. 预测核心函数 - 修改参数名以匹配NN
# 3. 核心预测函数
pfun <- function(model, pred, se, conf_lev) {
# 验证数据
missing_vars <- setdiff(svm_info$evar, colnames(pred))
if (length(missing_vars) > 0) {
return(paste("Missing variables:", paste(missing_vars, collapse = ", ")))
}
# 内部标准化
ms <- attr(svm_info$dataset, "radiant_ms")
sds <- attr(svm_info$dataset, "radiant_sds")
if (!is.null(ms) && !is.null(sds)) {
isNum <- sapply(pred, is.numeric)
cn <- names(isNum)[isNum]
for (var in cn) {
if (var %in% names(ms) && var %in% names(sds) && sds[var] != 0) {
pred[[var]] <- (pred[[var]] - ms[var]) / sds[var]
}
}
}
## ---- 1. 保证预测集字段顺序与训练集完全一致 ----
pred <- pred[ , svm_info$evar, drop = FALSE]
# 调试信息:检查模型是否启用了概率
cat("Model probability enabled:", svm_info$model$prob.model, "\n")
## ---- 2. 对齐变量类型(再次重复训练阶段的逻辑) ----
train_df <- svm_info$dataset
# 执行预测
pred_result <- try({
predict(
svm_info$model,
newdata = pred,
probability = TRUE, # 始终设置为TRUE
decision.values = TRUE
)
}, silent = TRUE)
for (v in svm_info$evar) {
if (inherits(pred_result, "try-error")) {
return(paste("Prediction failed:", attr(pred_result, "condition")$message))
}
# 若训练数据该变量是因子/字符
if (is.factor(train_df[[v]])) {
# 4. 结果整理
if (svm_info$type == "classification") {
# 正确的属性名是"probabilities"
prob_mat <- attr(pred_result, "probabilities")
## 预测也转因子并保持相同 level
pred[[v]] <- factor(pred[[v]],
levels = levels(train_df[[v]]))
if (!is.null(prob_mat)) {
# 调试:显示概率矩阵的列名
cat("Probability matrix columns:", paste(colnames(prob_mat), collapse = ", "), "\n")
## 转 numeric(训练阶段就是把 factor 都转成 numeric 的)
pred[[v]] <- as.numeric(pred[[v]])
# 更智能的列名匹配
target_level <- as.character(svm_info$lev)
} else if (is.character(train_df[[v]])) {
# 尝试多种匹配方式
matching_col <- NULL
pred[[v]] <- factor(pred[[v]],
levels = unique(train_df[[v]]))
pred[[v]] <- as.numeric(pred[[v]])
# 1. 精确匹配
exact_match <- grep(paste0("^", target_level, "$"), colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(exact_match) > 0) {
matching_col <- exact_match
}
# 2. 部分匹配
else {
partial_match <- grep(target_level, colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(partial_match) > 0) {
matching_col <- partial_match[1] # 取第一个匹配
}
}
} else if (is.logical(train_df[[v]])) {
# 3. 如果还是没有匹配,尝试逻辑值匹配
if (is.null(matching_col) || length(matching_col) == 0) {
if (target_level %in% c("TRUE", "Yes", "1", "1.0")) {
logical_match <- grep("TRUE", colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(logical_match) > 0) {
matching_col <- logical_match[1]
}
} else if (target_level %in% c("FALSE", "No", "0", "0.0")) {
logical_match <- grep("FALSE", colnames(prob_mat), value = TRUE, ignore.case = TRUE)
if (length(logical_match) > 0) {
matching_col <- logical_match[1]
}
}
}
pred[[v]] <- as.numeric(pred[[v]])
# 4. 作为最后手段,使用第一列
if (is.null(matching_col) || length(matching_col) == 0) {
if (ncol(prob_mat) > 0) {
matching_col <- colnames(prob_mat)[1]
cat("Warning: Using first probability column '", matching_col, "' for level '", target_level, "'\n")
} else {
## numeric → 保持原样
pred[[v]] <- as.numeric(pred[[v]])
}
}
if (!is.null(matching_col) && length(matching_col) > 0) {
surv_prob <- as.numeric(prob_mat[, matching_col])
pred_df <- data.frame(
Prediction = round(surv_prob, dec),
stringsAsFactors = FALSE
)
} else {
# 备用方案:使用决策值
decision_vals <- attr(pred_result, "decision.values")
if (!is.null(decision_vals)) {
# 将决策值转换为概率
pred_prob <- 1 / (1 + exp(-decision_vals))
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
} else {
# 最后手段:使用预测类
pred_class <- as.character(pred_result)
pred_prob <- ifelse(pred_class == target_level, 1, 0)
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
## ---- 3. 应用训练阶段的标准化参数 ----
ms <- attr(train_df, "radiant_ms")
sds <- attr(train_df, "radiant_sds")
if (!is.null(ms) && !is.null(sds)) {
for (v in svm_info$evar) {
if (!is.na(ms[v]) && !is.na(sds[v]) && sds[v] != 0) {
pred[[v]] <- (pred[[v]] - ms[v]) / sds[v]
}
}
} else {
# 备用方案:使用决策值
decision_vals <- attr(pred_result, "decision.values")
if (!is.null(decision_vals)) {
# 将决策值转换为概率
pred_prob <- 1 / (1 + exp(-decision_vals))
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
}
## ---- 4. 调用 e1071::predict ----
pred_result <- try(
predict(
svm_info$model,
newdata = pred,
probability = svm_info$model$prob.model
),
silent = TRUE
)
if (inherits(pred_result, "try-error"))
return(paste("Prediction failed:", attr(pred_result, "condition")$message))
## ---- 5. 分类模型输出概率 ----
if (svm_info$type == "classification") {
prob <- attr(pred_result, "probabilities")
lev <- svm_info$lev
if (!is.null(prob)) {
if (lev %in% colnames(prob)) {
p <- prob[, lev]
} else {
# 最后手段:使用预测类
target_level <- as.character(svm_info$lev)
pred_class <- as.character(pred_result)
pred_prob <- ifelse(pred_class == target_level, 1, 0)
pred_df <- data.frame(
Prediction = round(pred_prob, dec),
stringsAsFactors = FALSE
)
p <- prob[, 1] # fallback
}
return(data.frame(Prediction = round(p, dec)))
}
} else {
# 回归模型
pred_values <- as.numeric(pred_result)
# 关键:添加反标准化处理
rvar_ms <- if (!is.null(ms) && svm_info$rvar %in% names(ms)) ms[svm_info$rvar] else 0
rvar_sds <- if (!is.null(sds) && svm_info$rvar %in% names(sds)) sds[svm_info$rvar] else 1
# 反标准化
pred_values <- pred_values * rvar_sds + rvar_ms
pred_df <- data.frame(
Prediction = round(pred_values, dec),
stringsAsFactors = FALSE
)
## 无概率 → 退化成 0/1
p <- as.character(pred_result)
p <- ifelse(p == lev, 1, 0)
return(data.frame(Prediction = round(p, dec)))
}
return(pred_df)
## ---- 6. 回归模型 ----
return(data.frame(Prediction = round(as.numeric(pred_result), dec)))
}
# 5. 调用预测框架 - 与NN完全一致
# 4. Radiant 框架式预测
result <- predict_model(
object,
pfun,
"svm.predict", # 模型类型
"svm.predict",
pred_data,
pred_cmd,
conf_lev = 0.95,
se = FALSE,
dec,
envir = envir
) %>%
set_attr("radiant_pred_data", df_name)
) %>% set_attr("radiant_pred_data", df_name)
# 6. 结果元数据
if (inherits(result, "svm.predict.error")) {
return(result$error)
}
if (inherits(result, "svm.predict.error")) return(result$error)
if (is.character(result)) return(result)
result <- add_class(result, "svm.predict")
......@@ -553,6 +477,7 @@ predict.svm <- function(object, pred_data = NULL, pred_cmd = "", dec = 3,
return(result)
}
#' Print method for predict.svm
#' @export
print.svm.predict <- function(x, ..., n = 10) {
......@@ -697,95 +622,109 @@ varimp <- function(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, np
data <- object$dataset
}
# 确定响应变量
# 确定响应变量和分类水平(适配predict.svm的输出)
if (is.null(rvar)) {
rvar <- object$rvar
}
# 确定分类水平
if (is.null(lev) && object$type == "classification") {
lev <- object$lev
}
# 创建仅包含解释变量的数据集
X <- data[, object$evar, drop = FALSE]
y <- data[[rvar]]
y <- data[[rvar]] # 原始响应变量(因子/数值)
# 设置随机种子
if (!is.na(seed)) {
set.seed(seed)
}
# 基准预测
# 基准预测(调用predict.svm,获取实际输出)
base_pred <- predict(object, pred_data = data, envir = environment())
# 根据任务类型选择性能指标
if (object$type == "classification") {
# 处理二分类或多分类
base_metric <- if (nlevels(as.factor(y)) == 2) {
# 二分类:计算AUC
pred_prob_col <- if (length(grep("Prob_", colnames(base_pred))) > 0) {
grep(paste0("Prob_", lev), colnames(base_pred), value = TRUE)
if (is.character(base_pred) || inherits(base_pred, "svm.predict.error")) {
stop("基准预测失败:", base_pred)
}
# 根据任务类型选择性能指标(适配predict.svm的输出格式)
base_metric <- if (object$type == "classification") {
# 分类任务:二分类用AUC,多分类用准确率
if (nlevels(as.factor(y)) == 2) {
# 二分类:直接用predict.svm输出的"Prediction"列(成功水平概率)计算AUC
if (!"Prediction" %in% colnames(base_pred)) {
stop("分类预测缺少'Prediction'列(概率值)")
}
# 计算AUC(确保y是因子,prob是概率)
y_bin <- as.numeric(y == lev) # 原始响应变量转0/1(1=成功水平)
auc_val <- try(pROC::roc(response = y_bin, predictor = base_pred$Prediction)$auc[[1]], silent = TRUE)
if (inherits(auc_val, "try-error")) {
# AUC计算失败时降级用准确率
pred_class <- ifelse(base_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev)) # 概率≥0.5判定为成功水平
mean(pred_class == as.character(y), na.rm = TRUE)
} else {
NULL
auc_val
}
if (!is.null(pred_prob_col)) {
pROC::roc(response = y, predictor = base_pred[[pred_prob_col]])$auc[[1]]
} else {
# 无概率输出,使用准确率
mean(base_pred$Predicted_Class == y, na.rm = TRUE)
# 多分类:用预测概率矩阵(predict.svm需补充多分类支持,此处降级用准确率)
pred_class <- ifelse(base_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev))
mean(pred_class == as.character(y), na.rm = TRUE)
}
} else {
# 多分类:使用准确率
mean(base_pred$Predicted_Class == y, na.rm = TRUE)
# 回归任务:计算R²(适配predict.svm的"Prediction"列)
if (!"Prediction" %in% colnames(base_pred)) {
stop("回归预测缺少'Prediction'列(预测值)")
}
} else {
# 回归:计算R²
base_metric <- 1 - sum((base_pred$Predicted_Value - y)^2, na.rm = TRUE) /
1 - sum((base_pred$Prediction - y)^2, na.rm = TRUE) /
sum((y - mean(y, na.rm = TRUE))^2, na.rm = TRUE)
}
# 为每个变量计算排列重要性
# 为每个变量计算排列重要性(核心逻辑不变,仅适配预测输出)
importance_scores <- sapply(object$evar, function(var) {
metric_diffs <- numeric(nperm)
for (i in 1:nperm) {
# 创建数据副本
# 创建数据副本,随机打乱当前变量
perm_data <- data
# 随机打乱当前变量
perm_data[[var]] <- sample(perm_data[[var]], replace = FALSE)
# 预测
# 排列后预测
perm_pred <- predict(object, pred_data = perm_data, envir = environment())
# 计算性能变化
if (object$type == "classification") {
if (nlevels(as.factor(y)) == 2 && length(grep("Prob_", colnames(perm_pred))) > 0) {
pred_prob_col <- grep(paste0("Prob_", lev), colnames(perm_pred), value = TRUE)
if (length(pred_prob_col) > 0) {
perm_metric <- pROC::roc(response = y, predictor = perm_pred[[pred_prob_col]])$auc[[1]]
if (is.character(perm_pred) || inherits(perm_pred, "svm.predict.error")) {
metric_diffs[i] <- NA
next
}
# 计算排列后的性能指标
perm_metric <- if (object$type == "classification") {
if (nlevels(as.factor(y)) == 2) {
y_bin <- as.numeric(y == lev)
auc_val <- try(pROC::roc(response = y_bin, predictor = perm_pred$Prediction)$auc[[1]], silent = TRUE)
if (inherits(auc_val, "try-error")) {
pred_class <- ifelse(perm_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev))
mean(pred_class == as.character(y), na.rm = TRUE)
} else {
perm_metric <- mean(perm_pred$Predicted_Class == y, na.rm = TRUE)
auc_val
}
} else {
perm_metric <- mean(perm_pred$Predicted_Class == y, na.rm = TRUE)
pred_class <- ifelse(perm_pred$Prediction >= 0.5, lev, setdiff(levels(y), lev))
mean(pred_class == as.character(y), na.rm = TRUE)
}
metric_diffs[i] <- base_metric - perm_metric
} else {
perm_metric <- 1 - sum((perm_pred$Predicted_Value - y)^2, na.rm = TRUE) /
1 - sum((perm_pred$Prediction - y)^2, na.rm = TRUE) /
sum((y - mean(y, na.rm = TRUE))^2, na.rm = TRUE)
metric_diffs[i] <- base_metric - perm_metric
}
# 性能变化(基准 - 排列后,值越大变量越重要)
metric_diffs[i] <- base_metric - perm_metric
}
# 返回平均性能损失(忽略NA)
mean(metric_diffs, na.rm = TRUE)
})
# 创建结果数据框
# 创建结果数据框(过滤无效值)
result <- data.frame(
Variable = names(importance_scores),
Importance = as.numeric(importance_scores),
Importance = as.numeric(pmax(importance_scores, 0)), # 重要性不能为负
stringsAsFactors = FALSE
)
......@@ -796,6 +735,7 @@ varimp <- function(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, np
}
#' @export
svm_vip_plot <- function(object, size, custom) {
tryCatch({
......
......@@ -135,9 +135,14 @@ output$ui_svm_wts <- renderUI({
## 存储预测值UI
output$ui_svm_store_pred_name <- renderUI({
init <- state_init("svm_store_pred_name", "pred_svm") %>%
sub("\\d{1,}$", "", .) %>%
paste0(., ifelse(is.empty(input$svm_kernel), "", input$svm_kernel))
base_name <- "pred_svm"
kernel_name <- input$svm_kernel # 获取当前选中的核函数
init <- if (is.empty(kernel_name)) {
base_name
} else {
paste0(base_name, "_", kernel_name)
}
init <- state_init("svm_store_pred_name", init)
textInput(
"svm_store_pred_name",
i18n$t("Store predictions:"),
......@@ -145,6 +150,16 @@ output$ui_svm_store_pred_name <- renderUI({
)
})
observeEvent(input$svm_kernel, {
current_value <- tryCatch(isolate(input$svm_store_pred_name), error = function(e) "")
if (!is.null(current_value) && length(current_value) > 0 && nzchar(current_value)) {
if (grepl("^pred_svm(_[a-z]+)?$", current_value)) {
new_value <- paste0("pred_svm", "_", input$svm_kernel)
updateTextInput(session, "svm_store_pred_name", value = new_value)
}
}
}, ignoreInit = TRUE, ignoreNULL = TRUE)
## 数据集/模型类型切换时重置预测与绘图
observeEvent(input$dataset, {
updateSelectInput(session = session, inputId = "svm_predict", selected = "none")
......@@ -414,40 +429,52 @@ svm_available <- reactive({
## 存储预测值
observeEvent(input$svm_store_pred, {
req(
pressed(input$svm_run),
!is.empty(input$svm_pred_data),
!is.empty(input$svm_store_pred_name),
inherits(.predict_svm(), "svm.predict")
)
# 只有最基本的检查,不满足就静默退出
if (!pressed(input$svm_run) || is.empty(input$svm_pred_data) || is.empty(input$svm_store_pred_name)) {
return()
}
# 获取预测结果(不管成功失败)
pred_result <- .predict_svm()
target_data <- r_data[[input$svm_pred_data]]
base_col_name <- fix_names(input$svm_store_pred_name)
meta <- attr(pred_result, "svm_meta")
pred_cols <- if (meta$model_type %in% c("classification", "regression")) {
colnames(pred_result)[colnames(pred_result) == "Prediction"]
} else {
NULL
}
new_col_names <- if (length(pred_cols) == 1) base_col_name else {
suffix <- gsub("^Prediction", "", pred_cols)
paste0(base_col_name, ifelse(suffix == "", "", paste0("_", suffix)))
}
colnames(pred_result)[match(pred_cols, colnames(pred_result))] <- new_col_names
# 如果预测返回的是错误字符串,直接创建NA列
if (is.character(pred_result)) {
target_data[[base_col_name]] <- rep(NA_real_, nrow(target_data))
attr(target_data[[base_col_name]], "error") <- pred_result
r_data[[input$svm_pred_data]] <- target_data
merged_data <- merge(
target_data,
pred_result[, c(meta$evar, new_col_names), drop = FALSE],
by = meta$evar, all.x = TRUE
)
r_data[[input$svm_pred_data]] <- merged_data
showNotification(
sprintf(i18n$t("SVM predictions stored as: %s (in '%s')"),
paste(new_col_names, collapse = ", "), input$svm_pred_data),
type = "message"
sprintf("预测失败,已添加NA列 '%s'", base_col_name),
type = "warning"
)
updateTextInput(session, "svm_store_pred_name", value = base_col_name)
} else {
# 正常情况:直接提取Prediction列
if ("Prediction" %in% colnames(pred_result)) {
# 用cbind逻辑,更简单直接
pred_values <- pred_result$Prediction
# 处理长度不匹配
n_target <- nrow(target_data)
n_pred <- length(pred_values)
if (n_pred < n_target) {
# 预测值少,用NA填充
pred_values <- c(pred_values, rep(NA_real_, n_target - n_pred))
} else if (n_pred > n_target) {
# 预测值多,截断
pred_values <- pred_values[1:n_target]
}
target_data[[base_col_name]] <- pred_values
r_data[[input$svm_pred_data]] <- target_data
} else {
target_data[[base_col_name]] <- rep(NA_real_, nrow(target_data))
r_data[[input$svm_pred_data]] <- target_data
}
}
# 重置输入框
updateTextInput(session, "svm_store_pred_name", value = "pred_svm")
})
## 下载处理
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment