Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
R
Radiant
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wuzekai
Radiant
Commits
b3e914bc
Commit
b3e914bc
authored
Dec 03, 2025
by
wuzekai
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update
parent
078f95fa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
206 additions
and
239 deletions
+206
-239
svm.R
radiant.model/R/svm.R
+148
-208
svm_ui.R
radiant.model/inst/app/tools/analysis/svm_ui.R
+58
-31
No files found.
radiant.model/R/svm.R
View file @
b3e914bc
...
...
@@ -328,231 +328,156 @@ plot.svm <- function(x,
{
if
(
isTRUE
(
shiny
))
.
else
print
(
.
)
}
}
#' Predict method for SVM model
#' Predict method for SVM model
(FIXED VERSION)
#' @export
predict.svm
<-
function
(
object
,
pred_data
=
NULL
,
pred_cmd
=
""
,
dec
=
3
,
envir
=
parent.frame
(),
...
)
{
# 1. 基础校验
if
(
is.character
(
object
))
return
(
object
)
if
(
!
inherits
(
object
,
"svm"
))
stop
(
"Object must be of class 'svm'"
)
if
(
is.null
(
object
$
model
)
||
!
inherits
(
object
$
model
,
"svm"
))
{
err_msg
<-
"Prediction failed: Invalid SVM model"
err_obj
<-
structure
(
list
(
error
=
err_msg
),
class
=
"svm.predict.error"
)
return
(
err_obj
)
return
(
structure
(
list
(
error
=
err_msg
),
class
=
"svm.predict.error"
))
}
svm_info
<-
object
# 2.
处理
预测数据
# 2.
获取
预测数据
if
(
is.character
(
pred_data
)
&&
nzchar
(
pred_data
))
{
if
(
!
exists
(
pred_data
,
envir
=
envir
))
{
err_obj
<-
structure
(
list
(
error
=
sprintf
(
"Dataset '%s' not found"
,
pred_data
)),
class
=
"svm.predict.error"
)
return
(
err_obj
)
return
(
structure
(
list
(
error
=
sprintf
(
"Dataset '%s' not found"
,
pred_data
)),
class
=
"svm.predict.error"
)
)
}
pred_data
<-
get
(
pred_data
,
envir
=
envir
)
}
has_data
<-
!
is.null
(
pred_data
)
&&
is.data.frame
(
pred_data
)
&&
nrow
(
pred_data
)
>
0
has_cmd
<-
is.character
(
pred_cmd
)
&&
nzchar
(
pred_cmd
)
has_cmd
<-
is.character
(
pred_cmd
)
&&
nzchar
(
pred_cmd
)
if
(
!
has_data
&&
!
has_cmd
)
{
err_obj
<-
structure
(
list
(
error
=
"Please select data and/or specify a command to generate predictions."
),
class
=
"svm.predict.error"
)
return
(
err_obj
)
return
(
structure
(
list
(
error
=
"Please select data and/or specify a prediction command."
),
class
=
"svm.predict.error"
)
)
}
## ensure you have a name for the prediction dataset
if
(
is.data.frame
(
pred_data
))
{
df_name
<-
deparse
(
substitute
(
pred_data
))
}
else
{
df_name
<-
pred_data
}
df_name
<-
if
(
is.data.frame
(
pred_data
))
deparse
(
substitute
(
pred_data
))
else
pred_data
# 3.
预测核心函数 - 修改参数名以匹配NN
# 3.
核心预测函数
pfun
<-
function
(
model
,
pred
,
se
,
conf_lev
)
{
# 验证数据
missing_vars
<-
setdiff
(
svm_info
$
evar
,
colnames
(
pred
))
if
(
length
(
missing_vars
)
>
0
)
{
return
(
paste
(
"Missing variables:"
,
paste
(
missing_vars
,
collapse
=
", "
)))
## ---- 1. 保证预测集字段顺序与训练集完全一致 ----
pred
<-
pred
[
,
svm_info
$
evar
,
drop
=
FALSE
]
## ---- 2. 对齐变量类型(再次重复训练阶段的逻辑) ----
train_df
<-
svm_info
$
dataset
for
(
v
in
svm_info
$
evar
)
{
# 若训练数据该变量是因子/字符
if
(
is.factor
(
train_df
[[
v
]]))
{
## 预测也转因子并保持相同 level
pred
[[
v
]]
<-
factor
(
pred
[[
v
]],
levels
=
levels
(
train_df
[[
v
]]))
## 转 numeric(训练阶段就是把 factor 都转成 numeric 的)
pred
[[
v
]]
<-
as.numeric
(
pred
[[
v
]])
}
else
if
(
is.character
(
train_df
[[
v
]]))
{
pred
[[
v
]]
<-
factor
(
pred
[[
v
]],
levels
=
unique
(
train_df
[[
v
]]))
pred
[[
v
]]
<-
as.numeric
(
pred
[[
v
]])
}
else
if
(
is.logical
(
train_df
[[
v
]]))
{
pred
[[
v
]]
<-
as.numeric
(
pred
[[
v
]])
}
else
{
## numeric → 保持原样
pred
[[
v
]]
<-
as.numeric
(
pred
[[
v
]])
}
}
# 内部标准化
ms
<-
attr
(
svm_info
$
dataset
,
"radiant_ms"
)
sds
<-
attr
(
svm_info
$
dataset
,
"radiant_sds"
)
## ---- 3. 应用训练阶段的标准化参数 ----
ms
<-
attr
(
train_df
,
"radiant_ms"
)
sds
<-
attr
(
train_df
,
"radiant_sds"
)
if
(
!
is.null
(
ms
)
&&
!
is.null
(
sds
))
{
isNum
<-
sapply
(
pred
,
is.numeric
)
cn
<-
names
(
isNum
)[
isNum
]
for
(
var
in
cn
)
{
if
(
var
%in%
names
(
ms
)
&&
var
%in%
names
(
sds
)
&&
sds
[
var
]
!=
0
)
{
pred
[[
var
]]
<-
(
pred
[[
var
]]
-
ms
[
var
])
/
sds
[
var
]
for
(
v
in
svm_info
$
evar
)
{
if
(
!
is.na
(
ms
[
v
])
&&
!
is.na
(
sds
[
v
])
&&
sds
[
v
]
!=
0
)
{
pred
[[
v
]]
<-
(
pred
[[
v
]]
-
ms
[
v
])
/
sds
[
v
]
}
}
}
# 调试信息:检查模型是否启用了概率
cat
(
"Model probability enabled:"
,
svm_info
$
model
$
prob.model
,
"\n"
)
# 执行预测
pred_result
<-
try
({
## ---- 4. 调用 e1071::predict ----
pred_result
<-
try
(
predict
(
svm_info
$
model
,
svm_info
$
model
,
newdata
=
pred
,
probability
=
TRUE
,
# 始终设置为TRUE
decision.values
=
TRUE
)
},
silent
=
TRUE
)
probability
=
svm_info
$
model
$
prob.model
),
silent
=
TRUE
)
if
(
inherits
(
pred_result
,
"try-error"
))
{
if
(
inherits
(
pred_result
,
"try-error"
))
return
(
paste
(
"Prediction failed:"
,
attr
(
pred_result
,
"condition"
)
$
message
))
}
#
4. 结果整理
#
# ---- 5. 分类模型输出概率 ----
if
(
svm_info
$
type
==
"classification"
)
{
# 正确的属性名是"probabilities"
prob_mat
<-
attr
(
pred_result
,
"probabilities"
)
if
(
!
is.null
(
prob_mat
))
{
# 调试:显示概率矩阵的列名
cat
(
"Probability matrix columns:"
,
paste
(
colnames
(
prob_mat
),
collapse
=
", "
),
"\n"
)
# 更智能的列名匹配
target_level
<-
as.character
(
svm_info
$
lev
)
# 尝试多种匹配方式
matching_col
<-
NULL
# 1. 精确匹配
exact_match
<-
grep
(
paste0
(
"^"
,
target_level
,
"$"
),
colnames
(
prob_mat
),
value
=
TRUE
,
ignore.case
=
TRUE
)
if
(
length
(
exact_match
)
>
0
)
{
matching_col
<-
exact_match
}
# 2. 部分匹配
else
{
partial_match
<-
grep
(
target_level
,
colnames
(
prob_mat
),
value
=
TRUE
,
ignore.case
=
TRUE
)
if
(
length
(
partial_match
)
>
0
)
{
matching_col
<-
partial_match
[
1
]
# 取第一个匹配
}
}
# 3. 如果还是没有匹配,尝试逻辑值匹配
if
(
is.null
(
matching_col
)
||
length
(
matching_col
)
==
0
)
{
if
(
target_level
%in%
c
(
"TRUE"
,
"Yes"
,
"1"
,
"1.0"
))
{
logical_match
<-
grep
(
"TRUE"
,
colnames
(
prob_mat
),
value
=
TRUE
,
ignore.case
=
TRUE
)
if
(
length
(
logical_match
)
>
0
)
{
matching_col
<-
logical_match
[
1
]
}
}
else
if
(
target_level
%in%
c
(
"FALSE"
,
"No"
,
"0"
,
"0.0"
))
{
logical_match
<-
grep
(
"FALSE"
,
colnames
(
prob_mat
),
value
=
TRUE
,
ignore.case
=
TRUE
)
if
(
length
(
logical_match
)
>
0
)
{
matching_col
<-
logical_match
[
1
]
}
}
}
# 4. 作为最后手段,使用第一列
if
(
is.null
(
matching_col
)
||
length
(
matching_col
)
==
0
)
{
if
(
ncol
(
prob_mat
)
>
0
)
{
matching_col
<-
colnames
(
prob_mat
)[
1
]
cat
(
"Warning: Using first probability column '"
,
matching_col
,
"' for level '"
,
target_level
,
"'\n"
)
}
}
if
(
!
is.null
(
matching_col
)
&&
length
(
matching_col
)
>
0
)
{
surv_prob
<-
as.numeric
(
prob_mat
[,
matching_col
])
pred_df
<-
data.frame
(
Prediction
=
round
(
surv_prob
,
dec
),
stringsAsFactors
=
FALSE
)
}
else
{
# 备用方案:使用决策值
decision_vals
<-
attr
(
pred_result
,
"decision.values"
)
if
(
!
is.null
(
decision_vals
))
{
# 将决策值转换为概率
pred_prob
<-
1
/
(
1
+
exp
(
-
decision_vals
))
pred_df
<-
data.frame
(
Prediction
=
round
(
pred_prob
,
dec
),
stringsAsFactors
=
FALSE
)
}
else
{
# 最后手段:使用预测类
pred_class
<-
as.character
(
pred_result
)
pred_prob
<-
ifelse
(
pred_class
==
target_level
,
1
,
0
)
pred_df
<-
data.frame
(
Prediction
=
round
(
pred_prob
,
dec
),
stringsAsFactors
=
FALSE
)
}
}
}
else
{
# 备用方案:使用决策值
decision_vals
<-
attr
(
pred_result
,
"decision.values"
)
if
(
!
is.null
(
decision_vals
))
{
# 将决策值转换为概率
pred_prob
<-
1
/
(
1
+
exp
(
-
decision_vals
))
pred_df
<-
data.frame
(
Prediction
=
round
(
pred_prob
,
dec
),
stringsAsFactors
=
FALSE
)
prob
<-
attr
(
pred_result
,
"probabilities"
)
lev
<-
svm_info
$
lev
if
(
!
is.null
(
prob
))
{
if
(
lev
%in%
colnames
(
prob
))
{
p
<-
prob
[,
lev
]
}
else
{
# 最后手段:使用预测类
target_level
<-
as.character
(
svm_info
$
lev
)
pred_class
<-
as.character
(
pred_result
)
pred_prob
<-
ifelse
(
pred_class
==
target_level
,
1
,
0
)
pred_df
<-
data.frame
(
Prediction
=
round
(
pred_prob
,
dec
),
stringsAsFactors
=
FALSE
)
p
<-
prob
[,
1
]
# fallback
}
return
(
data.frame
(
Prediction
=
round
(
p
,
dec
)))
}
}
else
{
# 回归模型
pred_values
<-
as.numeric
(
pred_result
)
# 关键:添加反标准化处理
rvar_ms
<-
if
(
!
is.null
(
ms
)
&&
svm_info
$
rvar
%in%
names
(
ms
))
ms
[
svm_info
$
rvar
]
else
0
rvar_sds
<-
if
(
!
is.null
(
sds
)
&&
svm_info
$
rvar
%in%
names
(
sds
))
sds
[
svm_info
$
rvar
]
else
1
# 反标准化
pred_values
<-
pred_values
*
rvar_sds
+
rvar_ms
pred_df
<-
data.frame
(
Prediction
=
round
(
pred_values
,
dec
),
stringsAsFactors
=
FALSE
)
## 无概率 → 退化成 0/1
p
<-
as.character
(
pred_result
)
p
<-
ifelse
(
p
==
lev
,
1
,
0
)
return
(
data.frame
(
Prediction
=
round
(
p
,
dec
))
)
}
return
(
pred_df
)
## ---- 6. 回归模型 ----
return
(
data.frame
(
Prediction
=
round
(
as.numeric
(
pred_result
),
dec
)))
}
# 5. 调用预测框架 - 与NN完全一致
# 4. Radiant 框架式预测
result
<-
predict_model
(
object
,
pfun
,
"svm.predict"
,
# 模型类型
pred_data
,
pred_cmd
,
conf_lev
=
0.95
,
se
=
FALSE
,
dec
,
object
,
pfun
,
"svm.predict"
,
pred_data
,
pred_cmd
,
conf_lev
=
0.95
,
se
=
FALSE
,
dec
,
envir
=
envir
)
%>%
set_attr
(
"radiant_pred_data"
,
df_name
)
)
%>%
set_attr
(
"radiant_pred_data"
,
df_name
)
# 6. 结果元数据
if
(
inherits
(
result
,
"svm.predict.error"
))
{
return
(
result
$
error
)
}
if
(
inherits
(
result
,
"svm.predict.error"
))
return
(
result
$
error
)
if
(
is.character
(
result
))
return
(
result
)
result
<-
add_class
(
result
,
"svm.predict"
)
attr
(
result
,
"svm_meta"
)
<-
list
(
model_type
=
svm_info
$
type
,
kernel
=
svm_info
$
kernel
,
cost
=
svm_info
$
cost
,
gamma
=
if
(
svm_info
$
kernel
!=
"linear"
)
svm_info
$
gamma
else
NA
,
seed
=
svm_info
$
seed
,
kernel
=
svm_info
$
kernel
,
cost
=
svm_info
$
cost
,
gamma
=
if
(
svm_info
$
kernel
!=
"linear"
)
svm_info
$
gamma
else
NA
,
seed
=
svm_info
$
seed
,
train_data
=
svm_info
$
df_name
,
dec
=
dec
dec
=
dec
)
return
(
result
)
}
#' Print method for predict.svm
#' @export
print.svm.predict
<-
function
(
x
,
...
,
n
=
10
)
{
...
...
@@ -697,95 +622,109 @@ varimp <- function(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, np
data
<-
object
$
dataset
}
# 确定响应变量
# 确定响应变量
和分类水平(适配predict.svm的输出)
if
(
is.null
(
rvar
))
{
rvar
<-
object
$
rvar
}
# 确定分类水平
if
(
is.null
(
lev
)
&&
object
$
type
==
"classification"
)
{
lev
<-
object
$
lev
}
# 创建仅包含解释变量的数据集
X
<-
data
[,
object
$
evar
,
drop
=
FALSE
]
y
<-
data
[[
rvar
]]
y
<-
data
[[
rvar
]]
# 原始响应变量(因子/数值)
# 设置随机种子
if
(
!
is.na
(
seed
))
{
set.seed
(
seed
)
}
# 基准预测
# 基准预测
(调用predict.svm,获取实际输出)
base_pred
<-
predict
(
object
,
pred_data
=
data
,
envir
=
environment
())
if
(
is.character
(
base_pred
)
||
inherits
(
base_pred
,
"svm.predict.error"
))
{
stop
(
"基准预测失败:"
,
base_pred
)
}
# 根据任务类型选择性能指标
if
(
object
$
type
==
"classification"
)
{
# 处理二分类或多分类
base_metric
<-
if
(
nlevels
(
as.factor
(
y
))
==
2
)
{
# 二分类:计算AUC
pred_prob_col
<-
if
(
length
(
grep
(
"Prob_"
,
colnames
(
base_pred
)))
>
0
)
{
grep
(
paste0
(
"Prob_"
,
lev
),
colnames
(
base_pred
),
value
=
TRUE
)
}
else
{
NULL
# 根据任务类型选择性能指标(适配predict.svm的输出格式)
base_metric
<-
if
(
object
$
type
==
"classification"
)
{
# 分类任务:二分类用AUC,多分类用准确率
if
(
nlevels
(
as.factor
(
y
))
==
2
)
{
# 二分类:直接用predict.svm输出的"Prediction"列(成功水平概率)计算AUC
if
(
!
"Prediction"
%in%
colnames
(
base_pred
))
{
stop
(
"分类预测缺少'Prediction'列(概率值)"
)
}
if
(
!
is.null
(
pred_prob_col
))
{
pROC
::
roc
(
response
=
y
,
predictor
=
base_pred
[[
pred_prob_col
]])
$
auc
[[
1
]]
# 计算AUC(确保y是因子,prob是概率)
y_bin
<-
as.numeric
(
y
==
lev
)
# 原始响应变量转0/1(1=成功水平)
auc_val
<-
try
(
pROC
::
roc
(
response
=
y_bin
,
predictor
=
base_pred
$
Prediction
)
$
auc
[[
1
]],
silent
=
TRUE
)
if
(
inherits
(
auc_val
,
"try-error"
))
{
# AUC计算失败时降级用准确率
pred_class
<-
ifelse
(
base_pred
$
Prediction
>=
0.5
,
lev
,
setdiff
(
levels
(
y
),
lev
))
# 概率≥0.5判定为成功水平
mean
(
pred_class
==
as.character
(
y
),
na.rm
=
TRUE
)
}
else
{
# 无概率输出,使用准确率
mean
(
base_pred
$
Predicted_Class
==
y
,
na.rm
=
TRUE
)
auc_val
}
}
else
{
# 多分类:使用准确率
mean
(
base_pred
$
Predicted_Class
==
y
,
na.rm
=
TRUE
)
# 多分类:用预测概率矩阵(predict.svm需补充多分类支持,此处降级用准确率)
pred_class
<-
ifelse
(
base_pred
$
Prediction
>=
0.5
,
lev
,
setdiff
(
levels
(
y
),
lev
))
mean
(
pred_class
==
as.character
(
y
),
na.rm
=
TRUE
)
}
}
else
{
# 回归:计算R²
base_metric
<-
1
-
sum
((
base_pred
$
Predicted_Value
-
y
)
^
2
,
na.rm
=
TRUE
)
/
# 回归任务:计算R²(适配predict.svm的"Prediction"列)
if
(
!
"Prediction"
%in%
colnames
(
base_pred
))
{
stop
(
"回归预测缺少'Prediction'列(预测值)"
)
}
1
-
sum
((
base_pred
$
Prediction
-
y
)
^
2
,
na.rm
=
TRUE
)
/
sum
((
y
-
mean
(
y
,
na.rm
=
TRUE
))
^
2
,
na.rm
=
TRUE
)
}
# 为每个变量计算排列重要性
# 为每个变量计算排列重要性
(核心逻辑不变,仅适配预测输出)
importance_scores
<-
sapply
(
object
$
evar
,
function
(
var
)
{
metric_diffs
<-
numeric
(
nperm
)
for
(
i
in
1
:
nperm
)
{
# 创建数据副本
# 创建数据副本
,随机打乱当前变量
perm_data
<-
data
# 随机打乱当前变量
perm_data
[[
var
]]
<-
sample
(
perm_data
[[
var
]],
replace
=
FALSE
)
# 预测
#
排列后
预测
perm_pred
<-
predict
(
object
,
pred_data
=
perm_data
,
envir
=
environment
())
if
(
is.character
(
perm_pred
)
||
inherits
(
perm_pred
,
"svm.predict.error"
))
{
metric_diffs
[
i
]
<-
NA
next
}
# 计算性能变化
if
(
object
$
type
==
"classification"
)
{
if
(
nlevels
(
as.factor
(
y
))
==
2
&&
length
(
grep
(
"Prob_"
,
colnames
(
perm_pred
)))
>
0
)
{
pred_prob_col
<-
grep
(
paste0
(
"Prob_"
,
lev
),
colnames
(
perm_pred
),
value
=
TRUE
)
if
(
length
(
pred_prob_col
)
>
0
)
{
perm_metric
<-
pROC
::
roc
(
response
=
y
,
predictor
=
perm_pred
[[
pred_prob_col
]])
$
auc
[[
1
]]
# 计算排列后的性能指标
perm_metric
<-
if
(
object
$
type
==
"classification"
)
{
if
(
nlevels
(
as.factor
(
y
))
==
2
)
{
y_bin
<-
as.numeric
(
y
==
lev
)
auc_val
<-
try
(
pROC
::
roc
(
response
=
y_bin
,
predictor
=
perm_pred
$
Prediction
)
$
auc
[[
1
]],
silent
=
TRUE
)
if
(
inherits
(
auc_val
,
"try-error"
))
{
pred_class
<-
ifelse
(
perm_pred
$
Prediction
>=
0.5
,
lev
,
setdiff
(
levels
(
y
),
lev
))
mean
(
pred_class
==
as.character
(
y
),
na.rm
=
TRUE
)
}
else
{
perm_metric
<-
mean
(
perm_pred
$
Predicted_Class
==
y
,
na.rm
=
TRUE
)
auc_val
}
}
else
{
perm_metric
<-
mean
(
perm_pred
$
Predicted_Class
==
y
,
na.rm
=
TRUE
)
pred_class
<-
ifelse
(
perm_pred
$
Prediction
>=
0.5
,
lev
,
setdiff
(
levels
(
y
),
lev
))
mean
(
pred_class
==
as.character
(
y
),
na.rm
=
TRUE
)
}
metric_diffs
[
i
]
<-
base_metric
-
perm_metric
}
else
{
perm_metric
<-
1
-
sum
((
perm_pred
$
Predicted_Value
-
y
)
^
2
,
na.rm
=
TRUE
)
/
1
-
sum
((
perm_pred
$
Prediction
-
y
)
^
2
,
na.rm
=
TRUE
)
/
sum
((
y
-
mean
(
y
,
na.rm
=
TRUE
))
^
2
,
na.rm
=
TRUE
)
metric_diffs
[
i
]
<-
base_metric
-
perm_metric
}
# 性能变化(基准 - 排列后,值越大变量越重要)
metric_diffs
[
i
]
<-
base_metric
-
perm_metric
}
# 返回平均性能损失(忽略NA)
mean
(
metric_diffs
,
na.rm
=
TRUE
)
})
# 创建结果数据框
# 创建结果数据框
(过滤无效值)
result
<-
data.frame
(
Variable
=
names
(
importance_scores
),
Importance
=
as.numeric
(
importance_scores
),
Importance
=
as.numeric
(
pmax
(
importance_scores
,
0
)),
# 重要性不能为负
stringsAsFactors
=
FALSE
)
...
...
@@ -796,6 +735,7 @@ varimp <- function(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, np
}
#' @export
svm_vip_plot
<-
function
(
object
,
size
,
custom
)
{
tryCatch
({
...
...
radiant.model/inst/app/tools/analysis/svm_ui.R
View file @
b3e914bc
...
...
@@ -135,9 +135,14 @@ output$ui_svm_wts <- renderUI({
## 存储预测值UI
output
$
ui_svm_store_pred_name
<-
renderUI
({
init
<-
state_init
(
"svm_store_pred_name"
,
"pred_svm"
)
%>%
sub
(
"\\d{1,}$"
,
""
,
.
)
%>%
paste0
(
.
,
ifelse
(
is.empty
(
input
$
svm_kernel
),
""
,
input
$
svm_kernel
))
base_name
<-
"pred_svm"
kernel_name
<-
input
$
svm_kernel
# 获取当前选中的核函数
init
<-
if
(
is.empty
(
kernel_name
))
{
base_name
}
else
{
paste0
(
base_name
,
"_"
,
kernel_name
)
}
init
<-
state_init
(
"svm_store_pred_name"
,
init
)
textInput
(
"svm_store_pred_name"
,
i
18
n
$
t
(
"Store predictions:"
),
...
...
@@ -145,6 +150,16 @@ output$ui_svm_store_pred_name <- renderUI({
)
})
observeEvent
(
input
$
svm_kernel
,
{
current_value
<-
tryCatch
(
isolate
(
input
$
svm_store_pred_name
),
error
=
function
(
e
)
""
)
if
(
!
is.null
(
current_value
)
&&
length
(
current_value
)
>
0
&&
nzchar
(
current_value
))
{
if
(
grepl
(
"^pred_svm(_[a-z]+)?$"
,
current_value
))
{
new_value
<-
paste0
(
"pred_svm"
,
"_"
,
input
$
svm_kernel
)
updateTextInput
(
session
,
"svm_store_pred_name"
,
value
=
new_value
)
}
}
},
ignoreInit
=
TRUE
,
ignoreNULL
=
TRUE
)
## 数据集/模型类型切换时重置预测与绘图
observeEvent
(
input
$
dataset
,
{
updateSelectInput
(
session
=
session
,
inputId
=
"svm_predict"
,
selected
=
"none"
)
...
...
@@ -414,40 +429,52 @@ svm_available <- reactive({
## 存储预测值
observeEvent
(
input
$
svm_store_pred
,
{
req
(
pressed
(
input
$
svm_run
),
!
is.empty
(
input
$
svm_pred_data
),
!
is.empty
(
input
$
svm_store_pred_name
),
inherits
(
.predict_svm
(),
"svm.predict"
)
)
# 只有最基本的检查,不满足就静默退出
if
(
!
pressed
(
input
$
svm_run
)
||
is.empty
(
input
$
svm_pred_data
)
||
is.empty
(
input
$
svm_store_pred_name
))
{
return
()
}
# 获取预测结果(不管成功失败)
pred_result
<-
.predict_svm
()
target_data
<-
r_data
[[
input
$
svm_pred_data
]]
base_col_name
<-
fix_names
(
input
$
svm_store_pred_name
)
meta
<-
attr
(
pred_result
,
"svm_meta"
)
pred_cols
<-
if
(
meta
$
model_type
%in%
c
(
"classification"
,
"regression"
))
{
colnames
(
pred_result
)[
colnames
(
pred_result
)
==
"Prediction"
]
# 如果预测返回的是错误字符串,直接创建NA列
if
(
is.character
(
pred_result
))
{
target_data
[[
base_col_name
]]
<-
rep
(
NA_real_
,
nrow
(
target_data
))
attr
(
target_data
[[
base_col_name
]],
"error"
)
<-
pred_result
r_data
[[
input
$
svm_pred_data
]]
<-
target_data
showNotification
(
sprintf
(
"预测失败,已添加NA列 '%s'"
,
base_col_name
),
type
=
"warning"
)
}
else
{
NULL
}
new_col_names
<-
if
(
length
(
pred_cols
)
==
1
)
base_col_name
else
{
suffix
<-
gsub
(
"^Prediction"
,
""
,
pred_cols
)
paste0
(
base_col_name
,
ifelse
(
suffix
==
""
,
""
,
paste0
(
"_"
,
suffix
)))
# 正常情况:直接提取Prediction列
if
(
"Prediction"
%in%
colnames
(
pred_result
))
{
# 用cbind逻辑,更简单直接
pred_values
<-
pred_result
$
Prediction
# 处理长度不匹配
n_target
<-
nrow
(
target_data
)
n_pred
<-
length
(
pred_values
)
if
(
n_pred
<
n_target
)
{
# 预测值少,用NA填充
pred_values
<-
c
(
pred_values
,
rep
(
NA_real_
,
n_target
-
n_pred
))
}
else
if
(
n_pred
>
n_target
)
{
# 预测值多,截断
pred_values
<-
pred_values
[
1
:
n_target
]
}
target_data
[[
base_col_name
]]
<-
pred_values
r_data
[[
input
$
svm_pred_data
]]
<-
target_data
}
else
{
target_data
[[
base_col_name
]]
<-
rep
(
NA_real_
,
nrow
(
target_data
))
r_data
[[
input
$
svm_pred_data
]]
<-
target_data
}
}
colnames
(
pred_result
)[
match
(
pred_cols
,
colnames
(
pred_result
))]
<-
new_col_names
merged_data
<-
merge
(
target_data
,
pred_result
[,
c
(
meta
$
evar
,
new_col_names
),
drop
=
FALSE
],
by
=
meta
$
evar
,
all.x
=
TRUE
)
r_data
[[
input
$
svm_pred_data
]]
<-
merged_data
showNotification
(
sprintf
(
i
18
n
$
t
(
"SVM predictions stored as: %s (in '%s')"
),
paste
(
new_col_names
,
collapse
=
", "
),
input
$
svm_pred_data
),
type
=
"message"
)
updateTextInput
(
session
,
"svm_store_pred_name"
,
value
=
base_col_name
)
# 重置输入框
updateTextInput
(
session
,
"svm_store_pred_name"
,
value
=
"pred_svm"
)
})
## 下载处理
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment