Commit 53dd5828 authored by wuzekai's avatar wuzekai

更新了svm

parent e9df660e
...@@ -85,7 +85,6 @@ export(cv.crtree) ...@@ -85,7 +85,6 @@ export(cv.crtree)
export(cv.gbt) export(cv.gbt)
export(cv.nn) export(cv.nn)
export(cv.rforest) export(cv.rforest)
export(cv.svm)
export(dtree) export(dtree)
export(dtree_parser) export(dtree_parser)
export(evalbin) export(evalbin)
...@@ -121,6 +120,9 @@ export(sim_splitter) ...@@ -121,6 +120,9 @@ export(sim_splitter)
export(sim_summary) export(sim_summary)
export(simulater) export(simulater)
export(svm) export(svm)
export(svm_boundary_plot)
export(svm_margin_plot)
export(svm_vip_plot)
export(test_specs) export(test_specs)
export(uplift) export(uplift)
export(var_check) export(var_check)
......
This diff is collapsed.
...@@ -79,11 +79,12 @@ options( ...@@ -79,11 +79,12 @@ options(
i18n$t("Estimate"), i18n$t("Estimate"),
tabPanel(i18n$t("Linear regression (OLS)"), uiOutput("regress")), tabPanel(i18n$t("Linear regression (OLS)"), uiOutput("regress")),
tabPanel(i18n$t("Logistic regression (GLM)"), uiOutput("logistic")), tabPanel(i18n$t("Logistic regression (GLM)"), uiOutput("logistic")),
tabPanel(i18n$t("Cox Proportional Hazards Regression"),uiOutput("coxp")),
tabPanel(i18n$t("Multinomial logistic regression (MNL)"), uiOutput("mnl")), tabPanel(i18n$t("Multinomial logistic regression (MNL)"), uiOutput("mnl")),
tabPanel(i18n$t("Naive Bayes"), uiOutput("nb")), tabPanel(i18n$t("Naive Bayes"), uiOutput("nb")),
tabPanel(i18n$t("Neural Network"), uiOutput("nn")), tabPanel(i18n$t("Neural Network"), uiOutput("nn")),
tabPanel(i18n$t("Support Vector Machine (SVM)"),uiOutput("svm")), tabPanel(i18n$t("Support Vector Machine (SVM)"),uiOutput("svm")),
"----", i18n$t("Survival Analysis"),
tabPanel(i18n$t("Cox Proportional Hazards Regression"),uiOutput("coxp")),
"----", i18n$t("Trees"), "----", i18n$t("Trees"),
tabPanel(i18n$t("Classification and regression trees"), uiOutput("crtree")), tabPanel(i18n$t("Classification and regression trees"), uiOutput("crtree")),
tabPanel(i18n$t("Random Forest"), uiOutput("rf")), tabPanel(i18n$t("Random Forest"), uiOutput("rf")),
......
...@@ -23,7 +23,7 @@ svm_inputs <- reactive({ ...@@ -23,7 +23,7 @@ svm_inputs <- reactive({
svm_args svm_args
}) })
## 预测参数(保留命令模式,未改动) ## 预测参数
svm_pred_args <- as.list(if (exists("predict.svm")) { svm_pred_args <- as.list(if (exists("predict.svm")) {
formals(predict.svm) formals(predict.svm)
} else { } else {
...@@ -52,7 +52,7 @@ svm_pred_inputs <- reactive({ ...@@ -52,7 +52,7 @@ svm_pred_inputs <- reactive({
return(svm_pred_args) return(svm_pred_args)
}) })
## 绘图参数(砍掉vip、pdp、svm_margin) ## 绘图参数
svm_plot_args <- as.list(if (exists("plot.svm")) { svm_plot_args <- as.list(if (exists("plot.svm")) {
formals(plot.svm) formals(plot.svm)
} else { } else {
...@@ -77,11 +77,7 @@ output$ui_svm_rvar <- renderUI({ ...@@ -77,11 +77,7 @@ output$ui_svm_rvar <- renderUI({
vars <- varnames()[isNum] vars <- varnames()[isNum]
} }
}) })
init <- if (input$svm_type == "classification") { init <- state_single("svm_rvar", vars, isolate(input$svm_rvar))
if (is.empty(input$logit_rvar)) isolate(input$svm_rvar) else input$logit_rvar
} else {
if (is.empty(input$reg_rvar)) isolate(input$svm_rvar) else input$reg_rvar
}
selectInput( selectInput(
inputId = "svm_rvar", inputId = "svm_rvar",
label = i18n$t("Response variable:"), label = i18n$t("Response variable:"),
...@@ -109,11 +105,7 @@ output$ui_svm_evar <- renderUI({ ...@@ -109,11 +105,7 @@ output$ui_svm_evar <- renderUI({
if (not_available(input$svm_rvar)) return() if (not_available(input$svm_rvar)) return()
vars <- varnames() vars <- varnames()
if (length(vars) > 0) vars <- vars[-which(vars == input$svm_rvar)] if (length(vars) > 0) vars <- vars[-which(vars == input$svm_rvar)]
init <- if (input$svm_type == "classification") { init <- state_multiple("svm_evar", vars, isolate(input$svm_evar))
if (is.empty(input$logit_evar)) isolate(input$svm_evar) else input$logit_evar
} else {
if (is.empty(input$reg_evar)) isolate(input$svm_evar) else input$reg_evar
}
selectInput( selectInput(
inputId = "svm_evar", inputId = "svm_evar",
label = i18n$t("Explanatory variables:"), label = i18n$t("Explanatory variables:"),
...@@ -141,7 +133,7 @@ output$ui_svm_wts <- renderUI({ ...@@ -141,7 +133,7 @@ output$ui_svm_wts <- renderUI({
) )
}) })
## 存储预测值UI(残差存储已删除) ## 存储预测值UI
output$ui_svm_store_pred_name <- renderUI({ output$ui_svm_store_pred_name <- renderUI({
init <- state_init("svm_store_pred_name", "pred_svm") %>% init <- state_init("svm_store_pred_name", "pred_svm") %>%
sub("\\d{1,}$", "", .) %>% sub("\\d{1,}$", "", .) %>%
...@@ -164,7 +156,7 @@ observeEvent(input$svm_type, { ...@@ -164,7 +156,7 @@ observeEvent(input$svm_type, {
updateSelectInput(session = session, inputId = "svm_plots", selected = "none") updateSelectInput(session = session, inputId = "svm_plots", selected = "none")
}) })
## 绘图选项UI(已删vip、pdp、svm_margin) ## 绘图选项UI
output$ui_svm_plots <- renderUI({ output$ui_svm_plots <- renderUI({
req(input$svm_type) req(input$svm_type)
avail_plots <- svm_plots avail_plots <- svm_plots
...@@ -178,19 +170,7 @@ output$ui_svm_plots <- renderUI({ ...@@ -178,19 +170,7 @@ output$ui_svm_plots <- renderUI({
) )
}) })
## 数据点数量UI(仅dashboard用,保留) ## 主UI面板
output$ui_svm_nrobs <- renderUI({
nrobs <- nrow(.get_data())
choices <- c("1,000" = 1000, "5,000" = 5000, "10,000" = 10000, "All" = -1) %>%
.[. < nrobs]
selectInput(
"svm_nrobs", i18n$t("Number of data points plotted:"),
choices = choices,
selected = state_single("svm_nrobs", choices, 1000)
)
})
## 主UI面板(已删残差存储入口)
output$ui_svm <- renderUI({ output$ui_svm <- renderUI({
req(input$dataset) req(input$dataset)
tagList( tagList(
...@@ -258,7 +238,7 @@ output$ui_svm <- renderUI({ ...@@ -258,7 +238,7 @@ output$ui_svm <- renderUI({
) )
) )
), ),
# 预测面板(残差存储已删除) # 预测面板
conditionalPanel( conditionalPanel(
condition = "input.tabs_svm == 'Predict'", condition = "input.tabs_svm == 'Predict'",
selectInput( selectInput(
...@@ -303,19 +283,10 @@ output$ui_svm <- renderUI({ ...@@ -303,19 +283,10 @@ output$ui_svm <- renderUI({
) )
) )
), ),
# 绘图面板(已删vip、pdp、svm_margin) # 绘图面板
conditionalPanel( conditionalPanel(
condition = "input.tabs_svm == 'Plot'", condition = "input.tabs_svm == 'Plot'",
uiOutput("ui_svm_plots"), uiOutput("ui_svm_plots")
conditionalPanel(
condition = "input.svm_plots == 'pred_plot'",
uiOutput("ui_svm_incl"),
uiOutput("ui_svm_incl_int")
),
conditionalPanel(
condition = "input.svm_plots == 'dashboard'",
uiOutput("ui_svm_nrobs")
)
) )
), ),
# 帮助和报告面板 # 帮助和报告面板
...@@ -327,7 +298,7 @@ output$ui_svm <- renderUI({ ...@@ -327,7 +298,7 @@ output$ui_svm <- renderUI({
) )
}) })
## 绘图尺寸计算(已删vip、pdp、svm_margin) ## 绘图尺寸计算
svm_plot <- reactive({ svm_plot <- reactive({
if (svm_available() != "available") return() if (svm_available() != "available") return()
if (is.empty(input$svm_plots, "none")) return() if (is.empty(input$svm_plots, "none")) return()
...@@ -337,9 +308,6 @@ svm_plot <- reactive({ ...@@ -337,9 +308,6 @@ svm_plot <- reactive({
plot_width <- 650 plot_width <- 650
if ("decision_boundary" %in% input$svm_plots) { if ("decision_boundary" %in% input$svm_plots) {
plot_height <- 500 plot_height <- 500
} else if (input$svm_plots == "pred_plot") {
nr_vars <- length(input$svm_incl) + length(input$svm_incl_int)
plot_height <- max(250, ceiling(nr_vars / 2) * 250)
} else { } else {
plot_height <- max(500, length(res$evar) * 30) plot_height <- max(500, length(res$evar) * 30)
} }
...@@ -349,7 +317,7 @@ svm_plot <- reactive({ ...@@ -349,7 +317,7 @@ svm_plot <- reactive({
svm_plot_width <- function() svm_plot()$plot_width %||% 650 svm_plot_width <- function() svm_plot()$plot_width %||% 650
svm_plot_height <- function() svm_plot()$plot_height %||% 500 svm_plot_height <- function() svm_plot()$plot_height %||% 500
## 主输出面板(已删残差存储) ## 主输出面板
output$svm <- renderUI({ output$svm <- renderUI({
register_print_output("summary_svm", ".summary_svm") register_print_output("summary_svm", ".summary_svm")
register_print_output("predict_svm", ".predict_print_svm") register_print_output("predict_svm", ".predict_print_svm")
...@@ -393,14 +361,14 @@ svm_available <- reactive({ ...@@ -393,14 +361,14 @@ svm_available <- reactive({
} }
}) })
## 核心函数壳子 ## 核心函数
.svm <- eventReactive(input$svm_run, { .svm <- eventReactive(input$svm_run, {
svi <- svm_inputs() svi <- svm_inputs()
svi$envir <- r_data svi$envir <- r_data
withProgress(message = i18n$t("Estimating SVM model"), value = 1, do.call(svm, svi)) withProgress(message = i18n$t("Estimating SVM model"), value = 1, do.call(svm, svi))
}) })
## 辅助输出函数壳子 ## 辅助输出函数
.summary_svm <- reactive({ .summary_svm <- reactive({
if (not_pressed(input$svm_run)) return(i18n$t("** Press the Estimate button to estimate the SVM model **")) if (not_pressed(input$svm_run)) return(i18n$t("** Press the Estimate button to estimate the SVM model **"))
if (svm_available() != "available") return(svm_available()) if (svm_available() != "available") return(svm_available())
...@@ -444,7 +412,7 @@ svm_available <- reactive({ ...@@ -444,7 +412,7 @@ svm_available <- reactive({
withProgress(message = i18n$t("Generating SVM plots"), value = 1, do.call(plot, c(list(x = .svm()), pinp))) withProgress(message = i18n$t("Generating SVM plots"), value = 1, do.call(plot, c(list(x = .svm()), pinp)))
}) })
## 存储预测值(残差存储已删除) ## 存储预测值
observeEvent(input$svm_store_pred, { observeEvent(input$svm_store_pred, {
req( req(
pressed(input$svm_run), pressed(input$svm_run),
...@@ -457,14 +425,14 @@ observeEvent(input$svm_store_pred, { ...@@ -457,14 +425,14 @@ observeEvent(input$svm_store_pred, {
base_col_name <- fix_names(input$svm_store_pred_name) base_col_name <- fix_names(input$svm_store_pred_name)
meta <- attr(pred_result, "svm_meta") meta <- attr(pred_result, "svm_meta")
pred_cols <- if (meta$model_type == "classification") { pred_cols <- if (meta$model_type %in% c("classification", "regression")) {
colnames(pred_result)[grepl("^Predicted_Class|^Prob_", colnames(pred_result))] colnames(pred_result)[colnames(pred_result) == "Prediction"]
} else { } else {
"Predicted_Value" NULL
} }
new_col_names <- if (length(pred_cols) == 1) base_col_name else { new_col_names <- if (length(pred_cols) == 1) base_col_name else {
suffix <- gsub("^Predicted_|^Prob_", "", pred_cols) suffix <- gsub("^Prediction", "", pred_cols)
paste0(base_col_name, "_", suffix) paste0(base_col_name, ifelse(suffix == "", "", paste0("_", suffix)))
} }
colnames(pred_result)[match(pred_cols, colnames(pred_result))] <- new_col_names colnames(pred_result)[match(pred_cols, colnames(pred_result))] <- new_col_names
...@@ -510,9 +478,75 @@ download_handler( ...@@ -510,9 +478,75 @@ download_handler(
height = svm_plot_height height = svm_plot_height
) )
## 报告生成(空壳,保留接口) svm_report <- function() {
svm_report <- function() invisible() if (is.empty(input$svm_evar)) {
showNotification(i18n$t("Select at least one explanatory variable to generate report"), type = "error")
return(invisible())
}
outputs <- c("summary")
inp_out <- list(list(prn = TRUE), "")
figs <- FALSE
xcmd <- ""
if (!is.empty(input$svm_plots, "none")) {
inp <- check_plot_inputs(svm_plot_inputs())
inp$size <- NULL
inp_out[[2]] <- clean_args(inp, svm_plot_args[-1])
inp_out[[2]]$custom <- FALSE
outputs <- c(outputs, "plot")
figs <- TRUE
}
if (!is.empty(input$svm_predict, "none") &&
(!is.empty(input$svm_pred_data) || !is.empty(input$svm_pred_cmd))) {
pred_args <- clean_args(svm_pred_inputs(), svm_pred_args[-1])
if (!is.empty(pred_args$pred_cmd)) {
pred_args$pred_cmd <- strsplit(pred_args$pred_cmd, ";\\s*")[[1]]
} else {
pred_args$pred_cmd <- NULL
}
if (!is.empty(pred_args$pred_data)) {
pred_args$pred_data <- as.symbol(pred_args$pred_data)
} else {
pred_args$pred_data <- NULL
}
inp_out[[2 + figs]] <- pred_args
outputs <- c(outputs, "pred <- predict")
xcmd <- paste0(xcmd, "print(pred, n = 10)")
if (input$svm_predict %in% c("data", "datacmd") && !is.empty(input$svm_store_pred_name)) {
fixed <- fix_names(input$svm_store_pred_name)
updateTextInput(session, "svm_store_pred_name", value = fixed)
xcmd <- paste0(
xcmd, "\n", input$svm_pred_data, " <- store(",
input$svm_pred_data, ", pred, name = \"", fixed, "\")"
)
}
}
svm_inp <- svm_inputs()
if (input$svm_type == "regression") {
svm_inp$lev <- NULL
}
if (input$svm_kernel == "linear") {
svm_inp$gamma <- NULL
}
update_report(
inp_main = clean_args(svm_inp, svm_args),
fun_name = "svm",
inp_out = inp_out,
outputs = outputs,
figs = figs,
fig.width = svm_plot_width(),
fig.height = svm_plot_height(),
xcmd = xcmd
)
}
## 报告生成
observeEvent(input$svm_report, { observeEvent(input$svm_report, {
r_info[["latest_screenshot"]] <- NULL r_info[["latest_screenshot"]] <- NULL
svm_report() svm_report()
......
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/svm.R
\name{cv.svm}
\alias{cv.svm}
\title{Cross-validation for SVM}
\usage{
cv.svm(
object,
K = 5,
repeats = 1,
kernel = c("linear", "radial"),
cost = seq(0.1, 10, by = 0.5),
gamma = seq(0.1, 5, by = 0.5),
seed = 1234,
trace = TRUE,
fun,
...
)
}
\description{
Cross-validation for SVM
}
...@@ -4,17 +4,7 @@ ...@@ -4,17 +4,7 @@
\alias{plot.svm} \alias{plot.svm}
\title{Plot method for the svm function} \title{Plot method for the svm function}
\usage{ \usage{
\method{plot}{svm}( \method{plot}{svm}(x, plots = "none", size = 12, shiny = FALSE, custom = FALSE, ...)
x,
plots = "vip",
size = 12,
nrobs = -1,
incl = NULL,
incl_int = NULL,
shiny = FALSE,
custom = FALSE,
...
)
} }
\description{ \description{
Plot method for the svm function Plot method for the svm function
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
% Please edit documentation in R/svm.R % Please edit documentation in R/svm.R
\name{predict.svm} \name{predict.svm}
\alias{predict.svm} \alias{predict.svm}
\title{Predict method for the svm function} \title{Predict method for SVM model}
\usage{ \usage{
\method{predict}{svm}( \method{predict}{svm}(
object, object,
...@@ -14,5 +14,5 @@ ...@@ -14,5 +14,5 @@
) )
} }
\description{ \description{
Predict method for the svm function Predict method for SVM model
} }
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
\alias{varimp} \alias{varimp}
\title{Variable importance using the vip package and permutation importance} \title{Variable importance using the vip package and permutation importance}
\usage{ \usage{
varimp(object, rvar, lev, data = NULL, seed = 1234) varimp(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10)
varimp(object, rvar, lev, data = NULL, seed = 1234) varimp(object, rvar = NULL, lev = NULL, data = NULL, seed = 1234, nperm = 10)
} }
\arguments{ \arguments{
\item{object}{Model object created by Radiant} \item{object}{Model object created by Radiant}
...@@ -22,5 +22,5 @@ varimp(object, rvar, lev, data = NULL, seed = 1234) ...@@ -22,5 +22,5 @@ varimp(object, rvar, lev, data = NULL, seed = 1234)
\description{ \description{
Variable importance using the vip package and permutation importance Variable importance using the vip package and permutation importance
Variable importance using the vip package and permutation importance Variable importance for SVM using permutation importance
} }
% Generated by roxygen2: do not edit by hand % Generated by roxygen2: do not edit by hand
% Please edit documentation in R/nn.R, R/svm.R % Please edit documentation in R/nn.R
\name{varimp_plot} \name{varimp_plot}
\alias{varimp_plot} \alias{varimp_plot}
\title{Plot permutation importance} \title{Plot permutation importance}
\usage{ \usage{
varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
varimp_plot(object, rvar, lev, data = NULL, seed = 1234) varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
} }
\arguments{ \arguments{
...@@ -20,7 +18,5 @@ varimp_plot(object, rvar, lev, data = NULL, seed = 1234) ...@@ -20,7 +18,5 @@ varimp_plot(object, rvar, lev, data = NULL, seed = 1234)
\item{seed}{Random seed for reproducibility} \item{seed}{Random seed for reproducibility}
} }
\description{ \description{
Plot permutation importance
Plot permutation importance Plot permutation importance
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment