Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
R
Radiant
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wuzekai
Radiant
Commits
53dd5828
Commit
53dd5828
authored
Nov 21, 2025
by
wuzekai
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
更新了svm
parent
e9df660e
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
723 additions
and
561 deletions
+723
-561
NAMESPACE
radiant.model/NAMESPACE
+279
-277
svm.R
radiant.model/R/svm.R
+347
-186
init.R
radiant.model/inst/app/init.R
+2
-1
svm_ui.R
radiant.model/inst/app/tools/analysis/svm_ui.R
+88
-54
cv.svm.Rd
radiant.model/man/cv.svm.Rd
+0
-22
plot.svm.Rd
radiant.model/man/plot.svm.Rd
+1
-11
predict.svm.Rd
radiant.model/man/predict.svm.Rd
+2
-2
varimp.Rd
radiant.model/man/varimp.Rd
+3
-3
varimp_plot.Rd
radiant.model/man/varimp_plot.Rd
+1
-5
No files found.
radiant.model/NAMESPACE
View file @
53dd5828
# Generated by roxygen2: do not edit by hand
S3method(plot,confusion)
S3method(plot,coxp)
S3method(plot,crs)
S3method(plot,crtree)
S3method(plot,dtree)
S3method(plot,evalbin)
S3method(plot,evalreg)
S3method(plot,gbt)
S3method(plot,logistic)
S3method(plot,mnl)
S3method(plot,mnl.predict)
S3method(plot,model.predict)
S3method(plot,nb)
S3method(plot,nb.predict)
S3method(plot,nn)
S3method(plot,regress)
S3method(plot,repeater)
S3method(plot,rforest)
S3method(plot,rforest.predict)
S3method(plot,simulater)
S3method(plot,svm)
S3method(plot,uplift)
S3method(predict,coxp)
S3method(predict,crtree)
S3method(predict,gbt)
S3method(predict,logistic)
S3method(predict,mnl)
S3method(predict,nb)
S3method(predict,nn)
S3method(predict,regress)
S3method(predict,rforest)
S3method(predict,svm)
S3method(print,coxp.predict)
S3method(print,crtree.predict)
S3method(print,gbt.predict)
S3method(print,logistic.predict)
S3method(print,mnl.predict)
S3method(print,nb.predict)
S3method(print,nn.predict)
S3method(print,regress.predict)
S3method(print,rforest.predict)
S3method(print,svm.predict)
S3method(render,DiagrammeR)
S3method(sensitivity,dtree)
S3method(store,coxp.predict)
S3method(store,crs)
S3method(store,mnl.predict)
S3method(store,model)
S3method(store,model.predict)
S3method(store,nb.predict)
S3method(store,rforest.predict)
S3method(summary,confusion)
S3method(summary,coxp)
S3method(summary,crs)
S3method(summary,crtree)
S3method(summary,dtree)
S3method(summary,evalbin)
S3method(summary,evalreg)
S3method(summary,gbt)
S3method(summary,logistic)
S3method(summary,mnl)
S3method(summary,nb)
S3method(summary,nn)
S3method(summary,regress)
S3method(summary,repeater)
S3method(summary,rforest)
S3method(summary,simulater)
S3method(summary,svm)
S3method(summary,uplift)
export(.as_int)
export(.as_num)
export(MAE)
export(RMSE)
export(Rsq)
export(ann)
export(auc)
export(confint_robust)
export(confusion)
export(coxp)
export(crs)
export(crtree)
export(cv.crtree)
export(cv.gbt)
export(cv.nn)
export(cv.rforest)
export(cv.svm)
export(dtree)
export(dtree_parser)
export(evalbin)
export(evalreg)
export(find_max)
export(find_min)
export(gbt)
export(logistic)
export(minmax)
export(mnl)
export(nb)
export(nn)
export(onehot)
export(pdp_plot)
export(pred_plot)
export(predict_model)
export(print_predict_model)
export(profit)
export(radiant.model)
export(radiant.model_viewer)
export(radiant.model_window)
export(regress)
export(remove_comments)
export(repeater)
export(rforest)
export(rig)
export(scale_df)
export(sdw)
export(sensitivity)
export(sim_cleaner)
export(sim_cor)
export(sim_splitter)
export(sim_summary)
export(simulater)
export(svm)
export(test_specs)
export(uplift)
export(var_check)
export(varimp)
export(varimp_plot)
export(write.coeff)
import(ggplot2)
import(radiant.data)
import(shiny)
importFrom(DiagrammeR,DiagrammeR)
importFrom(DiagrammeR,DiagrammeROutput)
importFrom(DiagrammeR,mermaid)
importFrom(DiagrammeR,renderDiagrammeR)
importFrom(NeuralNetTools,garson)
importFrom(NeuralNetTools,olden)
importFrom(NeuralNetTools,plotnet)
importFrom(broom,augment)
importFrom(car,linearHypothesis)
importFrom(car,vif)
importFrom(data.tree,Clone)
importFrom(data.tree,FormatPercent)
importFrom(data.tree,Get)
importFrom(data.tree,Traverse)
importFrom(data.tree,as.Node)
importFrom(data.tree,isLeaf)
importFrom(data.tree,isNotLeaf)
importFrom(data.tree,isNotRoot)
importFrom(dplyr,across)
importFrom(dplyr,arrange)
importFrom(dplyr,arrange_at)
importFrom(dplyr,bind_cols)
importFrom(dplyr,bind_rows)
importFrom(dplyr,data_frame)
importFrom(dplyr,desc)
importFrom(dplyr,distinct_at)
importFrom(dplyr,everything)
importFrom(dplyr,filter)
importFrom(dplyr,first)
importFrom(dplyr,funs)
importFrom(dplyr,group_by)
importFrom(dplyr,group_by_)
importFrom(dplyr,group_by_at)
importFrom(dplyr,inner_join)
importFrom(dplyr,last)
importFrom(dplyr,min_rank)
importFrom(dplyr,mutate)
importFrom(dplyr,mutate_)
importFrom(dplyr,mutate_all)
importFrom(dplyr,mutate_at)
importFrom(dplyr,mutate_if)
importFrom(dplyr,near)
importFrom(dplyr,pull)
importFrom(dplyr,rename)
importFrom(dplyr,sample_n)
importFrom(dplyr,select)
importFrom(dplyr,select_at)
importFrom(dplyr,slice)
importFrom(dplyr,summarise)
importFrom(dplyr,summarise_)
importFrom(dplyr,summarise_all)
importFrom(dplyr,summarise_at)
importFrom(dplyr,summarize)
importFrom(dplyr,ungroup)
importFrom(e1071,naiveBayes)
importFrom(ggplot2,autoplot)
importFrom(ggrepel,geom_text_repel)
importFrom(graphics,par)
importFrom(import,from)
importFrom(lubridate,is.Date)
importFrom(lubridate,now)
importFrom(magrittr,"%<>%")
importFrom(magrittr,"%>%")
importFrom(magrittr,"%T>%")
importFrom(magrittr,extract2)
importFrom(magrittr,set_colnames)
importFrom(magrittr,set_names)
importFrom(magrittr,set_rownames)
importFrom(nnet,nnet)
importFrom(nnet,nnet.formula)
importFrom(patchwork,plot_annotation)
importFrom(patchwork,wrap_plots)
importFrom(pdp,partial)
importFrom(psych,cohen.kappa)
importFrom(radiant.data,launch)
importFrom(radiant.data,set_attr)
importFrom(radiant.data,visualize)
importFrom(ranger,ranger)
importFrom(rlang,":=")
importFrom(rlang,.data)
importFrom(rlang,parse_exprs)
importFrom(rpart,prune.rpart)
importFrom(rpart,rpart)
importFrom(rpart,rpart.control)
importFrom(sandwich,vcovHC)
importFrom(scales,percent)
importFrom(shiny,getDefaultReactiveDomain)
importFrom(shiny,incProgress)
importFrom(shiny,withProgress)
importFrom(stats,anova)
importFrom(stats,as.formula)
importFrom(stats,binomial)
importFrom(stats,coef)
importFrom(stats,confint)
importFrom(stats,confint.default)
importFrom(stats,contrasts)
importFrom(stats,cor)
importFrom(stats,deviance)
importFrom(stats,dnorm)
importFrom(stats,family)
importFrom(stats,formula)
importFrom(stats,glm)
importFrom(stats,lm)
importFrom(stats,logLik)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,na.omit)
importFrom(stats,pnorm)
importFrom(stats,predict)
importFrom(stats,pt)
importFrom(stats,qnorm)
importFrom(stats,qt)
importFrom(stats,quantile)
importFrom(stats,rbinom)
importFrom(stats,relevel)
importFrom(stats,residuals)
importFrom(stats,rlnorm)
importFrom(stats,rnorm)
importFrom(stats,rpois)
importFrom(stats,runif)
importFrom(stats,sd)
importFrom(stats,setNames)
importFrom(stats,step)
importFrom(stats,terms)
importFrom(stats,terms.formula)
importFrom(stats,update)
importFrom(stats,weighted.mean)
importFrom(stats,wilcox.test)
importFrom(stringi,stri_trans_general)
importFrom(stringr,str_match)
importFrom(tidyr,gather)
importFrom(tidyr,spread)
importFrom(tidyselect,where)
importFrom(utils,as.relistable)
importFrom(utils,capture.output)
importFrom(utils,combn)
importFrom(utils,head)
importFrom(utils,relist)
importFrom(utils,tail)
importFrom(utils,write.table)
importFrom(vip,vi)
importFrom(xgboost,xgb.importance)
importFrom(xgboost,xgboost)
importFrom(yaml,yaml.load)
# Generated by roxygen2: do not edit by hand
S3method(plot,confusion)
S3method(plot,coxp)
S3method(plot,crs)
S3method(plot,crtree)
S3method(plot,dtree)
S3method(plot,evalbin)
S3method(plot,evalreg)
S3method(plot,gbt)
S3method(plot,logistic)
S3method(plot,mnl)
S3method(plot,mnl.predict)
S3method(plot,model.predict)
S3method(plot,nb)
S3method(plot,nb.predict)
S3method(plot,nn)
S3method(plot,regress)
S3method(plot,repeater)
S3method(plot,rforest)
S3method(plot,rforest.predict)
S3method(plot,simulater)
S3method(plot,svm)
S3method(plot,uplift)
S3method(predict,coxp)
S3method(predict,crtree)
S3method(predict,gbt)
S3method(predict,logistic)
S3method(predict,mnl)
S3method(predict,nb)
S3method(predict,nn)
S3method(predict,regress)
S3method(predict,rforest)
S3method(predict,svm)
S3method(print,coxp.predict)
S3method(print,crtree.predict)
S3method(print,gbt.predict)
S3method(print,logistic.predict)
S3method(print,mnl.predict)
S3method(print,nb.predict)
S3method(print,nn.predict)
S3method(print,regress.predict)
S3method(print,rforest.predict)
S3method(print,svm.predict)
S3method(render,DiagrammeR)
S3method(sensitivity,dtree)
S3method(store,coxp.predict)
S3method(store,crs)
S3method(store,mnl.predict)
S3method(store,model)
S3method(store,model.predict)
S3method(store,nb.predict)
S3method(store,rforest.predict)
S3method(summary,confusion)
S3method(summary,coxp)
S3method(summary,crs)
S3method(summary,crtree)
S3method(summary,dtree)
S3method(summary,evalbin)
S3method(summary,evalreg)
S3method(summary,gbt)
S3method(summary,logistic)
S3method(summary,mnl)
S3method(summary,nb)
S3method(summary,nn)
S3method(summary,regress)
S3method(summary,repeater)
S3method(summary,rforest)
S3method(summary,simulater)
S3method(summary,svm)
S3method(summary,uplift)
export(.as_int)
export(.as_num)
export(MAE)
export(RMSE)
export(Rsq)
export(ann)
export(auc)
export(confint_robust)
export(confusion)
export(coxp)
export(crs)
export(crtree)
export(cv.crtree)
export(cv.gbt)
export(cv.nn)
export(cv.rforest)
export(dtree)
export(dtree_parser)
export(evalbin)
export(evalreg)
export(find_max)
export(find_min)
export(gbt)
export(logistic)
export(minmax)
export(mnl)
export(nb)
export(nn)
export(onehot)
export(pdp_plot)
export(pred_plot)
export(predict_model)
export(print_predict_model)
export(profit)
export(radiant.model)
export(radiant.model_viewer)
export(radiant.model_window)
export(regress)
export(remove_comments)
export(repeater)
export(rforest)
export(rig)
export(scale_df)
export(sdw)
export(sensitivity)
export(sim_cleaner)
export(sim_cor)
export(sim_splitter)
export(sim_summary)
export(simulater)
export(svm)
export(svm_boundary_plot)
export(svm_margin_plot)
export(svm_vip_plot)
export(test_specs)
export(uplift)
export(var_check)
export(varimp)
export(varimp_plot)
export(write.coeff)
import(ggplot2)
import(radiant.data)
import(shiny)
importFrom(DiagrammeR,DiagrammeR)
importFrom(DiagrammeR,DiagrammeROutput)
importFrom(DiagrammeR,mermaid)
importFrom(DiagrammeR,renderDiagrammeR)
importFrom(NeuralNetTools,garson)
importFrom(NeuralNetTools,olden)
importFrom(NeuralNetTools,plotnet)
importFrom(broom,augment)
importFrom(car,linearHypothesis)
importFrom(car,vif)
importFrom(data.tree,Clone)
importFrom(data.tree,FormatPercent)
importFrom(data.tree,Get)
importFrom(data.tree,Traverse)
importFrom(data.tree,as.Node)
importFrom(data.tree,isLeaf)
importFrom(data.tree,isNotLeaf)
importFrom(data.tree,isNotRoot)
importFrom(dplyr,across)
importFrom(dplyr,arrange)
importFrom(dplyr,arrange_at)
importFrom(dplyr,bind_cols)
importFrom(dplyr,bind_rows)
importFrom(dplyr,data_frame)
importFrom(dplyr,desc)
importFrom(dplyr,distinct_at)
importFrom(dplyr,everything)
importFrom(dplyr,filter)
importFrom(dplyr,first)
importFrom(dplyr,funs)
importFrom(dplyr,group_by)
importFrom(dplyr,group_by_)
importFrom(dplyr,group_by_at)
importFrom(dplyr,inner_join)
importFrom(dplyr,last)
importFrom(dplyr,min_rank)
importFrom(dplyr,mutate)
importFrom(dplyr,mutate_)
importFrom(dplyr,mutate_all)
importFrom(dplyr,mutate_at)
importFrom(dplyr,mutate_if)
importFrom(dplyr,near)
importFrom(dplyr,pull)
importFrom(dplyr,rename)
importFrom(dplyr,sample_n)
importFrom(dplyr,select)
importFrom(dplyr,select_at)
importFrom(dplyr,slice)
importFrom(dplyr,summarise)
importFrom(dplyr,summarise_)
importFrom(dplyr,summarise_all)
importFrom(dplyr,summarise_at)
importFrom(dplyr,summarize)
importFrom(dplyr,ungroup)
importFrom(e1071,naiveBayes)
importFrom(ggplot2,autoplot)
importFrom(ggrepel,geom_text_repel)
importFrom(graphics,par)
importFrom(import,from)
importFrom(lubridate,is.Date)
importFrom(lubridate,now)
importFrom(magrittr,"%<>%")
importFrom(magrittr,"%>%")
importFrom(magrittr,"%T>%")
importFrom(magrittr,extract2)
importFrom(magrittr,set_colnames)
importFrom(magrittr,set_names)
importFrom(magrittr,set_rownames)
importFrom(nnet,nnet)
importFrom(nnet,nnet.formula)
importFrom(patchwork,plot_annotation)
importFrom(patchwork,wrap_plots)
importFrom(pdp,partial)
importFrom(psych,cohen.kappa)
importFrom(radiant.data,launch)
importFrom(radiant.data,set_attr)
importFrom(radiant.data,visualize)
importFrom(ranger,ranger)
importFrom(rlang,":=")
importFrom(rlang,.data)
importFrom(rlang,parse_exprs)
importFrom(rpart,prune.rpart)
importFrom(rpart,rpart)
importFrom(rpart,rpart.control)
importFrom(sandwich,vcovHC)
importFrom(scales,percent)
importFrom(shiny,getDefaultReactiveDomain)
importFrom(shiny,incProgress)
importFrom(shiny,withProgress)
importFrom(stats,anova)
importFrom(stats,as.formula)
importFrom(stats,binomial)
importFrom(stats,coef)
importFrom(stats,confint)
importFrom(stats,confint.default)
importFrom(stats,contrasts)
importFrom(stats,cor)
importFrom(stats,deviance)
importFrom(stats,dnorm)
importFrom(stats,family)
importFrom(stats,formula)
importFrom(stats,glm)
importFrom(stats,lm)
importFrom(stats,logLik)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,na.omit)
importFrom(stats,pnorm)
importFrom(stats,predict)
importFrom(stats,pt)
importFrom(stats,qnorm)
importFrom(stats,qt)
importFrom(stats,quantile)
importFrom(stats,rbinom)
importFrom(stats,relevel)
importFrom(stats,residuals)
importFrom(stats,rlnorm)
importFrom(stats,rnorm)
importFrom(stats,rpois)
importFrom(stats,runif)
importFrom(stats,sd)
importFrom(stats,setNames)
importFrom(stats,step)
importFrom(stats,terms)
importFrom(stats,terms.formula)
importFrom(stats,update)
importFrom(stats,weighted.mean)
importFrom(stats,wilcox.test)
importFrom(stringi,stri_trans_general)
importFrom(stringr,str_match)
importFrom(tidyr,gather)
importFrom(tidyr,spread)
importFrom(tidyselect,where)
importFrom(utils,as.relistable)
importFrom(utils,capture.output)
importFrom(utils,combn)
importFrom(utils,head)
importFrom(utils,relist)
importFrom(utils,tail)
importFrom(utils,write.table)
importFrom(vip,vi)
importFrom(xgboost,xgb.importance)
importFrom(xgboost,xgboost)
importFrom(yaml,yaml.load)
radiant.model/R/svm.R
View file @
53dd5828
...
...
@@ -7,7 +7,7 @@ svm <- function(dataset, rvar, evar,
form
,
data_filter
=
""
,
arr
=
""
,
rows
=
NULL
,
envir
=
parent.frame
())
{
## ---- 参数合法性检查
(SVM特有)
----
## ---- 参数合法性检查----
valid_kernels
<-
c
(
"linear"
,
"radial"
,
"poly"
,
"sigmoid"
)
if
(
!
kernel
%in%
valid_kernels
)
{
return
(
paste0
(
"Kernel must be one of: "
,
paste
(
valid_kernels
,
collapse
=
", "
))
%>%
...
...
@@ -38,11 +38,32 @@ svm <- function(dataset, rvar, evar,
df_name
<-
if
(
is_string
(
dataset
))
dataset
else
deparse
(
substitute
(
dataset
))
dataset
<-
get_data
(
dataset
,
vars
,
filt
=
data_filter
,
arr
=
arr
,
rows
=
rows
,
envir
=
envir
)
is_cat
<-
sapply
(
dataset
,
function
(
x
)
is.factor
(
x
)
||
is.character
(
x
)
||
is.logical
(
x
))
cat_vars
<-
names
(
is_cat
)[
is_cat
]
if
(
length
(
cat_vars
)
>
0
)
{
# 2. 字符型转因子(确保后续编码顺序一致)
for
(
var
in
cat_vars
)
{
if
(
is.character
(
dataset
[[
var
]]))
{
dataset
[[
var
]]
<-
as.factor
(
dataset
[[
var
]])
}
# 3. 因子/逻辑型转数值(按原始水平顺序编码,如level1→1, level2→2...)
dataset
[[
var
]]
<-
as.numeric
(
dataset
[[
var
]])
}
}
# 缺失值处理
dataset
<-
na.omit
(
dataset
)
if
(
nrow
(
dataset
)
==
0
)
{
return
(
"No valid samples after removing missing values. Please check your data."
%>%
add_class
(
"svm"
))
}
# 标准化
if
(
"standardize"
%in%
check
)
{
dataset
<-
scale_df
(
dataset
)
}
## ---- 3. 分类任务的响应变量
(转为因子)
----
## ---- 3. 分类任务的响应变量 ----
if
(
type
==
"classification"
)
{
dataset
[[
rvar
]]
<-
as.factor
(
dataset
[[
rvar
]])
if
(
lev
==
""
)
{
...
...
@@ -71,10 +92,19 @@ svm <- function(dataset, rvar, evar,
if
(
!
is.na
(
seed
))
set.seed
(
seed
)
## ---- 6. 训练模型 ----
model
<-
do.call
(
e
1071
::
svm
,
svm_input
)
## ---- 模型训练----
model
<-
try
({
do.call
(
e
1071
::
svm
,
svm_input
)
},
silent
=
TRUE
)
## ---- 7. 附加关键信息 ----
if
(
inherits
(
model
,
"try-error"
))
{
return
(
paste
(
"Model training failed:"
,
attr
(
model
,
"condition"
)
$
message
)
%>%
add_class
(
"svm"
))
}
if
(
is.null
(
model
)
||
!
inherits
(
model
,
"svm"
))
{
return
(
"Model training failed: Generated SVM model is invalid"
%>%
add_class
(
"svm"
))
}
## ---- 附加关键信息----
model
$
df_name
<-
df_name
model
$
rvar
<-
rvar
model
$
evar
<-
evar
...
...
@@ -86,7 +116,21 @@ svm <- function(dataset, rvar, evar,
model
$
gamma
<-
gamma
model
$
kernel
<-
kernel
as.list
(
environment
())
%>%
add_class
(
c
(
"svm"
,
"model"
))
## ---- 返回模型对象----
list
(
model
=
model
,
df_name
=
df_name
,
rvar
=
rvar
,
evar
=
evar
,
type
=
type
,
lev
=
lev
,
wtsname
=
wtsname
,
seed
=
seed
,
cost
=
cost
,
gamma
=
gamma
,
kernel
=
kernel
,
dataset
=
dataset
)
%>%
add_class
(
c
(
"svm"
,
"model"
))
}
#' Center or standardize variables in a data frame
...
...
@@ -101,9 +145,9 @@ scale_df <- function(dataset, center = TRUE, scale = TRUE,
descr
<-
attr
(
dataset
,
"description"
)
# 保留原描述属性
if
(
calc
)
{
# 计算均值
(忽略NA)
# 计算均值
ms
<-
sapply
(
dataset
[,
cn
,
drop
=
FALSE
],
function
(
x
)
mean
(
x
,
na.rm
=
TRUE
))
# 计算标准差
(忽略NA,样本标准差ddof=1,避免除以零)
# 计算标准差
sds
<-
sapply
(
dataset
[,
cn
,
drop
=
FALSE
],
function
(
x
)
{
sd_val
<-
sd
(
x
,
na.rm
=
TRUE
)
ifelse
(
sd_val
==
0
,
1
,
sd_val
)
...
...
@@ -135,7 +179,7 @@ summary.svm <- function(object, prn = TRUE, ...) {
svm_model
<-
object
$
model
n_obs
<-
nrow
(
object
$
dataset
)
wtsname
<-
object
$
wtsname
# 可能是 NULL 或长度 0 字符
wtsname
<-
object
$
wtsname
cat
(
"Support Vector Machine\n"
)
cat
(
sprintf
(
"Kernel type : %s (%s)\n"
,
object
$
kernel
,
object
$
type
))
...
...
@@ -146,7 +190,7 @@ summary.svm <- function(object, prn = TRUE, ...) {
}
cat
(
sprintf
(
"Explanatory variables: %s\n"
,
paste
(
object
$
evar
,
collapse
=
", "
)))
if
(
!
is.null
(
wtsname
)
&&
length
(
wtsname
)
&&
wtsname
!=
""
)
{
if
(
!
is.null
(
wtsname
)
&&
nzchar
(
wtsname
)
)
{
cat
(
sprintf
(
"Weights used : %s\n"
,
wtsname
))
}
...
...
@@ -156,56 +200,79 @@ summary.svm <- function(object, prn = TRUE, ...) {
}
if
(
!
is.na
(
object
$
seed
))
cat
(
sprintf
(
"Seed : %s\n"
,
object
$
seed
))
# 支持向量计数
if
(
object
$
type
==
"classification"
)
{
n_sv_per_class
<-
svm_model
$
nSV
total_sv
<-
sum
(
n_sv_per_class
)
n_sv_per_class
<-
if
(
!
is.null
(
svm_model
$
nSV
))
svm_model
$
nSV
else
c
(
0
,
0
)
total_sv
<-
sum
(
n_sv_per_class
)
sv_info
<-
paste
(
sprintf
(
"class %s: %d"
,
seq_along
(
n_sv_per_class
),
n_sv_per_class
),
collapse
=
", "
)
cat
(
sprintf
(
"Support vectors : %d (%s)\n"
,
total_sv
,
sv_info
))
}
else
{
total_sv
<-
sum
(
svm_model
$
nSV
)
total_sv
<-
if
(
!
is.null
(
svm_model
$
index
))
length
(
svm_model
$
index
)
else
0
cat
(
sprintf
(
"Support vectors : %d\n"
,
total_sv
))
}
## ---- 权重样本数计算同样保护 ----
if
(
!
is.null
(
wtsname
)
&&
length
(
wtsname
)
&&
wtsname
!=
""
)
{
nr_obs
<-
n_obs
if
(
!
is.null
(
wtsname
)
&&
nzchar
(
wtsname
)
&&
wtsname
%in%
names
(
object
$
dataset
)
)
{
weights_values
<-
as.numeric
(
object
$
dataset
[[
wtsname
]])
nr_obs
<-
if
(
all
(
!
is.na
(
weights_values
)))
sum
(
weights_values
,
na.rm
=
TRUE
)
else
n_obs
}
else
{
nr_obs
<-
n_obs
weights_clean
<-
weights_values
[
!
is.na
(
weights_values
)]
if
(
length
(
weights_clean
)
>
0
)
{
sum_weights
<-
sum
(
weights_clean
)
if
(
sum_weights
>
0
)
nr_obs
<-
sum_weights
}
}
cat
(
sprintf
(
"Nr_obs : %s\n"
,
format_nr
(
nr_obs
,
dec
=
0
)))
## ---- 系数输出(仅线性核) ----
if
(
prn
)
{
cat
(
"Coefficients/Support Vectors:\n"
)
if
(
object
$
kernel
==
"linear"
)
{
if
(
is.null
(
svm_model
$
w
)
||
length
(
svm_model
$
w
)
==
0
)
{
cat
(
" Linear kernel coefficients not available (possible reasons:\n"
)
cat
(
" - Insufficient support vectors (data may be linearly inseparable)\n"
)
cat
(
" - Model training did not converge)\n"
)
if
(
object
$
type
==
"classification"
)
{
cat
(
sprintf
(
" Support vectors count: %d\n"
,
sum
(
svm_model
$
nSV
)))
}
}
else
{
feat_coefs_raw
<-
as.numeric
(
svm_model
$
w
)
bias
<-
as.numeric
(
-
svm_model
$
rho
)
n_evar
<-
length
(
object
$
evar
)
feat_coefs
<-
matrix
(
data
=
feat_coefs_raw
,
nrow
=
1
,
ncol
=
n_evar
,
byrow
=
TRUE
,
dimnames
=
list
(
NULL
,
object
$
evar
))
coef_data
<-
data.frame
(
Variable
=
c
(
object
$
evar
,
"bias"
),
Value
=
as.numeric
(
c
(
as.vector
(
feat_coefs
),
bias
)),
stringsAsFactors
=
FALSE
,
check.names
=
FALSE
)
for
(
i
in
seq
(
1
,
nrow
(
coef_data
),
2
))
{
if
(
i
+
1
>
nrow
(
coef_data
))
{
cat
(
sprintf
(
" %-12s: %.2f\n"
,
coef_data
$
Variable
[
i
],
coef_data
$
Value
[
i
]))
}
else
{
cat
(
sprintf
(
" %-12s: %.2f %-12s: %.2f\n"
,
coef_data
$
Variable
[
i
],
coef_data
$
Value
[
i
],
coef_data
$
Variable
[
i
+1
],
coef_data
$
Value
[
i
+1
]))
feat_coefs_raw
<-
rep
(
0
,
length
(
object
$
evar
))
bias
<-
0
if
(
!
is.null
(
svm_model
$
terms
)
&&
!
is.null
(
svm_model
$
coefs
)
&&
!
is.null
(
svm_model
$
SV
))
{
tryCatch
({
model_vars
<-
attr
(
svm_model
$
terms
,
"term.labels"
)
coef_mat
<-
as.matrix
(
svm_model
$
coefs
)
sv_mat
<-
as.matrix
(
svm_model
$
SV
)
if
(
nrow
(
coef_mat
)
>
0
&&
nrow
(
sv_mat
)
>
0
)
{
w_full
<-
as.numeric
(
t
(
coef_mat
)
%*%
sv_mat
)
sv_colnames
<-
colnames
(
sv_mat
)
if
(
!
is.null
(
sv_colnames
))
{
# 按变量名精确匹配
for
(
i
in
seq_along
(
object
$
evar
))
{
var_name
<-
object
$
evar
[
i
]
matching_cols
<-
which
(
sv_colnames
==
var_name
)
if
(
length
(
matching_cols
)
>
0
)
{
feat_coefs_raw
[
i
]
<-
sum
(
w_full
[
matching_cols
])
}
}
}
else
{
# 无列名则按位置取
n_evar
<-
length
(
object
$
evar
)
feat_coefs_raw
[
1
:
min
(
n_evar
,
length
(
w_full
))]
<-
w_full
[
1
:
min
(
n_evar
,
length
(
w_full
))]
}
bias
<-
as.numeric
(
-
svm_model
$
rho
)
}
},
error
=
function
(
e
)
{
warning
(
"系数提取失败: "
,
e
$
message
,
call.
=
FALSE
)
})
}
# 输出
coef_data
<-
data.frame
(
Variable
=
c
(
object
$
evar
,
"bias"
),
Value
=
c
(
feat_coefs_raw
,
bias
),
stringsAsFactors
=
FALSE
)
for
(
i
in
seq
(
1
,
nrow
(
coef_data
),
2
))
{
if
(
i
+
1
>
nrow
(
coef_data
))
{
cat
(
sprintf
(
" %-12s: %.2f\n"
,
coef_data
$
Variable
[
i
],
coef_data
$
Value
[
i
]))
}
else
{
cat
(
sprintf
(
" %-12s: %.2f %-12s: %.2f\n"
,
coef_data
$
Variable
[
i
],
coef_data
$
Value
[
i
],
coef_data
$
Variable
[
i
+1
],
coef_data
$
Value
[
i
+1
]))
}
}
}
else
{
...
...
@@ -256,134 +323,234 @@ plot.svm <- function(x,
return
(
"No valid plots selected for SVM"
)
}
## 返回 patchwork 对象
(Shiny 自动打印)
## 返回 patchwork 对象
patchwork
::
wrap_plots
(
plot_list
,
ncol
=
min
(
2
,
length
(
plot_list
)))
%>%
{
if
(
isTRUE
(
shiny
))
.
else
print
(
.
)
}
}
#' Predict method for
the svm function
#' Predict method for
SVM model
#' @export
predict.svm
<-
function
(
object
,
pred_data
=
NULL
,
pred_cmd
=
""
,
dec
=
3
,
envir
=
parent.frame
(),
...
)
{
#
#
1. 基础校验
predict.svm
<-
function
(
object
,
pred_data
=
NULL
,
pred_cmd
=
""
,
dec
=
3
,
envir
=
parent.frame
(),
...
)
{
# 1. 基础校验
if
(
is.character
(
object
))
return
(
object
)
if
(
!
inherits
(
object
,
"svm"
))
stop
(
"Object must be of class 'svm'"
)
if
(
is.null
(
object
$
model
)
||
!
inherits
(
object
$
model
,
"svm"
))
{
err_msg
<-
"Prediction failed: Invalid SVM model"
err_obj
<-
structure
(
list
(
error
=
err_msg
),
class
=
"svm.predict.error"
)
return
(
err_obj
)
}
svm_info
<-
object
# 2. 处理预测数据
if
(
is.character
(
pred_data
)
&&
nzchar
(
pred_data
))
{
if
(
!
exists
(
pred_data
,
envir
=
envir
))
{
err_obj
<-
structure
(
list
(
error
=
sprintf
(
"Dataset '%s' not found"
,
pred_data
)),
class
=
"svm.predict.error"
)
return
(
err_obj
)
}
pred_data
<-
get
(
pred_data
,
envir
=
envir
)
}
has_data
<-
!
is.null
(
pred_data
)
&&
is.data.frame
(
pred_data
)
&&
nrow
(
pred_data
)
>
0
has_cmd
<-
is.character
(
pred_cmd
)
&&
nzchar
(
pred_cmd
)
if
(
!
has_data
&&
!
has_cmd
)
{
err_obj
<-
structure
(
list
(
error
=
"Please select data and/or specify a command to generate predictions."
),
class
=
"svm.predict.error"
)
return
(
err_obj
)
}
## 2.1 确定预测数据源
if
(
is.null
(
pred_data
)
||
is.character
(
pred_data
))
{
# 当pred_data为NULL或字符(数据集名)时,使用训练数据
pred_data_raw
<-
object
$
dataset
[,
object
$
evar
,
drop
=
FALSE
]
## ensure you have a name for the prediction dataset
if
(
is.data.frame
(
pred_data
))
{
df_name
<-
deparse
(
substitute
(
pred_data
))
}
else
{
# 当pred_data是数据框时,直接使用
pred_data_raw
<-
as.data.frame
(
pred_data
)
df_name
<-
pred_data
}
pred_names
<-
colnames
(
pred_data_raw
)
missing_vars
<-
setdiff
(
object
$
evar
,
pred_names
)
if
(
length
(
missing_vars
)
>
0
)
{
msg
<-
paste0
(
"NA\n"
)
return
(
msg
%>%
add_class
(
"svm.predict"
))
}
pred_data
<-
pred_data_raw
[,
object
$
evar
,
drop
=
FALSE
]
if
(
!
is.empty
(
pred_cmd
))
{
pred_cmd
<-
gsub
(
"\\s{2,}"
,
" "
,
pred_cmd
)
%>%
gsub
(
";\\s+"
,
";"
,
.
)
%>%
strsplit
(
";"
)[[
1
]]
for
(
cmd
in
pred_cmd
)
{
if
(
grepl
(
"="
,
cmd
))
{
var_val
<-
strsplit
(
trimws
(
cmd
),
"="
)[[
1
]]
var
<-
trimws
(
var_val
[
1
])
val
<-
try
(
eval
(
parse
(
text
=
trimws
(
var_val
[
2
])),
envir
=
envir
),
silent
=
TRUE
)
if
(
!
inherits
(
val
,
"try-error"
))
pred_data
[[
var
]]
<-
val
else
warning
(
sprintf
(
"Invalid command '%s': using original values"
,
cmd
))
}
# 3. 预测核心函数 - 修改参数名以匹配NN
pfun
<-
function
(
model
,
pred
,
se
,
conf_lev
)
{
# 验证数据
missing_vars
<-
setdiff
(
svm_info
$
evar
,
colnames
(
pred
))
if
(
length
(
missing_vars
)
>
0
)
{
return
(
paste
(
"Missing variables:"
,
paste
(
missing_vars
,
collapse
=
", "
)))
}
pred_data
<-
unique
(
pred_data
)
}
## 3. 变量类型与因子水平校验
train_types
<-
sapply
(
object
$
dataset
[,
object
$
evar
],
class
)
pred_types
<-
sapply
(
pred_data
,
class
)
type_mismatch
<-
names
(
which
(
train_types
!=
pred_types
))
if
(
length
(
type_mismatch
)
>
0
)
{
return
(
paste0
(
"Variable type mismatch (train vs pred):\n"
,
paste
(
sprintf
(
" %s: %s vs %s"
,
type_mismatch
,
train_types
[
type_mismatch
],
pred_types
[
type_mismatch
]),
collapse
=
"\n"
))
%>%
add_class
(
"svm.predict"
))
}
for
(
var
in
object
$
evar
)
{
if
(
is.factor
(
object
$
dataset
[[
var
]]))
{
train_levs
<-
levels
(
object
$
dataset
[[
var
]])
pred_levs
<-
levels
(
pred_data
[[
var
]])
if
(
!
identical
(
train_levs
,
pred_levs
))
{
pred_data
[[
var
]]
<-
factor
(
pred_data
[[
var
]],
levels
=
train_levs
)
warning
(
sprintf
(
"Aligned factor levels for '%s' to match training data"
,
var
))
# 内部标准化
ms
<-
attr
(
svm_info
$
dataset
,
"radiant_ms"
)
sds
<-
attr
(
svm_info
$
dataset
,
"radiant_sds"
)
if
(
!
is.null
(
ms
)
&&
!
is.null
(
sds
))
{
isNum
<-
sapply
(
pred
,
is.numeric
)
cn
<-
names
(
isNum
)[
isNum
]
for
(
var
in
cn
)
{
if
(
var
%in%
names
(
ms
)
&&
var
%in%
names
(
sds
)
&&
sds
[
var
]
!=
0
)
{
pred
[[
var
]]
<-
(
pred
[[
var
]]
-
ms
[
var
])
/
sds
[
var
]
}
}
}
}
## 4. 标准化对齐
train_ms
<-
attr
(
object
$
dataset
,
"radiant_ms"
)
train_sds
<-
attr
(
object
$
dataset
,
"radiant_sds"
)
train_sf
<-
attr
(
object
$
dataset
,
"radiant_sf"
)
%||%
2
if
(
!
is.null
(
train_ms
)
&&
!
is.null
(
train_sds
))
{
pred_data
<-
scale_df
(
dataset
=
pred_data
,
center
=
TRUE
,
scale
=
TRUE
,
sf
=
train_sf
,
wts
=
NULL
,
calc
=
FALSE
)
attr
(
pred_data
,
"radiant_ms"
)
<-
train_ms
attr
(
pred_data
,
"radiant_sds"
)
<-
train_sds
}
## 5. 生成预测值
predict_args
<-
list
(
object
=
object
$
model
,
newdata
=
pred_data
,
na.action
=
na.omit
)
pred_result
<-
if
(
object
$
type
==
"classification"
)
{
predict_args
$
type
<-
"class"
pred_class
<-
do.call
(
predict
,
predict_args
)
if
(
isTRUE
(
object
$
model
$
param
$
probability
))
{
predict_args
$
type
<-
"probabilities"
pred_prob
<-
do.call
(
predict
,
predict_args
)
%>%
as.data.frame
()
%>%
set_colnames
(
paste0
(
"Prob_"
,
colnames
(
.
)))
data.frame
(
Predicted_Class
=
as.character
(
pred_class
),
pred_prob
,
stringsAsFactors
=
FALSE
# 调试信息:检查模型是否启用了概率
cat
(
"Model probability enabled:"
,
svm_info
$
model
$
prob.model
,
"\n"
)
# 执行预测
pred_result
<-
try
({
predict
(
svm_info
$
model
,
newdata
=
pred
,
probability
=
TRUE
,
# 始终设置为TRUE
decision.values
=
TRUE
)
},
silent
=
TRUE
)
if
(
inherits
(
pred_result
,
"try-error"
))
{
return
(
paste
(
"Prediction failed:"
,
attr
(
pred_result
,
"condition"
)
$
message
))
}
# 4. 结果整理
if
(
svm_info
$
type
==
"classification"
)
{
# 正确的属性名是"probabilities"
prob_mat
<-
attr
(
pred_result
,
"probabilities"
)
if
(
!
is.null
(
prob_mat
))
{
# 调试:显示概率矩阵的列名
cat
(
"Probability matrix columns:"
,
paste
(
colnames
(
prob_mat
),
collapse
=
", "
),
"\n"
)
# 更智能的列名匹配
target_level
<-
as.character
(
svm_info
$
lev
)
# 尝试多种匹配方式
matching_col
<-
NULL
# 1. 精确匹配
exact_match
<-
grep
(
paste0
(
"^"
,
target_level
,
"$"
),
colnames
(
prob_mat
),
value
=
TRUE
,
ignore.case
=
TRUE
)
if
(
length
(
exact_match
)
>
0
)
{
matching_col
<-
exact_match
}
# 2. 部分匹配
else
{
partial_match
<-
grep
(
target_level
,
colnames
(
prob_mat
),
value
=
TRUE
,
ignore.case
=
TRUE
)
if
(
length
(
partial_match
)
>
0
)
{
matching_col
<-
partial_match
[
1
]
# 取第一个匹配
}
}
# 3. 如果还是没有匹配,尝试逻辑值匹配
if
(
is.null
(
matching_col
)
||
length
(
matching_col
)
==
0
)
{
if
(
target_level
%in%
c
(
"TRUE"
,
"Yes"
,
"1"
,
"1.0"
))
{
logical_match
<-
grep
(
"TRUE"
,
colnames
(
prob_mat
),
value
=
TRUE
,
ignore.case
=
TRUE
)
if
(
length
(
logical_match
)
>
0
)
{
matching_col
<-
logical_match
[
1
]
}
}
else
if
(
target_level
%in%
c
(
"FALSE"
,
"No"
,
"0"
,
"0.0"
))
{
logical_match
<-
grep
(
"FALSE"
,
colnames
(
prob_mat
),
value
=
TRUE
,
ignore.case
=
TRUE
)
if
(
length
(
logical_match
)
>
0
)
{
matching_col
<-
logical_match
[
1
]
}
}
}
# 4. 作为最后手段,使用第一列
if
(
is.null
(
matching_col
)
||
length
(
matching_col
)
==
0
)
{
if
(
ncol
(
prob_mat
)
>
0
)
{
matching_col
<-
colnames
(
prob_mat
)[
1
]
cat
(
"Warning: Using first probability column '"
,
matching_col
,
"' for level '"
,
target_level
,
"'\n"
)
}
}
if
(
!
is.null
(
matching_col
)
&&
length
(
matching_col
)
>
0
)
{
surv_prob
<-
as.numeric
(
prob_mat
[,
matching_col
])
pred_df
<-
data.frame
(
Prediction
=
round
(
surv_prob
,
dec
),
stringsAsFactors
=
FALSE
)
}
else
{
# 备用方案:使用决策值
decision_vals
<-
attr
(
pred_result
,
"decision.values"
)
if
(
!
is.null
(
decision_vals
))
{
# 将决策值转换为概率
pred_prob
<-
1
/
(
1
+
exp
(
-
decision_vals
))
pred_df
<-
data.frame
(
Prediction
=
round
(
pred_prob
,
dec
),
stringsAsFactors
=
FALSE
)
}
else
{
# 最后手段:使用预测类
pred_class
<-
as.character
(
pred_result
)
pred_prob
<-
ifelse
(
pred_class
==
target_level
,
1
,
0
)
pred_df
<-
data.frame
(
Prediction
=
round
(
pred_prob
,
dec
),
stringsAsFactors
=
FALSE
)
}
}
}
else
{
# 备用方案:使用决策值
decision_vals
<-
attr
(
pred_result
,
"decision.values"
)
if
(
!
is.null
(
decision_vals
))
{
# 将决策值转换为概率
pred_prob
<-
1
/
(
1
+
exp
(
-
decision_vals
))
pred_df
<-
data.frame
(
Prediction
=
round
(
pred_prob
,
dec
),
stringsAsFactors
=
FALSE
)
}
else
{
# 最后手段:使用预测类
target_level
<-
as.character
(
svm_info
$
lev
)
pred_class
<-
as.character
(
pred_result
)
pred_prob
<-
ifelse
(
pred_class
==
target_level
,
1
,
0
)
pred_df
<-
data.frame
(
Prediction
=
round
(
pred_prob
,
dec
),
stringsAsFactors
=
FALSE
)
}
}
}
else
{
data.frame
(
Predicted_Class
=
as.character
(
pred_class
),
stringsAsFactors
=
FALSE
)
# 回归模型
pred_values
<-
as.numeric
(
pred_result
)
# 关键:添加反标准化处理
rvar_ms
<-
if
(
!
is.null
(
ms
)
&&
svm_info
$
rvar
%in%
names
(
ms
))
ms
[
svm_info
$
rvar
]
else
0
rvar_sds
<-
if
(
!
is.null
(
sds
)
&&
svm_info
$
rvar
%in%
names
(
sds
))
sds
[
svm_info
$
rvar
]
else
1
# 反标准化
pred_values
<-
pred_values
*
rvar_sds
+
rvar_ms
pred_df
<-
data.frame
(
Prediction
=
round
(
pred_values
,
dec
),
stringsAsFactors
=
FALSE
)
}
}
else
{
predict_args
$
type
<-
"response"
pred_val
<-
do.call
(
predict
,
predict_args
)
pred_val
<-
as.numeric
(
pred_val
)
data.frame
(
Predicted_Value
=
round
(
pred_val
,
dec
),
stringsAsFactors
=
FALSE
)
}
pred_result
<-
cbind
(
pred_data
,
pred_result
)
attr
(
pred_result
,
"svm_meta"
)
<-
list
(
kernel
=
object
$
kernel
,
cost
=
object
$
cost
,
gamma
=
if
(
object
$
kernel
!=
"linear"
)
object
$
gamma
else
NA
,
seed
=
object
$
seed
,
train_data
=
object
$
df_name
,
model_type
=
object
$
type
return
(
pred_df
)
}
# 5. 调用预测框架 - 与NN完全一致
result
<-
predict_model
(
object
,
pfun
,
"svm.predict"
,
# 模型类型
pred_data
,
pred_cmd
,
conf_lev
=
0.95
,
se
=
FALSE
,
dec
,
envir
=
envir
)
%>%
set_attr
(
"radiant_pred_data"
,
df_name
)
# 6. 结果元数据
if
(
inherits
(
result
,
"svm.predict.error"
))
{
return
(
result
$
error
)
}
if
(
is.character
(
result
))
return
(
result
)
result
<-
add_class
(
result
,
"svm.predict"
)
attr
(
result
,
"svm_meta"
)
<-
list
(
model_type
=
svm_info
$
type
,
kernel
=
svm_info
$
kernel
,
cost
=
svm_info
$
cost
,
gamma
=
if
(
svm_info
$
kernel
!=
"linear"
)
svm_info
$
gamma
else
NA
,
seed
=
svm_info
$
seed
,
train_data
=
svm_info
$
df_name
,
dec
=
dec
)
attr
(
pred_result
,
"class"
)
<-
c
(
"svm.predict"
,
"data.frame"
)
return
(
pred_result
)
return
(
result
)
}
#' Print method for predict.svm
...
...
@@ -396,16 +563,30 @@ print.svm.predict <- function(x, ..., n = 10) {
if
(
!
inherits
(
x
,
"svm.predict"
))
stop
(
"Object must be of class 'svm.predict'"
)
meta
<-
attr
(
x
,
"svm_meta"
)
if
(
is.null
(
meta
))
{
meta
<-
list
(
model_type
=
"Unknown"
,
kernel
=
"Unknown"
,
cost
=
NA
,
gamma
=
NA
,
seed
=
NA
,
train_data
=
"Unknown"
,
dec
=
1
)
}
dec
<-
meta
$
dec
%||%
1
n_pred
<-
nrow
(
x
)
show_n
<-
if
(
n
<
0
||
n
>=
n_pred
)
n_pred
else
n
cat
(
"SVM Predictions\n"
)
cat
(
sprintf
(
"Model Type : %s\n"
,
tools
::
toTitleCase
(
meta
$
model_type
)))
cat
(
sprintf
(
"Kernel : %s\n"
,
meta
$
kernel
))
cat
(
sprintf
(
"Cost (C) : %.2f\n"
,
meta
$
cost
))
model_type
<-
if
(
is.empty
(
meta
$
model_type
))
"Unknown"
else
as.character
(
meta
$
model_type
)
cat
(
sprintf
(
"Model Type : %s\n"
,
tools
::
toTitleCase
(
model_type
)))
cat
(
sprintf
(
"Kernel : %s\n"
,
ifelse
(
is.empty
(
meta
$
kernel
),
"Unknown"
,
meta
$
kernel
)))
cat
(
sprintf
(
"Cost (C) : %.2f\n"
,
ifelse
(
is.na
(
meta
$
cost
),
0
,
meta
$
cost
)))
if
(
!
is.na
(
meta
$
gamma
))
cat
(
sprintf
(
"Gamma : %.2f\n"
,
meta
$
gamma
))
if
(
!
is.na
(
meta
$
seed
))
cat
(
sprintf
(
"Seed : %s\n"
,
meta
$
seed
))
cat
(
sprintf
(
"Training Dataset : %s\n"
,
meta
$
train_data
))
cat
(
sprintf
(
"Training Dataset : %s\n"
,
ifelse
(
is.empty
(
meta
$
train_data
),
"Unknown"
,
meta
$
train_data
)
))
cat
(
sprintf
(
"Total Predictions : %d\n"
,
n_pred
))
if
(
n_pred
==
0
)
{
...
...
@@ -413,8 +594,13 @@ print.svm.predict <- function(x, ..., n = 10) {
return
(
invisible
(
x
))
}
x_show
<-
x
[
1
:
show_n
,
,
drop
=
FALSE
]
# 确保保持数据框结构
x_show
<-
x
[
1
:
show_n
,
,
drop
=
FALSE
]
num_cols
<-
sapply
(
x_show
,
is.numeric
)
x_show
[
num_cols
]
<-
lapply
(
x_show
[
num_cols
],
function
(
col
)
{
sprintf
(
paste0
(
"%."
,
dec
,
"f"
),
col
)
})
# 重新计算列宽
col_widths
<-
sapply
(
colnames
(
x_show
),
function
(
cn
)
{
max
(
nchar
(
cn
),
max
(
nchar
(
as.character
(
x_show
[[
cn
]])),
na.rm
=
TRUE
))
})
...
...
@@ -435,38 +621,13 @@ print.svm.predict <- function(x, ..., n = 10) {
}
if
(
show_n
<
n_pred
)
{
cat
(
sprintf
(
"\n... (showing first %d of %d
; use 'n=-1' to view all
)\n"
,
cat
(
sprintf
(
"\n... (showing first %d of %d)\n"
,
show_n
,
n_pred
))
}
return
(
invisible
(
x
))
}
#' Cross-validation for SVM
#' @export
cv.svm
<-
function
(
object
,
K
=
5
,
repeats
=
1
,
kernel
=
c
(
"linear"
,
"radial"
),
cost
=
2
^
(
-2
:
2
),
gamma
=
2
^
(
-2
:
2
),
seed
=
1234
,
trace
=
TRUE
,
fun
,
...
)
{
if
(
!
inherits
(
object
,
"svm"
))
stop
(
"Object must be of class 'svm'"
)
tune_grid
<-
expand.grid
(
kernel
=
kernel
,
cost
=
cost
,
gamma
=
if
(
any
(
kernel
%in%
c
(
"radial"
,
"poly"
,
"sigmoid"
)))
gamma
else
NA
)
cv_results
<-
data.frame
(
mean_perf
=
rep
(
NA
,
nrow
(
tune_grid
)),
std_perf
=
rep
(
NA
,
nrow
(
tune_grid
)),
kernel
=
tune_grid
$
kernel
,
cost
=
tune_grid
$
cost
,
gamma
=
tune_grid
$
gamma
)
cv_results
}
# ---- 决策边界 ----
#' @export
svm_boundary_plot
<-
function
(
object
,
size
,
custom
)
{
...
...
radiant.model/inst/app/init.R
View file @
53dd5828
...
...
@@ -79,11 +79,12 @@ options(
i
18
n
$
t
(
"Estimate"
),
tabPanel
(
i
18
n
$
t
(
"Linear regression (OLS)"
),
uiOutput
(
"regress"
)),
tabPanel
(
i
18
n
$
t
(
"Logistic regression (GLM)"
),
uiOutput
(
"logistic"
)),
tabPanel
(
i
18
n
$
t
(
"Cox Proportional Hazards Regression"
),
uiOutput
(
"coxp"
)),
tabPanel
(
i
18
n
$
t
(
"Multinomial logistic regression (MNL)"
),
uiOutput
(
"mnl"
)),
tabPanel
(
i
18
n
$
t
(
"Naive Bayes"
),
uiOutput
(
"nb"
)),
tabPanel
(
i
18
n
$
t
(
"Neural Network"
),
uiOutput
(
"nn"
)),
tabPanel
(
i
18
n
$
t
(
"Support Vector Machine (SVM)"
),
uiOutput
(
"svm"
)),
"----"
,
i
18
n
$
t
(
"Survival Analysis"
),
tabPanel
(
i
18
n
$
t
(
"Cox Proportional Hazards Regression"
),
uiOutput
(
"coxp"
)),
"----"
,
i
18
n
$
t
(
"Trees"
),
tabPanel
(
i
18
n
$
t
(
"Classification and regression trees"
),
uiOutput
(
"crtree"
)),
tabPanel
(
i
18
n
$
t
(
"Random Forest"
),
uiOutput
(
"rf"
)),
...
...
radiant.model/inst/app/tools/analysis/svm_ui.R
View file @
53dd5828
...
...
@@ -23,7 +23,7 @@ svm_inputs <- reactive({
svm_args
})
## 预测参数
(保留命令模式,未改动)
## 预测参数
svm_pred_args
<-
as.list
(
if
(
exists
(
"predict.svm"
))
{
formals
(
predict.svm
)
}
else
{
...
...
@@ -52,7 +52,7 @@ svm_pred_inputs <- reactive({
return
(
svm_pred_args
)
})
## 绘图参数
(砍掉vip、pdp、svm_margin)
## 绘图参数
svm_plot_args
<-
as.list
(
if
(
exists
(
"plot.svm"
))
{
formals
(
plot.svm
)
}
else
{
...
...
@@ -77,11 +77,7 @@ output$ui_svm_rvar <- renderUI({
vars
<-
varnames
()[
isNum
]
}
})
init
<-
if
(
input
$
svm_type
==
"classification"
)
{
if
(
is.empty
(
input
$
logit_rvar
))
isolate
(
input
$
svm_rvar
)
else
input
$
logit_rvar
}
else
{
if
(
is.empty
(
input
$
reg_rvar
))
isolate
(
input
$
svm_rvar
)
else
input
$
reg_rvar
}
init
<-
state_single
(
"svm_rvar"
,
vars
,
isolate
(
input
$
svm_rvar
))
selectInput
(
inputId
=
"svm_rvar"
,
label
=
i
18
n
$
t
(
"Response variable:"
),
...
...
@@ -109,11 +105,7 @@ output$ui_svm_evar <- renderUI({
if
(
not_available
(
input
$
svm_rvar
))
return
()
vars
<-
varnames
()
if
(
length
(
vars
)
>
0
)
vars
<-
vars
[
-
which
(
vars
==
input
$
svm_rvar
)]
init
<-
if
(
input
$
svm_type
==
"classification"
)
{
if
(
is.empty
(
input
$
logit_evar
))
isolate
(
input
$
svm_evar
)
else
input
$
logit_evar
}
else
{
if
(
is.empty
(
input
$
reg_evar
))
isolate
(
input
$
svm_evar
)
else
input
$
reg_evar
}
init
<-
state_multiple
(
"svm_evar"
,
vars
,
isolate
(
input
$
svm_evar
))
selectInput
(
inputId
=
"svm_evar"
,
label
=
i
18
n
$
t
(
"Explanatory variables:"
),
...
...
@@ -141,7 +133,7 @@ output$ui_svm_wts <- renderUI({
)
})
## 存储预测值UI
(残差存储已删除)
## 存储预测值UI
output
$
ui_svm_store_pred_name
<-
renderUI
({
init
<-
state_init
(
"svm_store_pred_name"
,
"pred_svm"
)
%>%
sub
(
"\\d{1,}$"
,
""
,
.
)
%>%
...
...
@@ -164,7 +156,7 @@ observeEvent(input$svm_type, {
updateSelectInput
(
session
=
session
,
inputId
=
"svm_plots"
,
selected
=
"none"
)
})
## 绘图选项UI
(已删vip、pdp、svm_margin)
## 绘图选项UI
output
$
ui_svm_plots
<-
renderUI
({
req
(
input
$
svm_type
)
avail_plots
<-
svm_plots
...
...
@@ -178,19 +170,7 @@ output$ui_svm_plots <- renderUI({
)
})
## 数据点数量UI(仅dashboard用,保留)
output
$
ui_svm_nrobs
<-
renderUI
({
nrobs
<-
nrow
(
.get_data
())
choices
<-
c
(
"1,000"
=
1000
,
"5,000"
=
5000
,
"10,000"
=
10000
,
"All"
=
-1
)
%>%
.
[
.
<
nrobs
]
selectInput
(
"svm_nrobs"
,
i
18
n
$
t
(
"Number of data points plotted:"
),
choices
=
choices
,
selected
=
state_single
(
"svm_nrobs"
,
choices
,
1000
)
)
})
## 主UI面板(已删残差存储入口)
## 主UI面板
output
$
ui_svm
<-
renderUI
({
req
(
input
$
dataset
)
tagList
(
...
...
@@ -258,7 +238,7 @@ output$ui_svm <- renderUI({
)
)
),
# 预测面板
(残差存储已删除)
# 预测面板
conditionalPanel
(
condition
=
"input.tabs_svm == 'Predict'"
,
selectInput
(
...
...
@@ -303,19 +283,10 @@ output$ui_svm <- renderUI({
)
)
),
# 绘图面板
(已删vip、pdp、svm_margin)
# 绘图面板
conditionalPanel
(
condition
=
"input.tabs_svm == 'Plot'"
,
uiOutput
(
"ui_svm_plots"
),
conditionalPanel
(
condition
=
"input.svm_plots == 'pred_plot'"
,
uiOutput
(
"ui_svm_incl"
),
uiOutput
(
"ui_svm_incl_int"
)
),
conditionalPanel
(
condition
=
"input.svm_plots == 'dashboard'"
,
uiOutput
(
"ui_svm_nrobs"
)
)
uiOutput
(
"ui_svm_plots"
)
)
),
# 帮助和报告面板
...
...
@@ -327,7 +298,7 @@ output$ui_svm <- renderUI({
)
})
## 绘图尺寸计算
(已删vip、pdp、svm_margin)
## 绘图尺寸计算
svm_plot
<-
reactive
({
if
(
svm_available
()
!=
"available"
)
return
()
if
(
is.empty
(
input
$
svm_plots
,
"none"
))
return
()
...
...
@@ -337,9 +308,6 @@ svm_plot <- reactive({
plot_width
<-
650
if
(
"decision_boundary"
%in%
input
$
svm_plots
)
{
plot_height
<-
500
}
else
if
(
input
$
svm_plots
==
"pred_plot"
)
{
nr_vars
<-
length
(
input
$
svm_incl
)
+
length
(
input
$
svm_incl_int
)
plot_height
<-
max
(
250
,
ceiling
(
nr_vars
/
2
)
*
250
)
}
else
{
plot_height
<-
max
(
500
,
length
(
res
$
evar
)
*
30
)
}
...
...
@@ -349,7 +317,7 @@ svm_plot <- reactive({
svm_plot_width
<-
function
()
svm_plot
()
$
plot_width
%||%
650
svm_plot_height
<-
function
()
svm_plot
()
$
plot_height
%||%
500
## 主输出面板
(已删残差存储)
## 主输出面板
output
$
svm
<-
renderUI
({
register_print_output
(
"summary_svm"
,
".summary_svm"
)
register_print_output
(
"predict_svm"
,
".predict_print_svm"
)
...
...
@@ -393,14 +361,14 @@ svm_available <- reactive({
}
})
## 核心函数
壳子
## 核心函数
.svm
<-
eventReactive
(
input
$
svm_run
,
{
svi
<-
svm_inputs
()
svi
$
envir
<-
r_data
withProgress
(
message
=
i
18
n
$
t
(
"Estimating SVM model"
),
value
=
1
,
do.call
(
svm
,
svi
))
})
## 辅助输出函数
壳子
## 辅助输出函数
.summary_svm
<-
reactive
({
if
(
not_pressed
(
input
$
svm_run
))
return
(
i
18
n
$
t
(
"** Press the Estimate button to estimate the SVM model **"
))
if
(
svm_available
()
!=
"available"
)
return
(
svm_available
())
...
...
@@ -444,7 +412,7 @@ svm_available <- reactive({
withProgress
(
message
=
i
18
n
$
t
(
"Generating SVM plots"
),
value
=
1
,
do.call
(
plot
,
c
(
list
(
x
=
.svm
()),
pinp
)))
})
## 存储预测值
(残差存储已删除)
## 存储预测值
observeEvent
(
input
$
svm_store_pred
,
{
req
(
pressed
(
input
$
svm_run
),
...
...
@@ -457,14 +425,14 @@ observeEvent(input$svm_store_pred, {
base_col_name
<-
fix_names
(
input
$
svm_store_pred_name
)
meta
<-
attr
(
pred_result
,
"svm_meta"
)
pred_cols
<-
if
(
meta
$
model_type
==
"classification"
)
{
colnames
(
pred_result
)[
grepl
(
"^Predicted_Class|^Prob_"
,
colnames
(
pred_result
))
]
pred_cols
<-
if
(
meta
$
model_type
%in%
c
(
"classification"
,
"regression"
)
)
{
colnames
(
pred_result
)[
colnames
(
pred_result
)
==
"Prediction"
]
}
else
{
"Predicted_Value"
NULL
}
new_col_names
<-
if
(
length
(
pred_cols
)
==
1
)
base_col_name
else
{
suffix
<-
gsub
(
"^Predict
ed_|^Prob_
"
,
""
,
pred_cols
)
paste0
(
base_col_name
,
"_"
,
suffix
)
suffix
<-
gsub
(
"^Predict
ion
"
,
""
,
pred_cols
)
paste0
(
base_col_name
,
ifelse
(
suffix
==
""
,
""
,
paste0
(
"_"
,
suffix
))
)
}
colnames
(
pred_result
)[
match
(
pred_cols
,
colnames
(
pred_result
))]
<-
new_col_names
...
...
@@ -510,9 +478,75 @@ download_handler(
height
=
svm_plot_height
)
## 报告生成(空壳,保留接口)
svm_report
<-
function
()
invisible
()
svm_report
<-
function
()
{
if
(
is.empty
(
input
$
svm_evar
))
{
showNotification
(
i
18
n
$
t
(
"Select at least one explanatory variable to generate report"
),
type
=
"error"
)
return
(
invisible
())
}
outputs
<-
c
(
"summary"
)
inp_out
<-
list
(
list
(
prn
=
TRUE
),
""
)
figs
<-
FALSE
xcmd
<-
""
if
(
!
is.empty
(
input
$
svm_plots
,
"none"
))
{
inp
<-
check_plot_inputs
(
svm_plot_inputs
())
inp
$
size
<-
NULL
inp_out
[[
2
]]
<-
clean_args
(
inp
,
svm_plot_args
[
-1
])
inp_out
[[
2
]]
$
custom
<-
FALSE
outputs
<-
c
(
outputs
,
"plot"
)
figs
<-
TRUE
}
if
(
!
is.empty
(
input
$
svm_predict
,
"none"
)
&&
(
!
is.empty
(
input
$
svm_pred_data
)
||
!
is.empty
(
input
$
svm_pred_cmd
)))
{
pred_args
<-
clean_args
(
svm_pred_inputs
(),
svm_pred_args
[
-1
])
if
(
!
is.empty
(
pred_args
$
pred_cmd
))
{
pred_args
$
pred_cmd
<-
strsplit
(
pred_args
$
pred_cmd
,
";\\s*"
)[[
1
]]
}
else
{
pred_args
$
pred_cmd
<-
NULL
}
if
(
!
is.empty
(
pred_args
$
pred_data
))
{
pred_args
$
pred_data
<-
as.symbol
(
pred_args
$
pred_data
)
}
else
{
pred_args
$
pred_data
<-
NULL
}
inp_out
[[
2
+
figs
]]
<-
pred_args
outputs
<-
c
(
outputs
,
"pred <- predict"
)
xcmd
<-
paste0
(
xcmd
,
"print(pred, n = 10)"
)
if
(
input
$
svm_predict
%in%
c
(
"data"
,
"datacmd"
)
&&
!
is.empty
(
input
$
svm_store_pred_name
))
{
fixed
<-
fix_names
(
input
$
svm_store_pred_name
)
updateTextInput
(
session
,
"svm_store_pred_name"
,
value
=
fixed
)
xcmd
<-
paste0
(
xcmd
,
"\n"
,
input
$
svm_pred_data
,
" <- store("
,
input
$
svm_pred_data
,
", pred, name = \""
,
fixed
,
"\")"
)
}
}
svm_inp
<-
svm_inputs
()
if
(
input
$
svm_type
==
"regression"
)
{
svm_inp
$
lev
<-
NULL
}
if
(
input
$
svm_kernel
==
"linear"
)
{
svm_inp
$
gamma
<-
NULL
}
update_report
(
inp_main
=
clean_args
(
svm_inp
,
svm_args
),
fun_name
=
"svm"
,
inp_out
=
inp_out
,
outputs
=
outputs
,
figs
=
figs
,
fig.width
=
svm_plot_width
(),
fig.height
=
svm_plot_height
(),
xcmd
=
xcmd
)
}
## 报告生成
observeEvent
(
input
$
svm_report
,
{
r_info
[[
"latest_screenshot"
]]
<-
NULL
svm_report
()
...
...
radiant.model/man/cv.svm.Rd
deleted
100644 → 0
View file @
e9df660e
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/svm.R
\name{cv.svm}
\alias{cv.svm}
\title{Cross-validation for SVM}
\usage{
cv.svm(
object,
K = 5,
repeats = 1,
kernel = c("linear", "radial"),
cost = seq(0.1, 10, by = 0.5),
gamma = seq(0.1, 5, by = 0.5),
seed = 1234,
trace = TRUE,
fun,
...
)
}
\description{
Cross-validation for SVM
}
radiant.model/man/plot.svm.Rd
View file @
53dd5828
...
...
@@ -4,17 +4,7 @@
\alias{plot.svm}
\title{Plot method for the svm function}
\usage{
\method{plot}{svm}(
x,
plots = "vip",
size = 12,
nrobs = -1,
incl = NULL,
incl_int = NULL,
shiny = FALSE,
custom = FALSE,
...
)
\method{plot}{svm}(x, plots = "none", size = 12, shiny = FALSE, custom = FALSE, ...)
}
\description{
Plot method for the svm function
...
...
radiant.model/man/predict.svm.Rd
View file @
53dd5828
...
...
@@ -2,7 +2,7 @@
% Please edit documentation in R/svm.R
\name{predict.svm}
\alias{predict.svm}
\title{Predict method for
the svm function
}
\title{Predict method for
SVM model
}
\usage{
\method{predict}{svm}(
object,
...
...
@@ -14,5 +14,5 @@
)
}
\description{
Predict method for
the svm function
Predict method for
SVM model
}
radiant.model/man/varimp.Rd
View file @
53dd5828
...
...
@@ -4,9 +4,9 @@
\alias{varimp}
\title{Variable importance using the vip package and permutation importance}
\usage{
varimp(object, rvar
, lev, data = NULL, seed = 1234
)
varimp(object, rvar
= NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10
)
varimp(object, rvar
, lev, data = NULL, seed = 1234
)
varimp(object, rvar
= NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10
)
}
\arguments{
\item{object}{Model object created by Radiant}
...
...
@@ -22,5 +22,5 @@ varimp(object, rvar, lev, data = NULL, seed = 1234)
\description{
Variable importance using the vip package and permutation importance
Variable importance
using the vip package and
permutation importance
Variable importance
for SVM using
permutation importance
}
radiant.model/man/varimp_plot.Rd
View file @
53dd5828
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/nn.R
, R/svm.R
% Please edit documentation in R/nn.R
\name{varimp_plot}
\alias{varimp_plot}
\title{Plot permutation importance}
\usage{
varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
}
\arguments{
...
...
@@ -20,7 +18,5 @@ varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
\item{seed}{Random seed for reproducibility}
}
\description{
Plot permutation importance
Plot permutation importance
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment