#' Support Vector Machine using e1071 #' #' @export svm <- function(dataset, rvar, evar, type = "classification", lev = "", kernel = "radial", cost = 1, gamma = "auto", degree = 3, coef0 = 0, nu = 0.5, epsilon = 0.1, probability = FALSE, wts = "None", seed = 1234, check = NULL, form, data_filter = "", arr = "", rows = NULL, envir = parent.frame()) { ## ---- 公式入口 ---------------------------------------------------------- if (!missing(form)) { form <- as.formula(format(form)) vars <- all.vars(form) rvar <- vars[1] evar <- vars[-1] } ## ---- 基础检查 ---------------------------------------------------------- if (rvar %in% evar) return("Response variable contained in explanatory variables" %>% add_class("svm")) vars <- c(rvar, evar) if (is.empty(wts, "None")) { wts <- NULL } else { vars <- c(vars, wts) } ## ---- 数据提取 ---------------------------------------------------------- df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset)) dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir) if (!is.empty(wts)) { wts_vec <- dataset[[wts]] dataset <- select_at(dataset, setdiff(colnames(dataset), wts)) } else { wts_vec <- NULL } rv <- dataset[[rvar]] if (type == "classification") { if (lev == "") lev <- levels(as.factor(rv))[1] dataset[[rvar]] <- factor(dataset[[rvar]] == lev, levels = c(TRUE, FALSE)) } ## ---- 标准化(占位) ---------------------------------------------------- if ("standardize" %in% check) dataset <- scale_df(dataset, wts = wts_vec) ## ---- 构造公式 ---------------------------------------------------------- if (missing(form)) form <- as.formula(paste(rvar, "~ .")) ## ---- 设定种子 ---------------------------------------------------------- seed <- gsub("[^0-9]", "", seed) if (!is.empty(seed)) set.seed(as.integer(seed)) ## ---- 调 e1071::svm ----------------------------------------------------- svm_call <- list( formula = form, data = dataset, type = ifelse(type == "classification", "C-classification", "eps-regression"), kernel = kernel, cost = cost, gamma = if (gamma == "auto") 1 / ncol(select(dataset, -rvar)) else as.numeric(gamma), degree = degree, coef0 = coef0, nu = nu, epsilon = epsilon, probability = probability, weights = wts_vec, fitted = TRUE ) model <- do.call(e1071::svm, svm_call) ## ---- 打包返回 ---------------------------------------------------------- out <- as.list(environment()) out$model <- model out$df_name <- df_name out$type <- type out$lev <- if (type == "classification") lev else NULL out$check <- check add_class(out, c("svm", "model")) } #' Summary method #' @export summary.svm <- function(object, ...) { if (is.character(object)) return(object) cat("Support Vector Machine\n") cat("Data :", object$df_name, "\n") if (!is.empty(object$data_filter)) cat("Filter :", object$data_filter, "\n") cat("Response :", object$rvar, "\n") if (object$type == "classification") cat("Level :", object$lev, "\n") cat("Variables :", paste(object$evar, collapse = ", "), "\n") cat("Kernel :", object$model$kernel, "\n") cat("Cost (C) :", object$model$cost, "\n") if (object$model$kernel != "linear") cat("Gamma :", object$model$gamma, "\n") cat("Support vectors :", length(object$model$SV), "\n") invisible(object) } #' Predict method #' @export predict.svm <- function(object, pred_data = NULL, pred_cmd = "", dec = 3, envir = parent.frame(), ...) { if (is.character(object)) return(object) pfun <- function(model, newdata, ...) { predict(model, newdata, probability = object$model$probability)[, 1] } predict_model(object, pfun, "svm.predict", pred_data, pred_cmd, dec = dec, envir = envir) } #' Print predictions #' @export print.svm.predict <- function(x, ..., n = 10) { print_predict_model(x, ..., n = n, header = "SVM") }