From c56c5069794efb69ee00d488feac4dd5c1d6279c Mon Sep 17 00:00:00 2001 From: wuzekai <3025054974@qq.com> Date: Fri, 21 Nov 2025 16:07:59 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E7=A5=9E=E7=BB=8F?= =?UTF-8?q?=E7=BD=91=E7=BB=9C=E7=95=8C=E9=9D=A2=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- radiant.model/NAMESPACE | 1 + radiant.model/R/nn.R | 1452 ++++++++-------- radiant.model/R/svm.R | 4 +- radiant.model/inst/app/tools/analysis/nn_ui.R | 1458 ++++++++--------- radiant.model/man/scale_df.Rd | 6 +- radiant.model/man/scale_sv.Rd | 11 + 6 files changed, 1478 insertions(+), 1454 deletions(-) create mode 100644 radiant.model/man/scale_sv.Rd diff --git a/radiant.model/NAMESPACE b/radiant.model/NAMESPACE index a506c51..8bfb37f 100644 --- a/radiant.model/NAMESPACE +++ b/radiant.model/NAMESPACE @@ -112,6 +112,7 @@ export(repeater) export(rforest) export(rig) export(scale_df) +export(scale_sv) export(sdw) export(sensitivity) export(sim_cleaner) diff --git a/radiant.model/R/nn.R b/radiant.model/R/nn.R index b708440..5c72560 100644 --- a/radiant.model/R/nn.R +++ b/radiant.model/R/nn.R @@ -1,718 +1,734 @@ -#' Neural Networks using nnet -#' -#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant -#' -#' @param dataset Dataset -#' @param rvar The response variable in the model -#' @param evar Explanatory variables in the model -#' @param type Model type (i.e., "classification" or "regression") -#' @param lev The level in the response variable defined as _success_ -#' @param size Number of units (nodes) in the hidden layer -#' @param decay Parameter decay -#' @param wts Weights to use in estimation -#' @param seed Random seed to use as the starting point -#' @param check Optional estimation parameters ("standardize" is the default) -#' @param form Optional formula to use instead of rvar and evar -#' @param data_filter Expression entered in, e.g., Data > View to filter the dataset in Radiant. The expression should be a string (e.g., "price > 10000") -#' @param arr Expression to arrange (sort) the data on (e.g., "color, desc(price)") -#' @param rows Rows to select from the specified dataset -#' @param envir Environment to extract data from -#' -#' @return A list with all variables defined in nn as an object of class nn -#' -#' @examples -#' nn(titanic, "survived", c("pclass", "sex"), lev = "Yes") %>% summary() -#' nn(titanic, "survived", c("pclass", "sex")) %>% str() -#' nn(diamonds, "price", c("carat", "clarity"), type = "regression") %>% summary() -#' @seealso \code{\link{summary.nn}} to summarize results -#' @seealso \code{\link{plot.nn}} to plot results -#' @seealso \code{\link{predict.nn}} for prediction -#' -#' @importFrom nnet nnet -#' -#' @export -nn <- function(dataset, rvar, evar, - type = "classification", lev = "", - size = 1, decay = .5, wts = "None", - seed = NA, check = "standardize", - form, data_filter = "", arr = "", - rows = NULL, envir = parent.frame()) { - if (!missing(form)) { - form <- as.formula(format(form)) - paste0(format(form), collapse = "") - - vars <- all.vars(form) - rvar <- vars[1] - evar <- vars[-1] - } - - if (rvar %in% evar) { - return("Response variable contained in the set of explanatory variables.\nPlease update model specification." %>% - add_class("nn")) - } else if (is.empty(size) || size < 1) { - return("Size should be larger than or equal to 1." %>% add_class("nn")) - } else if (is.empty(decay) || decay < 0) { - return("Decay should be larger than or equal to 0." %>% add_class("nn")) - } - - vars <- c(rvar, evar) - - if (is.empty(wts, "None")) { - wts <- NULL - } else if (is_string(wts)) { - wtsname <- wts - vars <- c(rvar, evar, wtsname) - } - - df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset)) - dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir) - - if (!is.empty(wts)) { - if (exists("wtsname")) { - wts <- dataset[[wtsname]] - dataset <- select_at(dataset, .vars = base::setdiff(colnames(dataset), wtsname)) - } - if (length(wts) != nrow(dataset)) { - return( - paste0("Length of the weights variable is not equal to the number of rows in the dataset (", format_nr(length(wts), dec = 0), " vs ", format_nr(nrow(dataset), dec = 0), ")") %>% - add_class("nn") - ) - } - } - - not_vary <- colnames(dataset)[summarise_all(dataset, does_vary) == FALSE] - if (length(not_vary) > 0) { - return(paste0("The following variable(s) show no variation. Please select other variables.\n\n** ", paste0(not_vary, collapse = ", "), " **") %>% - add_class("nn")) - } - - rv <- dataset[[rvar]] - - if (type == "classification") { - linout <- FALSE - entropy <- TRUE - if (lev == "") { - if (is.factor(rv)) { - lev <- levels(rv)[1] - } else { - lev <- as.character(rv) %>% - as.factor() %>% - levels() %>% - .[1] - } - } - - ## transformation to TRUE/FALSE depending on the selected level (lev) - dataset[[rvar]] <- dataset[[rvar]] == lev - } else { - linout <- TRUE - entropy <- FALSE - } - - ## standardize data to limit stability issues ... - # http://stats.stackexchange.com/questions/23235/how-do-i-improve-my-neural-network-stability - if ("standardize" %in% check) { - dataset <- scale_df(dataset, wts = wts) - } - - vars <- evar - ## in case : is used - if (length(vars) < (ncol(dataset) - 1)) { - vars <- evar <- colnames(dataset)[-1] - } - - if (missing(form)) form <- as.formula(paste(rvar, "~ . ")) - - ## use decay http://stats.stackexchange.com/a/70146/61693 - nninput <- list( - formula = form, - rang = .1, size = size, decay = decay, weights = wts, - maxit = 10000, linout = linout, entropy = entropy, - skip = FALSE, trace = FALSE, data = dataset - ) - - ## based on https://stackoverflow.com/a/14324316/1974918 - seed <- gsub("[^0-9]", "", seed) - if (!is.empty(seed)) { - if (exists(".Random.seed")) { - gseed <- .Random.seed - on.exit(.Random.seed <<- gseed) - } - set.seed(seed) - } - - ## need do.call so Garson/Olden plot will work - model <- do.call(nnet::nnet, nninput) - coefnames <- model$coefnames - hasLevs <- sapply(select(dataset, -1), function(x) is.factor(x) || is.logical(x) || is.character(x)) - if (sum(hasLevs) > 0) { - for (i in names(hasLevs[hasLevs])) { - coefnames %<>% gsub(paste0("^", i), paste0(i, "|"), .) %>% - gsub(paste0(":", i), paste0(":", i, "|"), .) - } - rm(i, hasLevs) - } - - ## nn returns residuals as a matrix - model$residuals <- model$residuals[, 1] - - ## nn model object does not include the data by default - model$model <- dataset - rm(dataset, envir) ## dataset not needed elsewhere - - as.list(environment()) %>% add_class(c("nn", "model")) -} - -#' Center or standardize variables in a data frame -#' -#' @param dataset Data frame -#' @param center Center data (TRUE or FALSE) -#' @param scale Scale data (TRUE or FALSE) -#' @param sf Scaling factor (default is 2) -#' @param wts Weights to use (default is NULL for no weights) -#' @param calc Calculate mean and sd or use attributes attached to dat -#' -#' @return Scaled data frame -#' -#' @export -scale_df <- function(dataset, center = TRUE, scale = TRUE, - sf = 2, wts = NULL, calc = TRUE) { - isNum <- sapply(dataset, function(x) is.numeric(x)) - if (length(isNum) == 0 || sum(isNum) == 0) { - return(dataset) - } - cn <- names(isNum)[isNum] - - ## remove set_attr calls when dplyr removes and keep attributes appropriately - descr <- attr(dataset, "description") - if (calc) { - if (length(wts) == 0) { - ms <- summarise_at(dataset, .vars = cn, .funs = ~ mean(., na.rm = TRUE)) %>% - set_attr("description", NULL) - if (scale) { - sds <- summarise_at(dataset, .vars = cn, .funs = ~ sd(., na.rm = TRUE)) %>% - set_attr("description", NULL) - } - } else { - ms <- summarise_at(dataset, .vars = cn, .funs = ~ weighted.mean(., wts, na.rm = TRUE)) %>% - set_attr("description", NULL) - if (scale) { - sds <- summarise_at(dataset, .vars = cn, .funs = ~ weighted.sd(., wts, na.rm = TRUE)) %>% - set_attr("description", NULL) - } - } - } else { - ms <- attr(dataset, "radiant_ms") - sds <- attr(dataset, "radiant_sds") - if (is.null(ms) && is.null(sds)) { - return(dataset) - } - } - if (center && scale) { - icn <- intersect(names(ms), cn) - dataset[icn] <- lapply(icn, function(var) (dataset[[var]] - ms[[var]]) / (sf * sds[[var]])) - dataset %>% - set_attr("radiant_ms", ms) %>% - set_attr("radiant_sds", sds) %>% - set_attr("radiant_sf", sf) %>% - set_attr("description", descr) - } else if (center) { - icn <- intersect(names(ms), cn) - dataset[icn] <- lapply(icn, function(var) dataset[[var]] - ms[[var]]) - dataset %>% - set_attr("radiant_ms", ms) %>% - set_attr("description", descr) - } else if (scale) { - icn <- intersect(names(sds), cn) - dataset[icn] <- lapply(icn, function(var) dataset[[var]] / (sf * sds[[var]])) - set_attr("radiant_sds", sds) %>% - set_attr("radiant_sf", sf) %>% - set_attr("description", descr) - } else { - dataset - } -} - -#' Summary method for the nn function -#' -#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant -#' -#' @param object Return value from \code{\link{nn}} -#' @param prn Print list of weights -#' @param ... further arguments passed to or from other methods -#' -#' @examples -#' result <- nn(titanic, "survived", "pclass", lev = "Yes") -#' summary(result) -#' @seealso \code{\link{nn}} to generate results -#' @seealso \code{\link{plot.nn}} to plot results -#' @seealso \code{\link{predict.nn}} for prediction -#' -#' @export -summary.nn <- function(object, prn = TRUE, ...) { - if (is.character(object)) { - return(object) - } - cat("Neural Network\n") - if (object$type == "classification") { - cat("Activation function : Logistic (classification)") - } else { - cat("Activation function : Linear (regression)") - } - cat("\nData :", object$df_name) - if (!is.empty(object$data_filter)) { - cat("\nFilter :", gsub("\\n", "", object$data_filter)) - } - if (!is.empty(object$arr)) { - cat("\nArrange :", gsub("\\n", "", object$arr)) - } - if (!is.empty(object$rows)) { - cat("\nSlice :", gsub("\\n", "", object$rows)) - } - cat("\nResponse variable :", object$rvar) - if (object$type == "classification") { - cat("\nLevel :", object$lev, "in", object$rvar) - } - cat("\nExplanatory variables:", paste0(object$evar, collapse = ", "), "\n") - if (length(object$wtsname) > 0) { - cat("Weights used :", object$wtsname, "\n") - } - cat("Network size :", object$size, "\n") - cat("Parameter decay :", object$decay, "\n") - if (!is.empty(object$seed)) { - cat("Seed :", object$seed, "\n") - } - - network <- paste0(object$model$n, collapse = "-") - nweights <- length(object$model$wts) - cat("Network :", network, "with", nweights, "weights\n") - - if (!is.empty(object$wts, "None") && (length(unique(object$wts)) > 2 || min(object$wts) >= 1)) { - cat("Nr obs :", format_nr(sum(object$wts), dec = 0), "\n") - } else { - cat("Nr obs :", format_nr(length(object$rv), dec = 0), "\n") - } - - if (object$model$convergence != 0) { - cat("\n** The model did not converge **") - } else { - if (prn) { - cat("Weights :\n") - oop <- base::options(width = 100) - on.exit(base::options(oop), add = TRUE) - capture.output(summary(object$model))[-1:-2] %>% - gsub("^", " ", .) %>% - paste0(collapse = "\n") %>% - cat("\n") - } - } -} - -#' Variable importance using the vip package and permutation importance -#' -#' @param object Model object created by Radiant -#' @param rvar Label to identify the response or target variable -#' @param lev Reference class for binary classifier (rvar) -#' @param data Data to use for prediction. Will default to the data used to estimate the model -#' @param seed Random seed for reproducibility -#' -#' @importFrom vip vi -#' -#' @export -varimp <- function(object, rvar, lev, data = NULL, seed = 1234) { - if (is.null(data)) data <- object$model$model - - # needed to avoid rescaling during prediction - object$check <- setdiff(object$check, c("center", "standardize")) - - arg_list <- list(object, pred_data = data, se = FALSE) - if (missing(rvar)) rvar <- object$rvar - if (missing(lev) && object$type == "classification") { - if (!is.empty(object$lev)) { - lev <- object$lev - } - if (!is.logical(data[[rvar]])) { - # don't change if already logical - data[[rvar]] <- data[[rvar]] == lev - } - } else if (object$type == "classification") { - data[[rvar]] <- data[[rvar]] == lev - } - - fun <- function(object, arg_list) do.call(predict, arg_list)[["Prediction"]] - if (inherits(object, "rforest")) { - arg_list$OOB <- FALSE # all 0 importance scores when using OOB - if (object$type == "classification") { - fun <- function(object, arg_list) do.call(predict, arg_list)[[object$lev]] - } - } - - pred_fun <- function(object, newdata) { - arg_list$pred_data <- newdata - fun(object, arg_list) - } - - set.seed(seed) - if (object$type == "regression") { - vimp <- vip::vi( - object, - target = rvar, - method = "permute", - metric = "rsq", # "rmse" - pred_wrapper = pred_fun, - train = data - ) - } else { - # required after transition to yardstick by the vip package - data[[rvar]] <- factor(data[[rvar]], levels = c("TRUE", "FALSE")) - vimp <- vip::vi( - object, - target = rvar, - event_level = "first", - method = "permute", - metric = "roc_auc", - pred_wrapper = pred_fun, - train = data - ) - } - - vimp %>% - filter(Importance != 0) %>% - mutate(Variable = factor(Variable, levels = rev(Variable))) -} - -#' Plot permutation importance -#' -#' @param object Model object created by Radiant -#' @param rvar Label to identify the response or target variable -#' @param lev Reference class for binary classifier (rvar) -#' @param data Data to use for prediction. Will default to the data used to estimate the model -#' @param seed Random seed for reproducibility -#' -#' @importFrom vip vi -#' -#' @export -varimp_plot <- function(object, rvar, lev, data = NULL, seed = 1234) { - vi_scores <- varimp(object, rvar, lev, data = data, seed = seed) - visualize(vi_scores, yvar = "Importance", xvar = "Variable", type = "bar", custom = TRUE) + - labs( - title = "Permutation Importance", - x = NULL, - y = ifelse(object$type == "regression", "Importance (R-square decrease)", "Importance (AUC decrease)") - ) + - coord_flip() + - theme(axis.text.y = element_text(hjust = 0)) -} - -#' Plot method for the nn function -#' -#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant -#' -#' @param x Return value from \code{\link{nn}} -#' @param plots Plots to produce for the specified Neural Network model. Use "" to avoid showing any plots (default). Options are "olden" or "garson" for importance plots, or "net" to depict the network structure -#' @param size Font size used -#' @param pad_x Padding for explanatory variable labels in the network plot. Default value is 0.9, smaller numbers (e.g., 0.5) increase the amount of padding -#' @param nrobs Number of data points to show in dashboard scatter plots (-1 for all) -#' @param incl Which variables to include in a coefficient plot or PDP plot -#' @param incl_int Which interactions to investigate in PDP plots -#' @param shiny Did the function call originate inside a shiny app -#' @param custom Logical (TRUE, FALSE) to indicate if ggplot object (or list of ggplot objects) should be returned. This option can be used to customize plots (e.g., add a title, change x and y labels, etc.). See examples and \url{https://ggplot2.tidyverse.org} for options. -#' @param ... further arguments passed to or from other methods -#' -#' @examples -#' result <- nn(titanic, "survived", c("pclass", "sex"), lev = "Yes") -#' plot(result, plots = "net") -#' plot(result, plots = "olden") -#' @seealso \code{\link{nn}} to generate results -#' @seealso \code{\link{summary.nn}} to summarize results -#' @seealso \code{\link{predict.nn}} for prediction -#' -#' @importFrom NeuralNetTools plotnet olden garson -#' @importFrom graphics par -#' -#' @export -plot.nn <- function(x, plots = "vip", size = 12, pad_x = 0.9, nrobs = -1, - incl = NULL, incl_int = NULL, - shiny = FALSE, custom = FALSE, ...) { - if (is.character(x) || !inherits(x$model, "nnet")) { - return(x) - } - plot_list <- list() - nrCol <- 1 - - if ("olden" %in% plots || "olsen" %in% plots) { ## legacy for typo - plot_list[["olsen"]] <- NeuralNetTools::olden(x$model, x_lab = x$coefnames, cex_val = 4) + - coord_flip() + - theme_set(theme_gray(base_size = size)) + - theme(legend.position = "none") + - labs(title = paste0("Olden plot of variable importance (size = ", x$size, ", decay = ", x$decay, ")")) - } - - if ("garson" %in% plots) { - plot_list[["garson"]] <- NeuralNetTools::garson(x$model, x_lab = x$coefnames) + - coord_flip() + - theme_set(theme_gray(base_size = size)) + - theme(legend.position = "none") + - labs(title = paste0("Garson plot of variable importance (size = ", x$size, ", decay = ", x$decay, ")")) - } - - if ("vip" %in% plots) { - vi_scores <- varimp(x) - plot_list[["vip"]] <- - visualize(vi_scores, yvar = "Importance", xvar = "Variable", type = "bar", custom = TRUE) + - labs( - title = paste0("Permutation Importance (size = ", x$size, ", decay = ", x$decay, ")"), - x = NULL, - y = ifelse(x$type == "regression", "Importance (R-square decrease)", "Importance (AUC decrease)") - ) + - coord_flip() + - theme(axis.text.y = element_text(hjust = 0)) - } - - if ("net" %in% plots) { - ## don't need as much spacing at the top and bottom - mar <- par(mar = c(0, 4.1, 0, 2.1)) - on.exit(par(mar = mar$mar)) - return(do.call(NeuralNetTools::plotnet, list(mod_in = x$model, x_names = x$coefnames, pad_x = pad_x, cex_val = size / 16))) - } - - if ("pred_plot" %in% plots) { - nrCol <- 2 - if (length(incl) > 0 | length(incl_int) > 0) { - plot_list <- pred_plot(x, plot_list, incl, incl_int, ...) - } else { - return("Select one or more variables to generate Prediction plots") - } - } - - if ("pdp" %in% plots) { - nrCol <- 2 - if (length(incl) > 0 || length(incl_int) > 0) { - plot_list <- pdp_plot(x, plot_list, incl, incl_int, ...) - } else { - return("Select one or more variables to generate Partial Dependence Plots") - } - } - - if (x$type == "regression" && "dashboard" %in% plots) { - plot_list <- plot.regress(x, plots = "dashboard", lines = "line", nrobs = nrobs, custom = TRUE) - nrCol <- 2 - } - - if (length(plot_list) > 0) { - if (custom) { - if (length(plot_list) == 1) plot_list[[1]] else plot_list - } else { - patchwork::wrap_plots(plot_list, ncol = nrCol) %>% - (function(x) if (isTRUE(shiny)) x else print(x)) - } - } -} - -#' Predict method for the nn function -#' -#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant -#' -#' @param object Return value from \code{\link{nn}} -#' @param pred_data Provide the dataframe to generate predictions (e.g., diamonds). The dataset must contain all columns used in the estimation -#' @param pred_cmd Generate predictions using a command. For example, `pclass = levels(pclass)` would produce predictions for the different levels of factor `pclass`. To add another variable, create a vector of prediction strings, (e.g., c('pclass = levels(pclass)', 'age = seq(0,100,20)') -#' @param dec Number of decimals to show -#' @param envir Environment to extract data from -#' @param ... further arguments passed to or from other methods -#' -#' @examples -#' result <- nn(titanic, "survived", c("pclass", "sex"), lev = "Yes") -#' predict(result, pred_cmd = "pclass = levels(pclass)") -#' result <- nn(diamonds, "price", "carat:color", type = "regression") -#' predict(result, pred_cmd = "carat = 1:3") -#' predict(result, pred_data = diamonds) %>% head() -#' @seealso \code{\link{nn}} to generate the result -#' @seealso \code{\link{summary.nn}} to summarize results -#' -#' @export -predict.nn <- function(object, pred_data = NULL, pred_cmd = "", - dec = 3, envir = parent.frame(), ...) { - if (is.character(object)) { - return(object) - } - - ## ensure you have a name for the prediction dataset - if (is.data.frame(pred_data)) { - df_name <- deparse(substitute(pred_data)) - } else { - df_name <- pred_data - } - - pfun <- function(model, pred, se, conf_lev) { - pred_val <- try(sshhr(predict(model, pred)), silent = TRUE) - - if (!inherits(pred_val, "try-error")) { - pred_val %<>% as.data.frame(stringsAsFactors = FALSE) %>% - select(1) %>% - set_colnames("Prediction") - } - - pred_val - } - - predict_model(object, pfun, "nn.predict", pred_data, pred_cmd, conf_lev = 0.95, se = FALSE, dec, envir = envir) %>% - set_attr("radiant_pred_data", df_name) -} - -#' Print method for predict.nn -#' -#' @param x Return value from prediction method -#' @param ... further arguments passed to or from other methods -#' @param n Number of lines of prediction results to print. Use -1 to print all lines -#' -#' @export -print.nn.predict <- function(x, ..., n = 10) { - print_predict_model(x, ..., n = n, header = "Neural Network") -} - -#' Cross-validation for a Neural Network -#' -#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant -#' -#' @param object Object of type "nn" or "nnet" -#' @param K Number of cross validation passes to use -#' @param repeats Repeated cross validation -#' @param size Number of units (nodes) in the hidden layer -#' @param decay Parameter decay -#' @param seed Random seed to use as the starting point -#' @param trace Print progress -#' @param fun Function to use for model evaluation (i.e., auc for classification and RMSE for regression) -#' @param ... Additional arguments to be passed to 'fun' -#' -#' @return A data.frame sorted by the mean of the performance metric -#' -#' @seealso \code{\link{nn}} to generate an initial model that can be passed to cv.nn -#' @seealso \code{\link{Rsq}} to calculate an R-squared measure for a regression -#' @seealso \code{\link{RMSE}} to calculate the Root Mean Squared Error for a regression -#' @seealso \code{\link{MAE}} to calculate the Mean Absolute Error for a regression -#' @seealso \code{\link{auc}} to calculate the area under the ROC curve for classification -#' @seealso \code{\link{profit}} to calculate profits for classification at a cost/margin threshold -#' -#' @importFrom nnet nnet.formula -#' @importFrom shiny getDefaultReactiveDomain withProgress incProgress -#' -#' @examples -#' \dontrun{ -#' result <- nn(dvd, "buy", c("coupon", "purch", "last")) -#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2) -#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2, fun = profit, cost = 1, margin = 5) -#' result <- nn(diamonds, "price", c("carat", "color", "clarity"), type = "regression") -#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2) -#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2, fun = Rsq) -#' } -#' -#' @export -cv.nn <- function(object, K = 5, repeats = 1, decay = seq(0, 1, .2), size = 1:5, - seed = 1234, trace = TRUE, fun, ...) { - if (inherits(object, "nn")) { - ms <- attr(object$model$model, "radiant_ms")[[object$rvar]] - sds <- attr(object$model$model, "radiant_sds")[[object$rvar]] - if (length(sds) == 0) { - sds <- sf <- 1 - } else { - sf <- attr(object$model$model, "radiant_sf") - sf <- ifelse(length(sf) == 0, 2, sf) - } - object <- object$model - } else { - ms <- 0 - sds <- 1 - sf <- 1 - } - - if (inherits(object, "nnet")) { - dv <- as.character(object$call$formula[[2]]) - m <- eval(object$call[["data"]]) - weights <- eval(object$call[["weights"]]) - if (is.numeric(m[[dv]])) { - type <- "regression" - } else { - type <- "classification" - if (is.factor(m[[dv]])) { - lev <- levels(m[[dv]])[1] - } else if (is.logical(m[[dv]])) { - lev <- TRUE - } else { - stop("The level to use for classification is not clear. Use a factor of logical as the response variable") - } - } - } else { - stop("The model object does not seems to be a neural network") - } - - set.seed(seed) - tune_grid <- expand.grid(decay = decay, size = size) - out <- data.frame(mean = NA, std = NA, min = NA, max = NA, decay = tune_grid[["decay"]], size = tune_grid[["size"]]) - - if (missing(fun)) { - if (type == "classification") { - fun <- radiant.model::auc - cn <- "AUC (mean)" - } else { - fun <- radiant.model::RMSE - cn <- "RMSE (mean)" - } - } else { - cn <- glue("{deparse(substitute(fun))} (mean)") - } - - if (length(shiny::getDefaultReactiveDomain()) > 0) { - trace <- FALSE - incProgress <- shiny::incProgress - withProgress <- shiny::withProgress - } else { - incProgress <- function(...) {} - withProgress <- function(...) list(...)[["expr"]] - } - - nitt <- nrow(tune_grid) - withProgress(message = "Running cross-validation (nn)", value = 0, { - for (i in seq_len(nitt)) { - perf <- double(K * repeats) - object$call[["decay"]] <- tune_grid[i, "decay"] - object$call[["size"]] <- tune_grid[i, "size"] - if (trace) cat("Working on size", tune_grid[i, "size"], "decay", tune_grid[i, "decay"], "\n") - for (j in seq_len(repeats)) { - rand <- sample(K, nrow(m), replace = TRUE) - for (k in seq_len(K)) { - object$call[["data"]] <- quote(m[rand != k, , drop = FALSE]) - if (length(weights) > 0) { - object$call[["weights"]] <- weights[rand != k] - } - pred <- predict(eval(object$call), newdata = m[rand == k, , drop = FALSE])[, 1] - if (type == "classification") { - if (missing(...)) { - perf[k + (j - 1) * K] <- fun(pred, unlist(m[rand == k, dv]), lev) - } else { - perf[k + (j - 1) * K] <- fun(pred, unlist(m[rand == k, dv]), lev, ...) - } - } else { - pred <- pred * sf * sds + ms - rvar <- unlist(m[rand == k, dv]) * sf * sds + ms - if (missing(...)) { - perf[k + (j - 1) * K] <- fun(pred, rvar) - } else { - perf[k + (j - 1) * K] <- fun(pred, rvar, ...) - } - } - } - } - out[i, 1:4] <- c(mean(perf), sd(perf), min(perf), max(perf)) - incProgress(1 / nitt, detail = paste("\nCompleted run", i, "out of", nitt)) - } - }) - - if (type == "classification") { - out <- arrange(out, desc(mean)) - } else { - out <- arrange(out, mean) - } - ## show evaluation metric in column name - colnames(out)[1] <- cn - out -} +#' Neural Networks using nnet +#' +#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant +#' +#' @param dataset Dataset +#' @param rvar The response variable in the model +#' @param evar Explanatory variables in the model +#' @param type Model type (i.e., "classification" or "regression") +#' @param lev The level in the response variable defined as _success_ +#' @param size Number of units (nodes) in the hidden layer +#' @param decay Parameter decay +#' @param wts Weights to use in estimation +#' @param seed Random seed to use as the starting point +#' @param check Optional estimation parameters ("standardize" is the default) +#' @param form Optional formula to use instead of rvar and evar +#' @param data_filter Expression entered in, e.g., Data > View to filter the dataset in Radiant. The expression should be a string (e.g., "price > 10000") +#' @param arr Expression to arrange (sort) the data on (e.g., "color, desc(price)") +#' @param rows Rows to select from the specified dataset +#' @param envir Environment to extract data from +#' +#' @return A list with all variables defined in nn as an object of class nn +#' +#' @examples +#' nn(titanic, "survived", c("pclass", "sex"), lev = "Yes") %>% summary() +#' nn(titanic, "survived", c("pclass", "sex")) %>% str() +#' nn(diamonds, "price", c("carat", "clarity"), type = "regression") %>% summary() +#' @seealso \code{\link{summary.nn}} to summarize results +#' @seealso \code{\link{plot.nn}} to plot results +#' @seealso \code{\link{predict.nn}} for prediction +#' +#' @importFrom nnet nnet +#' +#' @export +nn <- function(dataset, rvar, evar, + type = "classification", lev = "", + size = 1, decay = .5, wts = "None", + seed = NA, check = "standardize", + form, data_filter = "", arr = "", + rows = NULL, envir = parent.frame()) { + if (!missing(form)) { + form <- as.formula(format(form)) + paste0(format(form), collapse = "") + + vars <- all.vars(form) + rvar <- vars[1] + evar <- vars[-1] + } + + if (rvar %in% evar) { + return("Response variable contained in the set of explanatory variables.\nPlease update model specification." %>% + add_class("nn")) + } else if (is.empty(size) || size < 1) { + return("Size should be larger than or equal to 1." %>% add_class("nn")) + } else if (is.empty(decay) || decay < 0) { + return("Decay should be larger than or equal to 0." %>% add_class("nn")) + } + + vars <- c(rvar, evar) + + if (is.empty(wts, "None")) { + wts <- NULL + } else if (is_string(wts)) { + wtsname <- wts + vars <- c(rvar, evar, wtsname) + } + + df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset)) + dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir) + + if (!is.empty(wts)) { + if (exists("wtsname")) { + wts <- dataset[[wtsname]] + dataset <- select_at(dataset, .vars = base::setdiff(colnames(dataset), wtsname)) + } + if (length(wts) != nrow(dataset)) { + return( + paste0("Length of the weights variable is not equal to the number of rows in the dataset (", format_nr(length(wts), dec = 0), " vs ", format_nr(nrow(dataset), dec = 0), ")") %>% + add_class("nn") + ) + } + } + + not_vary <- colnames(dataset)[summarise_all(dataset, does_vary) == FALSE] + if (length(not_vary) > 0) { + return(paste0("The following variable(s) show no variation. Please select other variables.\n\n** ", paste0(not_vary, collapse = ", "), " **") %>% + add_class("nn")) + } + + rv <- dataset[[rvar]] + + if (type == "classification") { + linout <- FALSE + entropy <- TRUE + if (lev == "") { + if (is.factor(rv)) { + lev <- levels(rv)[1] + } else { + lev <- as.character(rv) %>% + as.factor() %>% + levels() %>% + .[1] + } + } + + ## transformation to TRUE/FALSE depending on the selected level (lev) + dataset[[rvar]] <- dataset[[rvar]] == lev + } else { + linout <- TRUE + entropy <- FALSE + } + + ## standardize data to limit stability issues ... + # http://stats.stackexchange.com/questions/23235/how-do-i-improve-my-neural-network-stability + if ("standardize" %in% check) { + dataset <- scale_df(dataset, wts = wts) + } + + vars <- evar + ## in case : is used + if (length(vars) < (ncol(dataset) - 1)) { + vars <- evar <- colnames(dataset)[-1] + } + + if (missing(form)) form <- as.formula(paste(rvar, "~ . ")) + + ## use decay http://stats.stackexchange.com/a/70146/61693 + nninput <- list( + formula = form, + rang = .1, size = size, decay = decay, weights = wts, + maxit = 10000, linout = linout, entropy = entropy, + skip = FALSE, trace = FALSE, data = dataset + ) + + ## based on https://stackoverflow.com/a/14324316/1974918 + seed <- gsub("[^0-9]", "", seed) + if (!is.empty(seed)) { + if (exists(".Random.seed")) { + gseed <- .Random.seed + on.exit(.Random.seed <<- gseed) + } + set.seed(seed) + } + + ## need do.call so Garson/Olden plot will work + model <- do.call(nnet::nnet, nninput) + coefnames <- model$coefnames + hasLevs <- sapply(select(dataset, -1), function(x) is.factor(x) || is.logical(x) || is.character(x)) + if (sum(hasLevs) > 0) { + for (i in names(hasLevs[hasLevs])) { + coefnames %<>% gsub(paste0("^", i), paste0(i, "|"), .) %>% + gsub(paste0(":", i), paste0(":", i, "|"), .) + } + rm(i, hasLevs) + } + + ## nn returns residuals as a matrix + model$residuals <- model$residuals[, 1] + + ## nn model object does not include the data by default + model$model <- dataset + rm(dataset, envir) ## dataset not needed elsewhere + + as.list(environment()) %>% add_class(c("nn", "model")) +} + +#' Center or standardize variables in a data frame +#' +#' @param dataset Data frame +#' @param center Center data (TRUE or FALSE) +#' @param scale Scale data (TRUE or FALSE) +#' @param sf Scaling factor (default is 2) +#' @param wts Weights to use (default is NULL for no weights) +#' @param calc Calculate mean and sd or use attributes attached to dat +#' +#' @return Scaled data frame +#' +#' @export +scale_df <- function(dataset, center = TRUE, scale = TRUE, + sf = 2, wts = NULL, calc = TRUE) { + isNum <- sapply(dataset, function(x) is.numeric(x)) + if (length(isNum) == 0 || sum(isNum) == 0) { + return(dataset) + } + cn <- names(isNum)[isNum] + ## remove set_attr calls when dplyr removes and keep attributes appropriately + descr <- attr(dataset, "description") + if (calc) { + if (length(wts) == 0) { + ms <- summarise_at(dataset, .vars = cn, .funs = ~ mean(., na.rm = TRUE)) %>% + set_attr("description", NULL) + if (scale) { + sds <- summarise_at(dataset, .vars = cn, .funs = ~ sd(., na.rm = TRUE)) %>% + set_attr("description", NULL) + } + } else { + ms <- summarise_at(dataset, .vars = cn, .funs = ~ weighted.mean(., wts, na.rm = TRUE)) %>% + set_attr("description", NULL) + if (scale) { + sds <- summarise_at(dataset, .vars = cn, .funs = ~ weighted.sd(., wts, na.rm = TRUE)) %>% + set_attr("description", NULL) + } + } + } else { + ms <- attr(dataset, "radiant_ms") + sds <- attr(dataset, "radiant_sds") + if (is.null(ms) && is.null(sds)) { + return(dataset) + } + } + if (center && scale) { + icn <- intersect(names(ms), cn) + dataset[icn] <- lapply(icn, function(var) (dataset[[var]] - ms[[var]]) / (sf * sds[[var]])) + dataset %>% + set_attr("radiant_ms", ms) %>% + set_attr("radiant_sds", sds) %>% + set_attr("radiant_sf", sf) %>% + set_attr("description", descr) + } else if (center) { + icn <- intersect(names(ms), cn) + dataset[icn] <- lapply(icn, function(var) dataset[[var]] - ms[[var]]) + dataset %>% + set_attr("radiant_ms", ms) %>% + set_attr("description", descr) + } else if (scale) { + icn <- intersect(names(sds), cn) + dataset[icn] <- lapply(icn, function(var) dataset[[var]] / (sf * sds[[var]])) + dataset %>% + set_attr("radiant_sds", sds) %>% + set_attr("radiant_sf", sf) %>% + set_attr("description", descr) + } else { + dataset + } +} + +#' Summary method for the nn function +#' +#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant +#' +#' @param object Return value from \code{\link{nn}} +#' @param prn Print list of weights +#' @param ... further arguments passed to or from other methods +#' +#' @examples +#' result <- nn(titanic, "survived", "pclass", lev = "Yes") +#' summary(result) +#' @seealso \code{\link{nn}} to generate results +#' @seealso \code{\link{plot.nn}} to plot results +#' @seealso \code{\link{predict.nn}} for prediction +#' +#' @export +summary.nn <- function(object, prn = TRUE, ...) { + if (is.character(object)) { + return(object) + } + cat("Neural Network\n") + if (object$type == "classification") { + cat("Activation function : Logistic (classification)") + } else { + cat("Activation function : Linear (regression)") + } + cat("\nData :", object$df_name) + if (!is.empty(object$data_filter)) { + cat("\nFilter :", gsub("\\n", "", object$data_filter)) + } + if (!is.empty(object$arr)) { + cat("\nArrange :", gsub("\\n", "", object$arr)) + } + if (!is.empty(object$rows)) { + cat("\nSlice :", gsub("\\n", "", object$rows)) + } + cat("\nResponse variable :", object$rvar) + if (object$type == "classification") { + cat("\nLevel :", object$lev, "in", object$rvar) + } + cat("\nExplanatory variables:", paste0(object$evar, collapse = ", "), "\n") + if (length(object$wtsname) > 0) { + cat("Weights used :", object$wtsname, "\n") + } + cat("Network size :", object$size, "\n") + cat("Parameter decay :", object$decay, "\n") + if (!is.empty(object$seed)) { + cat("Seed :", object$seed, "\n") + } + + network <- paste0(object$model$n, collapse = "-") + nweights <- length(object$model$wts) + cat("Network :", network, "with", nweights, "weights\n") + + if (!is.empty(object$wts, "None") && (length(unique(object$wts)) > 2 || min(object$wts) >= 1)) { + cat("Nr obs :", format_nr(sum(object$wts), dec = 0), "\n") + } else { + cat("Nr obs :", format_nr(length(object$rv), dec = 0), "\n") + } + + if (object$model$convergence != 0) { + cat("\n** The model did not converge **") + } else { + if (prn) { + cat("Weights :\n") + oop <- base::options(width = 100) + on.exit(base::options(oop), add = TRUE) + capture.output(summary(object$model))[-1:-2] %>% + gsub("^", " ", .) %>% + paste0(collapse = "\n") %>% + cat("\n") + } + } +} + +#' Variable importance using the vip package and permutation importance +#' +#' @param object Model object created by Radiant +#' @param rvar Label to identify the response or target variable +#' @param lev Reference class for binary classifier (rvar) +#' @param data Data to use for prediction. Will default to the data used to estimate the model +#' @param seed Random seed for reproducibility +#' +#' @importFrom vip vi +#' +#' @export +varimp <- function(object, rvar, lev, data = NULL, seed = 1234) { + if (is.null(data)) data <- object$model$model + + # needed to avoid rescaling during prediction + object$check <- setdiff(object$check, c("center", "standardize")) + + arg_list <- list(object, pred_data = data, se = FALSE) + if (missing(rvar)) rvar <- object$rvar + if (missing(lev) && object$type == "classification") { + if (!is.empty(object$lev)) { + lev <- object$lev + } + if (!is.logical(data[[rvar]])) { + # don't change if already logical + data[[rvar]] <- data[[rvar]] == lev + } + } else if (object$type == "classification") { + data[[rvar]] <- data[[rvar]] == lev + } + + fun <- function(object, arg_list) do.call(predict, arg_list)[["Prediction"]] + if (inherits(object, "rforest")) { + arg_list$OOB <- FALSE # all 0 importance scores when using OOB + if (object$type == "classification") { + fun <- function(object, arg_list) do.call(predict, arg_list)[[object$lev]] + } + } + + pred_fun <- function(object, newdata) { + arg_list$pred_data <- newdata + fun(object, arg_list) + } + + set.seed(seed) + if (object$type == "regression") { + vimp <- vip::vi( + object, + target = rvar, + method = "permute", + metric = "rsq", # "rmse" + pred_wrapper = pred_fun, + train = data + ) + } else { + # required after transition to yardstick by the vip package + data[[rvar]] <- factor(data[[rvar]], levels = c("TRUE", "FALSE")) + vimp <- vip::vi( + object, + target = rvar, + event_level = "first", + method = "permute", + metric = "roc_auc", + pred_wrapper = pred_fun, + train = data + ) + } + + vimp %>% + filter(Importance != 0) %>% + mutate(Variable = factor(Variable, levels = rev(Variable))) +} + +#' Plot permutation importance +#' +#' @param object Model object created by Radiant +#' @param rvar Label to identify the response or target variable +#' @param lev Reference class for binary classifier (rvar) +#' @param data Data to use for prediction. Will default to the data used to estimate the model +#' @param seed Random seed for reproducibility +#' +#' @importFrom vip vi +#' +#' @export +varimp_plot <- function(object, rvar, lev, data = NULL, seed = 1234) { + vi_scores <- varimp(object, rvar, lev, data = data, seed = seed) + visualize(vi_scores, yvar = "Importance", xvar = "Variable", type = "bar", custom = TRUE) + + labs( + title = "Permutation Importance", + x = NULL, + y = ifelse(object$type == "regression", "Importance (R-square decrease)", "Importance (AUC decrease)") + ) + + coord_flip() + + theme(axis.text.y = element_text(hjust = 0)) +} + +#' Plot method for the nn function +#' +#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant +#' +#' @param x Return value from \code{\link{nn}} +#' @param plots Plots to produce for the specified Neural Network model. Use "" to avoid showing any plots (default). Options are "olden" or "garson" for importance plots, or "net" to depict the network structure +#' @param size Font size used +#' @param pad_x Padding for explanatory variable labels in the network plot. Default value is 0.9, smaller numbers (e.g., 0.5) increase the amount of padding +#' @param nrobs Number of data points to show in dashboard scatter plots (-1 for all) +#' @param incl Which variables to include in a coefficient plot or PDP plot +#' @param incl_int Which interactions to investigate in PDP plots +#' @param shiny Did the function call originate inside a shiny app +#' @param custom Logical (TRUE, FALSE) to indicate if ggplot object (or list of ggplot objects) should be returned. This option can be used to customize plots (e.g., add a title, change x and y labels, etc.). See examples and \url{https://ggplot2.tidyverse.org} for options. +#' @param ... further arguments passed to or from other methods +#' +#' @examples +#' result <- nn(titanic, "survived", c("pclass", "sex"), lev = "Yes") +#' plot(result, plots = "net") +#' plot(result, plots = "olden") +#' @seealso \code{\link{nn}} to generate results +#' @seealso \code{\link{summary.nn}} to summarize results +#' @seealso \code{\link{predict.nn}} for prediction +#' +#' @importFrom NeuralNetTools plotnet olden garson +#' @importFrom graphics par +#' +#' @export +plot.nn <- function(x, plots = "vip", size = 12, pad_x = 0.9, nrobs = -1, + incl = NULL, incl_int = NULL, + shiny = FALSE, custom = FALSE, ...) { + if (is.character(x) || !inherits(x$model, "nnet")) { + return(x) + } + plot_list <- list() + nrCol <- 1 + + if ("olden" %in% plots || "olsen" %in% plots) { ## legacy for typo + plot_list[["olsen"]] <- NeuralNetTools::olden(x$model, x_lab = x$coefnames, cex_val = 4) + + coord_flip() + + theme_set(theme_gray(base_size = size)) + + theme(legend.position = "none") + + labs(title = paste0("Olden plot of variable importance (size = ", x$size, ", decay = ", x$decay, ")")) + } + + if ("garson" %in% plots) { + plot_list[["garson"]] <- NeuralNetTools::garson(x$model, x_lab = x$coefnames) + + coord_flip() + + theme_set(theme_gray(base_size = size)) + + theme(legend.position = "none") + + labs(title = paste0("Garson plot of variable importance (size = ", x$size, ", decay = ", x$decay, ")")) + } + + if ("vip" %in% plots) { + vi_scores <- varimp(x) + plot_list[["vip"]] <- + visualize(vi_scores, yvar = "Importance", xvar = "Variable", type = "bar", custom = TRUE) + + labs( + title = paste0("Permutation Importance (size = ", x$size, ", decay = ", x$decay, ")"), + x = NULL, + y = ifelse(x$type == "regression", "Importance (R-square decrease)", "Importance (AUC decrease)") + ) + + coord_flip() + + theme(axis.text.y = element_text(hjust = 0)) + } + + if ("net" %in% plots) { + ## don't need as much spacing at the top and bottom + mar <- par(mar = c(0, 4.1, 0, 2.1)) + on.exit(par(mar = mar$mar)) + return(do.call(NeuralNetTools::plotnet, list(mod_in = x$model, x_names = x$coefnames, pad_x = pad_x, cex_val = size / 16))) + } + + if ("pred_plot" %in% plots) { + nrCol <- 2 + if (length(incl) > 0 | length(incl_int) > 0) { + plot_list <- pred_plot(x, plot_list, incl, incl_int, ...) + } else { + return("Select one or more variables to generate Prediction plots") + } + } + + if ("pdp" %in% plots) { + nrCol <- 2 + if (length(incl) > 0 || length(incl_int) > 0) { + plot_list <- pdp_plot(x, plot_list, incl, incl_int, ...) + } else { + return("Select one or more variables to generate Partial Dependence Plots") + } + } + + if (x$type == "regression" && "dashboard" %in% plots) { + plot_list <- plot.regress(x, plots = "dashboard", lines = "line", nrobs = nrobs, custom = TRUE) + nrCol <- 2 + } + + if (length(plot_list) > 0) { + if (custom) { + if (length(plot_list) == 1) plot_list[[1]] else plot_list + } else { + patchwork::wrap_plots(plot_list, ncol = nrCol) %>% + (function(x) if (isTRUE(shiny)) x else print(x)) + } + } +} + +#' Predict method for the nn function +#' +#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant +#' +#' @param object Return value from \code{\link{nn}} +#' @param pred_data Provide the dataframe to generate predictions (e.g., diamonds). The dataset must contain all columns used in the estimation +#' @param pred_cmd Generate predictions using a command. For example, `pclass = levels(pclass)` would produce predictions for the different levels of factor `pclass`. To add another variable, create a vector of prediction strings, (e.g., c('pclass = levels(pclass)', 'age = seq(0,100,20)') +#' @param dec Number of decimals to show +#' @param envir Environment to extract data from +#' @param ... further arguments passed to or from other methods +#' +#' @examples +#' result <- nn(titanic, "survived", c("pclass", "sex"), lev = "Yes") +#' predict(result, pred_cmd = "pclass = levels(pclass)") +#' result <- nn(diamonds, "price", "carat:color", type = "regression") +#' predict(result, pred_cmd = "carat = 1:3") +#' predict(result, pred_data = diamonds) %>% head() +#' @seealso \code{\link{nn}} to generate the result +#' @seealso \code{\link{summary.nn}} to summarize results +#' +#' @export +predict.nn <- function(object, pred_data = NULL, pred_cmd = "", + dec = 3, envir = parent.frame(), ...) { + if (is.character(object)) { + return(object) + } + + ## ensure you have a name for the prediction dataset + if (is.data.frame(pred_data)) { + df_name <- deparse(substitute(pred_data)) + } else { + df_name <- pred_data + } + + pfun <- function(model, pred, se, conf_lev) { + pred_val <- try(sshhr(predict(model, pred)), silent = TRUE) + if (!inherits(pred_val, "try-error")) { + if (is.vector(pred_val)) { + pred_val <- data.frame(Prediction = pred_val, stringsAsFactors = FALSE) + } else if (is.matrix(pred_val) || is.array(pred_val)) { + pred_val <- as.data.frame(pred_val, stringsAsFactors = FALSE) + if (ncol(pred_val) > 0) { + if (ncol(pred_val) == 1) { + colnames(pred_val) <- "Prediction" + } else { + pred_val <- pred_val[, 1, drop = FALSE] + colnames(pred_val) <- "Prediction" + } + } else { + pred_val <- data.frame(Prediction = numeric(0)) + } + } else { + pred_val <- as.data.frame(pred_val, stringsAsFactors = FALSE) + if (ncol(pred_val) > 0) { + pred_val <- pred_val[, 1, drop = FALSE] + colnames(pred_val) <- "Prediction" + } + } + } + pred_val + } + + predict_model(object, pfun, "nn.predict", pred_data, pred_cmd, conf_lev = 0.95, se = FALSE, dec, envir = envir) %>% + set_attr("radiant_pred_data", df_name) +} + +#' Print method for predict.nn +#' +#' @param x Return value from prediction method +#' @param ... further arguments passed to or from other methods +#' @param n Number of lines of prediction results to print. Use -1 to print all lines +#' +#' @export +print.nn.predict <- function(x, ..., n = 10) { + print_predict_model(x, ..., n = n, header = "Neural Network") +} + +#' Cross-validation for a Neural Network +#' +#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant +#' +#' @param object Object of type "nn" or "nnet" +#' @param K Number of cross validation passes to use +#' @param repeats Repeated cross validation +#' @param size Number of units (nodes) in the hidden layer +#' @param decay Parameter decay +#' @param seed Random seed to use as the starting point +#' @param trace Print progress +#' @param fun Function to use for model evaluation (i.e., auc for classification and RMSE for regression) +#' @param ... Additional arguments to be passed to 'fun' +#' +#' @return A data.frame sorted by the mean of the performance metric +#' +#' @seealso \code{\link{nn}} to generate an initial model that can be passed to cv.nn +#' @seealso \code{\link{Rsq}} to calculate an R-squared measure for a regression +#' @seealso \code{\link{RMSE}} to calculate the Root Mean Squared Error for a regression +#' @seealso \code{\link{MAE}} to calculate the Mean Absolute Error for a regression +#' @seealso \code{\link{auc}} to calculate the area under the ROC curve for classification +#' @seealso \code{\link{profit}} to calculate profits for classification at a cost/margin threshold +#' +#' @importFrom nnet nnet.formula +#' @importFrom shiny getDefaultReactiveDomain withProgress incProgress +#' +#' @examples +#' \dontrun{ +#' result <- nn(dvd, "buy", c("coupon", "purch", "last")) +#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2) +#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2, fun = profit, cost = 1, margin = 5) +#' result <- nn(diamonds, "price", c("carat", "color", "clarity"), type = "regression") +#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2) +#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2, fun = Rsq) +#' } +#' +#' @export +cv.nn <- function(object, K = 5, repeats = 1, decay = seq(0, 1, .2), size = 1:5, + seed = 1234, trace = TRUE, fun, ...) { + if (inherits(object, "nn")) { + ms <- attr(object$model$model, "radiant_ms")[[object$rvar]] + sds <- attr(object$model$model, "radiant_sds")[[object$rvar]] + if (length(sds) == 0) { + sds <- sf <- 1 + } else { + sf <- attr(object$model$model, "radiant_sf") + sf <- ifelse(length(sf) == 0, 2, sf) + } + object <- object$model + } else { + ms <- 0 + sds <- 1 + sf <- 1 + } + + if (inherits(object, "nnet")) { + dv <- as.character(object$call$formula[[2]]) + m <- eval(object$call[["data"]]) + weights <- eval(object$call[["weights"]]) + if (is.numeric(m[[dv]])) { + type <- "regression" + } else { + type <- "classification" + if (is.factor(m[[dv]])) { + lev <- levels(m[[dv]])[1] + } else if (is.logical(m[[dv]])) { + lev <- TRUE + } else { + stop("The level to use for classification is not clear. Use a factor of logical as the response variable") + } + } + } else { + stop("The model object does not seems to be a neural network") + } + + set.seed(seed) + tune_grid <- expand.grid(decay = decay, size = size) + out <- data.frame(mean = NA, std = NA, min = NA, max = NA, decay = tune_grid[["decay"]], size = tune_grid[["size"]]) + + if (missing(fun)) { + if (type == "classification") { + fun <- radiant.model::auc + cn <- "AUC (mean)" + } else { + fun <- radiant.model::RMSE + cn <- "RMSE (mean)" + } + } else { + cn <- glue("{deparse(substitute(fun))} (mean)") + } + + if (length(shiny::getDefaultReactiveDomain()) > 0) { + trace <- FALSE + incProgress <- shiny::incProgress + withProgress <- shiny::withProgress + } else { + incProgress <- function(...) {} + withProgress <- function(...) list(...)[["expr"]] + } + + nitt <- nrow(tune_grid) + withProgress(message = "Running cross-validation (nn)", value = 0, { + for (i in seq_len(nitt)) { + perf <- double(K * repeats) + object$call[["decay"]] <- tune_grid[i, "decay"] + object$call[["size"]] <- tune_grid[i, "size"] + if (trace) cat("Working on size", tune_grid[i, "size"], "decay", tune_grid[i, "decay"], "\n") + for (j in seq_len(repeats)) { + rand <- sample(K, nrow(m), replace = TRUE) + for (k in seq_len(K)) { + object$call[["data"]] <- quote(m[rand != k, , drop = FALSE]) + if (length(weights) > 0) { + object$call[["weights"]] <- weights[rand != k] + } + pred <- predict(eval(object$call), newdata = m[rand == k, , drop = FALSE])[, 1] + if (type == "classification") { + if (missing(...)) { + perf[k + (j - 1) * K] <- fun(pred, unlist(m[rand == k, dv]), lev) + } else { + perf[k + (j - 1) * K] <- fun(pred, unlist(m[rand == k, dv]), lev, ...) + } + } else { + pred <- pred * sf * sds + ms + rvar <- unlist(m[rand == k, dv]) * sf * sds + ms + if (missing(...)) { + perf[k + (j - 1) * K] <- fun(pred, rvar) + } else { + perf[k + (j - 1) * K] <- fun(pred, rvar, ...) + } + } + } + } + out[i, 1:4] <- c(mean(perf), sd(perf), min(perf), max(perf)) + incProgress(1 / nitt, detail = paste("\nCompleted run", i, "out of", nitt)) + } + }) + + if (type == "classification") { + out <- arrange(out, desc(mean)) + } else { + out <- arrange(out, mean) + } + ## show evaluation metric in column name + colnames(out)[1] <- cn + out +} diff --git a/radiant.model/R/svm.R b/radiant.model/R/svm.R index 91d4b45..2694271 100644 --- a/radiant.model/R/svm.R +++ b/radiant.model/R/svm.R @@ -60,7 +60,7 @@ svm <- function(dataset, rvar, evar, # 标准化 if ("standardize" %in% check) { - dataset <- scale_df(dataset) + dataset <- scale_sv(dataset) } ## ---- 3. 分类任务的响应变量 ---- @@ -135,7 +135,7 @@ svm <- function(dataset, rvar, evar, #' Center or standardize variables in a data frame #' @export -scale_df <- function(dataset, center = TRUE, scale = TRUE, +scale_sv <- function(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE) { isNum <- sapply(dataset, function(x) is.numeric(x)) if (length(isNum) == 0 || sum(isNum) == 0) { diff --git a/radiant.model/inst/app/tools/analysis/nn_ui.R b/radiant.model/inst/app/tools/analysis/nn_ui.R index 42290fe..8bd0258 100644 --- a/radiant.model/inst/app/tools/analysis/nn_ui.R +++ b/radiant.model/inst/app/tools/analysis/nn_ui.R @@ -1,729 +1,729 @@ -nn_plots <- c( - "none", "net", "vip", "pred_plot", "pdp", "olden", "garson", "dashboard" -) -names(nn_plots) <- c( - i18n$t("None"), - i18n$t("Network"), - i18n$t("Permutation Importance"), - i18n$t("Prediction plots"), - i18n$t("Partial Dependence"), - i18n$t("Olden"), - i18n$t("Garson"), - i18n$t("Dashboard") -) - -## list of function arguments -nn_args <- as.list(formals(nn)) - -## list of function inputs selected by user -nn_inputs <- reactive({ - ## loop needed because reactive values don't allow single bracket indexing - nn_args$data_filter <- if (input$show_filter) input$data_filter else "" - nn_args$arr <- if (input$show_filter) input$data_arrange else "" - nn_args$rows <- if (input$show_filter) input$data_rows else "" - nn_args$dataset <- input$dataset - for (i in r_drop(names(nn_args))) { - nn_args[[i]] <- input[[paste0("nn_", i)]] - } - nn_args -}) - -nn_pred_args <- as.list(if (exists("predict.nn")) { - formals(predict.nn) -} else { - formals(radiant.model:::predict.nn) -}) - -# list of function inputs selected by user -nn_pred_inputs <- reactive({ - # loop needed because reactive values don't allow single bracket indexing - for (i in names(nn_pred_args)) { - nn_pred_args[[i]] <- input[[paste0("nn_", i)]] - } - - nn_pred_args$pred_cmd <- nn_pred_args$pred_data <- "" - if (input$nn_predict == "cmd") { - nn_pred_args$pred_cmd <- gsub("\\s{2,}", " ", input$nn_pred_cmd) %>% - gsub(";\\s+", ";", .) %>% - gsub("\"", "\'", .) - } else if (input$nn_predict == "data") { - nn_pred_args$pred_data <- input$nn_pred_data - } else if (input$nn_predict == "datacmd") { - nn_pred_args$pred_cmd <- gsub("\\s{2,}", " ", input$nn_pred_cmd) %>% - gsub(";\\s+", ";", .) %>% - gsub("\"", "\'", .) - nn_pred_args$pred_data <- input$nn_pred_data - } - nn_pred_args -}) - -nn_plot_args <- as.list(if (exists("plot.nn")) { - formals(plot.nn) -} else { - formals(radiant.model:::plot.nn) -}) - -## list of function inputs selected by user -nn_plot_inputs <- reactive({ - ## loop needed because reactive values don't allow single bracket indexing - for (i in names(nn_plot_args)) { - nn_plot_args[[i]] <- input[[paste0("nn_", i)]] - } - nn_plot_args -}) - -nn_pred_plot_args <- as.list(if (exists("plot.model.predict")) { - formals(plot.model.predict) -} else { - formals(radiant.model:::plot.model.predict) -}) - -# list of function inputs selected by user -nn_pred_plot_inputs <- reactive({ - # loop needed because reactive values don't allow single bracket indexing - for (i in names(nn_pred_plot_args)) { - nn_pred_plot_args[[i]] <- input[[paste0("nn_", i)]] - } - nn_pred_plot_args -}) - -output$ui_nn_rvar <- renderUI({ - req(input$nn_type) - - withProgress(message = i18n$t("Acquiring variable information"), value = 1, { - if (input$nn_type == "classification") { - vars <- two_level_vars() - } else { - isNum <- .get_class() %in% c("integer", "numeric", "ts") - vars <- varnames()[isNum] - } - }) - - init <- if (input$nn_type == "classification") { - if (is.empty(input$logit_rvar)) isolate(input$nn_rvar) else input$logit_rvar - } else { - if (is.empty(input$reg_rvar)) isolate(input$nn_rvar) else input$reg_rvar - } - - selectInput( - inputId = "nn_rvar", - label = i18n$t("Response variable:"), - choices = vars, - selected = state_single("nn_rvar", vars, init), - multiple = FALSE - ) -}) - -output$ui_nn_lev <- renderUI({ - req(input$nn_type == "classification") - req(available(input$nn_rvar)) - levs <- .get_data()[[input$nn_rvar]] %>% - as_factor() %>% - levels() - - init <- if (is.empty(input$logit_lev)) isolate(input$nn_lev) else input$logit_lev - selectInput( - inputId = "nn_lev", label = i18n$t("Choose level:"), - choices = levs, - selected = state_init("nn_lev", init) - ) -}) - -output$ui_nn_evar <- renderUI({ - if (not_available(input$nn_rvar)) { - return() - } - vars <- varnames() - if (length(vars) > 0) { - vars <- vars[-which(vars == input$nn_rvar)] - } - - init <- if (input$nn_type == "classification") { - # input$logit_evar - if (is.empty(input$logit_evar)) isolate(input$nn_evar) else input$logit_evar - } else { - # input$reg_evar - if (is.empty(input$reg_evar)) isolate(input$nn_evar) else input$reg_evar - } - - selectInput( - inputId = "nn_evar", - label = i18n$t("Explanatory variables:"), - choices = vars, - selected = state_multiple("nn_evar", vars, init), - multiple = TRUE, - size = min(10, length(vars)), - selectize = FALSE - ) -}) - -# function calls generate UI elements -output_incl("nn") -output_incl_int("nn") - -output$ui_nn_wts <- renderUI({ - isNum <- .get_class() %in% c("integer", "numeric", "ts") - vars <- varnames()[isNum] - if (length(vars) > 0 && any(vars %in% input$nn_evar)) { - vars <- base::setdiff(vars, input$nn_evar) - names(vars) <- varnames() %>% - { - .[match(vars, .)] - } %>% - names() - } - vars <- c("None", vars) - - selectInput( - inputId = "nn_wts", label = i18n$t("Weights:"), choices = vars, - selected = state_single("nn_wts", vars), - multiple = FALSE - ) -}) - -output$ui_nn_store_pred_name <- renderUI({ - init <- state_init("nn_store_pred_name", "pred_nn") %>% - sub("\\d{1,}$", "", .) %>% - paste0(., ifelse(is.empty(input$nn_size), "", input$nn_size)) - textInput( - "nn_store_pred_name", - i18n$t("Store predictions:"), - init - ) -}) - -output$ui_nn_store_res_name <- renderUI({ - req(input$dataset) - textInput("nn_store_res_name", i18n$t("Store residuals:"), "", placeholder = i18n$t("Provide variable name")) -}) - -## reset prediction and plot settings when the dataset changes -observeEvent(input$dataset, { - updateSelectInput(session = session, inputId = "nn_predict", selected = "none") - updateSelectInput(session = session, inputId = "nn_plots", selected = "none") -}) - -## reset prediction settings when the model type changes -observeEvent(input$nn_type, { - updateSelectInput(session = session, inputId = "nn_predict", selected = "none") - updateSelectInput(session = session, inputId = "nn_plots", selected = "none") -}) - -output$ui_nn_predict_plot <- renderUI({ - predict_plot_controls("nn") -}) - -output$ui_nn_plots <- renderUI({ - req(input$nn_type) - if (input$nn_type != "regression") { - nn_plots <- head(nn_plots, -1) - } - selectInput( - "nn_plots", i18n$t("Plots:"), - choices = nn_plots, - selected = state_single("nn_plots", nn_plots) - ) -}) - -output$ui_nn_nrobs <- renderUI({ - nrobs <- nrow(.get_data()) - choices <- c("1,000" = 1000, "5,000" = 5000, "10,000" = 10000, "All" = -1) %>% - .[. < nrobs] - selectInput( - "nn_nrobs", i18n$t("Number of data points plotted:"), - choices = choices, - selected = state_single("nn_nrobs", choices, 1000) - ) -}) - -## add a spinning refresh icon if the model needs to be (re)estimated -run_refresh(nn_args, "nn", tabs = "tabs_nn", label = i18n$t("Estimate model"), relabel = i18n$t("Re-estimate model")) - -output$ui_nn <- renderUI({ - req(input$dataset) - tagList( - conditionalPanel( - condition = "input.tabs_nn == 'Summary'", - wellPanel( - actionButton("nn_run", i18n$t("Estimate model"), width = "100%", icon = icon("play", verify_fa = FALSE), class = "btn-success") - ) - ), - wellPanel( - conditionalPanel( - condition = "input.tabs_nn == 'Summary'", - radioButtons( - "nn_type", - label = NULL, - choices = c("classification", "regression") %>% - { names(.) <- c(i18n$t("Classification"), i18n$t("Regression")); . }, - inline = TRUE - ), - uiOutput("ui_nn_rvar"), - uiOutput("ui_nn_lev"), - uiOutput("ui_nn_evar"), - uiOutput("ui_nn_wts"), - tags$table( - tags$td(numericInput( - "nn_size", - label = i18n$t("Size:"), min = 1, max = 20, - value = state_init("nn_size", 1), width = "77px" - )), - tags$td(numericInput( - "nn_decay", - label = i18n$t("Decay:"), min = 0, max = 1, - step = .1, value = state_init("nn_decay", .5), width = "77px" - )), - tags$td(numericInput( - "nn_seed", - label = i18n$t("Seed:"), - value = state_init("nn_seed", 1234), width = "77px" - )), - width = "100%" - ) - ), - conditionalPanel( - condition = "input.tabs_nn == 'Predict'", - selectInput( - "nn_predict", - label = i18n$t("Prediction input type:"), reg_predict, - selected = state_single("nn_predict", reg_predict, "none") - ), - conditionalPanel( - "input.nn_predict == 'data' | input.nn_predict == 'datacmd'", - selectizeInput( - inputId = "nn_pred_data", label = i18n$t("Prediction data:"), - choices = c("None" = "", r_info[["datasetlist"]]), - selected = state_single("nn_pred_data", c("None" = "", r_info[["datasetlist"]])), - multiple = FALSE - ) - ), - conditionalPanel( - "input.nn_predict == 'cmd' | input.nn_predict == 'datacmd'", - returnTextAreaInput( - "nn_pred_cmd", i18n$t("Prediction command:"), - value = state_init("nn_pred_cmd", ""), - rows = 3, - placeholder = i18n$t("Type a formula to set values for model variables (e.g., carat = 1; cut = 'Ideal') and press return") - ) - ), - conditionalPanel( - condition = "input.nn_predict != 'none'", - checkboxInput("nn_pred_plot", i18n$t("Plot predictions"), state_init("nn_pred_plot", FALSE)), - conditionalPanel( - "input.nn_pred_plot == true", - uiOutput("ui_nn_predict_plot") - ) - ), - ## only show if full data is used for prediction - conditionalPanel( - "input.nn_predict == 'data' | input.nn_predict == 'datacmd'", - tags$table( - tags$td(uiOutput("ui_nn_store_pred_name")), - tags$td(actionButton("nn_store_pred", i18n$t("Store"), icon = icon("plus", verify_fa = FALSE)), class = "top") - ) - ) - ), - conditionalPanel( - condition = "input.tabs_nn == 'Plot'", - uiOutput("ui_nn_plots"), - conditionalPanel( - condition = "input.nn_plots == 'pdp' | input.nn_plots == 'pred_plot'", - uiOutput("ui_nn_incl"), - uiOutput("ui_nn_incl_int") - ), - conditionalPanel( - condition = "input.nn_plots == 'dashboard'", - uiOutput("ui_nn_nrobs") - ) - ), - conditionalPanel( - condition = "input.tabs_nn == 'Summary'", - tags$table( - tags$td(uiOutput("ui_nn_store_res_name")), - tags$td(actionButton("nn_store_res", i18n$t("Store"), icon = icon("plus", verify_fa = FALSE)), class = "top") - ) - ) - ), - help_and_report( - modal_title = i18n$t("Neural Network"), - fun_name = "nn", - help_file = inclMD(file.path(getOption("radiant.path.model"), "app/tools/help/nn.md")) - ) - ) -}) - -nn_plot <- reactive({ - if (nn_available() != "available") { - return() - } - if (is.empty(input$nn_plots, "none")) { - return() - } - res <- .nn() - if (is.character(res)) { - return() - } - plot_width <- 650 - if ("dashboard" %in% input$nn_plots) { - plot_height <- 750 - } else if (input$nn_plots %in% c("pdp", "pred_plot")) { - nr_vars <- length(input$nn_incl) + length(input$nn_incl_int) - plot_height <- max(250, ceiling(nr_vars / 2) * 250) - if (length(input$nn_incl_int) > 0) { - plot_width <- plot_width + min(2, length(input$nn_incl_int)) * 90 - } - } else { - mlt <- if ("net" %in% input$nn_plots) 45 else 30 - plot_height <- max(500, length(res$model$coefnames) * mlt) - } - - list(plot_width = plot_width, plot_height = plot_height) -}) - -nn_plot_width <- function() { - nn_plot() %>% - (function(x) if (is.list(x)) x$plot_width else 650) -} - -nn_plot_height <- function() { - nn_plot() %>% - (function(x) if (is.list(x)) x$plot_height else 500) -} - -nn_pred_plot_height <- function() { - if (input$nn_pred_plot) 500 else 1 -} - -## output is called from the main radiant ui.R -output$nn <- renderUI({ - register_print_output("summary_nn", ".summary_nn") - register_print_output("predict_nn", ".predict_print_nn") - register_plot_output( - "predict_plot_nn", ".predict_plot_nn", - height_fun = "nn_pred_plot_height" - ) - register_plot_output( - "plot_nn", ".plot_nn", - height_fun = "nn_plot_height", - width_fun = "nn_plot_width" - ) - - ## three separate tabs - nn_output_panels <- tabsetPanel( - id = "tabs_nn", - tabPanel( - i18n$t("Summary"), value = "Summary", - verbatimTextOutput("summary_nn") - ), - tabPanel( - i18n$t("Predict"), value = "Predict", - conditionalPanel( - "input.nn_pred_plot == true", - download_link("dlp_nn_pred"), - plotOutput("predict_plot_nn", width = "100%", height = "100%") - ), - download_link("dl_nn_pred"), br(), - verbatimTextOutput("predict_nn") - ), - tabPanel( - i18n$t("Plot"), value = "Plot", - download_link("dlp_nn"), - plotOutput("plot_nn", width = "100%", height = "100%") - ) - ) - - stat_tab_panel( - menu = i18n$t("Model > Estimate"), - tool = i18n$t("Neural Network"), - tool_ui = "ui_nn", - output_panels = nn_output_panels - ) -}) - -nn_available <- reactive({ - req(input$nn_type) - if (not_available(input$nn_rvar)) { - if (input$nn_type == "classification") { - i18n$t("This analysis requires a response variable with two levels and one\nor more explanatory variables. If these variables are not available\nplease select another dataset.\n\n") %>% - suggest_data("titanic") - } else { - i18n$t("This analysis requires a response variable of type integer\nor numeric and one or more explanatory variables.\nIf these variables are not available please select another dataset.\n\n") %>% - suggest_data("diamonds") - } - } else if (not_available(input$nn_evar)) { - if (input$nn_type == "classification") { - i18n$t("Please select one or more explanatory variables.") %>% - suggest_data("titanic") - } else { - i18n$t("Please select one or more explanatory variables.") %>% - suggest_data("diamonds") - } - } else { - "available" - } -}) - -.nn <- eventReactive(input$nn_run, { - nni <- nn_inputs() - nni$envir <- r_data - withProgress( - message = i18n$t("Estimating model"), value = 1, - do.call(nn, nni) - ) -}) - -.summary_nn <- reactive({ - if (not_pressed(input$nn_run)) { - return(i18n$t("** Press the Estimate button to estimate the model **")) - } - if (nn_available() != "available") { - return(nn_available()) - } - summary(.nn()) -}) - -.predict_nn <- reactive({ - if (not_pressed(input$nn_run)) { - return(i18n$t("** Press the Estimate button to estimate the model **")) - } - if (nn_available() != "available") { - return(nn_available()) - } - if (is.empty(input$nn_predict, "none")) { - return(i18n$t("** Select prediction input **")) - } - - if ((input$nn_predict == "data" || input$nn_predict == "datacmd") && is.empty(input$nn_pred_data)) { - return(i18n$t("** Select data for prediction **")) - } - if (input$nn_predict == "cmd" && is.empty(input$nn_pred_cmd)) { - return(i18n$t("** Enter prediction commands **")) - } - - withProgress(message = i18n$t("Generating predictions"), value = 1, { - nni <- nn_pred_inputs() - nni$object <- .nn() - nni$envir <- r_data - do.call(predict, nni) - }) -}) - -.predict_print_nn <- reactive({ - .predict_nn() %>% - { - if (is.character(.)) cat(., "\n") else print(.) - } -}) - -.predict_plot_nn <- reactive({ - req( - pressed(input$nn_run), input$nn_pred_plot, - available(input$nn_xvar), - !is.empty(input$nn_predict, "none") - ) - - # if (not_pressed(input$nn_run)) return(invisible()) - # if (nn_available() != "available") return(nn_available()) - # req(input$nn_pred_plot, available(input$nn_xvar)) - # if (is.empty(input$nn_predict, "none")) return(invisible()) - # if ((input$nn_predict == "data" || input$nn_predict == "datacmd") && is.empty(input$nn_pred_data)) { - # return(invisible()) - # } - # if (input$nn_predict == "cmd" && is.empty(input$nn_pred_cmd)) { - # return(invisible()) - # } - - withProgress(message = i18n$t("Generating prediction plot"), value = 1, { - do.call(plot, c(list(x = .predict_nn()), nn_pred_plot_inputs())) - }) -}) - -.plot_nn <- reactive({ - if (not_pressed(input$nn_run)) { - return(i18n$t("** Press the Estimate button to estimate the model **")) - } else if (nn_available() != "available") { - return(nn_available()) - } - req(input$nn_size) - if (is.empty(input$nn_plots, "none")) { - return(i18n$t("Please select a neural network plot from the drop-down menu")) - } - pinp <- nn_plot_inputs() - pinp$shiny <- TRUE - pinp$size <- NULL - if (input$nn_plots == "dashboard") { - req(input$nn_nrobs) - } - - if (input$nn_plots == "net") { - .nn() %>% - (function(x) if (is.character(x)) invisible() else capture_plot(do.call(plot, c(list(x = x), pinp)))) - } else { - withProgress(message = i18n$t("Generating plots"), value = 1, { - do.call(plot, c(list(x = .nn()), pinp)) - }) - } -}) - -observeEvent(input$nn_store_res, { - req(pressed(input$nn_run)) - robj <- .nn() - if (!is.list(robj)) { - return() - } - fixed <- fix_names(input$nn_store_res_name) - updateTextInput(session, "nn_store_res_name", value = fixed) - withProgress( - message = i18n$t("Storing residuals"), value = 1, - r_data[[input$dataset]] <- store(r_data[[input$dataset]], robj, name = fixed) - ) -}) - -observeEvent(input$nn_store_pred, { - req(!is.empty(input$nn_pred_data), pressed(input$nn_run)) - pred <- .predict_nn() - if (is.null(pred)) { - return() - } - fixed <- fix_names(input$nn_store_pred_name) - updateTextInput(session, "nn_store_pred_name", value = fixed) - withProgress( - message = i18n$t("Storing predictions"), value = 1, - r_data[[input$nn_pred_data]] <- store( - r_data[[input$nn_pred_data]], pred, - name = fixed - ) - ) -}) - -nn_report <- function() { - if (is.empty(input$nn_evar)) { - return(invisible()) - } - - outputs <- c("summary") - inp_out <- list(list(prn = TRUE), "") - figs <- FALSE - - if (!is.empty(input$nn_plots, "none")) { - inp <- check_plot_inputs(nn_plot_inputs()) - inp$size <- NULL - inp_out[[2]] <- clean_args(inp, nn_plot_args[-1]) - inp_out[[2]]$custom <- FALSE - outputs <- c(outputs, "plot") - figs <- TRUE - } - - if (!is.empty(input$nn_store_res_name)) { - fixed <- fix_names(input$nn_store_res_name) - updateTextInput(session, "nn_store_res_name", value = fixed) - xcmd <- paste0(input$dataset, " <- store(", input$dataset, ", result, name = \"", fixed, "\")\n") - } else { - xcmd <- "" - } - - if (!is.empty(input$nn_predict, "none") && - (!is.empty(input$nn_pred_data) || !is.empty(input$nn_pred_cmd))) { - pred_args <- clean_args(nn_pred_inputs(), nn_pred_args[-1]) - - if (!is.empty(pred_args$pred_cmd)) { - pred_args$pred_cmd <- strsplit(pred_args$pred_cmd, ";\\s*")[[1]] - } else { - pred_args$pred_cmd <- NULL - } - - if (!is.empty(pred_args$pred_data)) { - pred_args$pred_data <- as.symbol(pred_args$pred_data) - } else { - pred_args$pred_data <- NULL - } - - inp_out[[2 + figs]] <- pred_args - outputs <- c(outputs, "pred <- predict") - xcmd <- paste0(xcmd, "print(pred, n = 10)") - if (input$nn_predict %in% c("data", "datacmd")) { - fixed <- fix_names(input$nn_store_pred_name) - updateTextInput(session, "nn_store_pred_name", value = fixed) - xcmd <- paste0( - xcmd, "\n", input$nn_pred_data, " <- store(", - input$nn_pred_data, ", pred, name = \"", fixed, "\")" - ) - } - - if (input$nn_pred_plot && !is.empty(input$nn_xvar)) { - inp_out[[3 + figs]] <- clean_args(nn_pred_plot_inputs(), nn_pred_plot_args[-1]) - inp_out[[3 + figs]]$result <- "pred" - outputs <- c(outputs, "plot") - figs <- TRUE - } - } - - nn_inp <- nn_inputs() - if (input$nn_type == "regression") { - nn_inp$lev <- NULL - } - - update_report( - inp_main = clean_args(nn_inp, nn_args), - fun_name = "nn", - inp_out = inp_out, - outputs = outputs, - figs = figs, - fig.width = nn_plot_width(), - fig.height = nn_plot_height(), - xcmd = xcmd - ) -} - -dl_nn_pred <- function(path) { - if (pressed(input$nn_run)) { - write.csv(.predict_nn(), file = path, row.names = FALSE) - } else { - cat(i18n$t("No output available. Press the Estimate button to generate results"), file = path) - } -} - -download_handler( - id = "dl_nn_pred", - fun = dl_nn_pred, - fn = function() paste0(input$dataset, "_nn_pred"), - type = "csv", - caption = i18n$t("Save predictions") -) - -download_handler( - id = "dlp_nn_pred", - fun = download_handler_plot, - fn = function() paste0(input$dataset, "_nn_pred"), - type = "png", - caption = i18n$t("Save neural network prediction plot"), - plot = .predict_plot_nn, - width = plot_width, - height = nn_pred_plot_height -) - -download_handler( - id = "dlp_nn", - fun = download_handler_plot, - fn = function() paste0(input$dataset, "_nn"), - type = "png", - caption = i18n$t("Save neural network plot"), - plot = .plot_nn, - width = nn_plot_width, - height = nn_plot_height -) - -observeEvent(input$nn_report, { - r_info[["latest_screenshot"]] <- NULL - nn_report() -}) - -observeEvent(input$nn_screenshot, { - r_info[["latest_screenshot"]] <- NULL - radiant_screenshot_modal("modal_nn_screenshot") -}) - -observeEvent(input$modal_nn_screenshot, { - nn_report() - removeModal() ## remove shiny modal after save -}) +nn_plots <- c( + "none", "net", "vip", "pred_plot", "pdp", "olden", "garson", "dashboard" +) +names(nn_plots) <- c( + i18n$t("None"), + i18n$t("Network"), + i18n$t("Permutation Importance"), + i18n$t("Prediction plots"), + i18n$t("Partial Dependence"), + i18n$t("Olden"), + i18n$t("Garson"), + i18n$t("Dashboard") +) + +## list of function arguments +nn_args <- as.list(formals(nn)) + +## list of function inputs selected by user +nn_inputs <- reactive({ + ## loop needed because reactive values don't allow single bracket indexing + nn_args$data_filter <- if (input$show_filter) input$data_filter else "" + nn_args$arr <- if (input$show_filter) input$data_arrange else "" + nn_args$rows <- if (input$show_filter) input$data_rows else "" + nn_args$dataset <- input$dataset + for (i in r_drop(names(nn_args))) { + nn_args[[i]] <- input[[paste0("nn_", i)]] + } + nn_args +}) + +nn_pred_args <- as.list(if (exists("predict.nn")) { + formals(predict.nn) +} else { + formals(radiant.model:::predict.nn) +}) + +# list of function inputs selected by user +nn_pred_inputs <- reactive({ + # loop needed because reactive values don't allow single bracket indexing + for (i in names(nn_pred_args)) { + nn_pred_args[[i]] <- input[[paste0("nn_", i)]] + } + + nn_pred_args$pred_cmd <- nn_pred_args$pred_data <- "" + if (input$nn_predict == "cmd") { + nn_pred_args$pred_cmd <- gsub("\\s{2,}", " ", input$nn_pred_cmd) %>% + gsub(";\\s+", ";", .) %>% + gsub("\"", "\'", .) + } else if (input$nn_predict == "data") { + nn_pred_args$pred_data <- input$nn_pred_data + } else if (input$nn_predict == "datacmd") { + nn_pred_args$pred_cmd <- gsub("\\s{2,}", " ", input$nn_pred_cmd) %>% + gsub(";\\s+", ";", .) %>% + gsub("\"", "\'", .) + nn_pred_args$pred_data <- input$nn_pred_data + } + nn_pred_args +}) + +nn_plot_args <- as.list(if (exists("plot.nn")) { + formals(plot.nn) +} else { + formals(radiant.model:::plot.nn) +}) + +## list of function inputs selected by user +nn_plot_inputs <- reactive({ + ## loop needed because reactive values don't allow single bracket indexing + for (i in names(nn_plot_args)) { + nn_plot_args[[i]] <- input[[paste0("nn_", i)]] + } + nn_plot_args +}) + +nn_pred_plot_args <- as.list(if (exists("plot.model.predict")) { + formals(plot.model.predict) +} else { + formals(radiant.model:::plot.model.predict) +}) + +# list of function inputs selected by user +nn_pred_plot_inputs <- reactive({ + # loop needed because reactive values don't allow single bracket indexing + for (i in names(nn_pred_plot_args)) { + nn_pred_plot_args[[i]] <- input[[paste0("nn_", i)]] + } + nn_pred_plot_args +}) + +output$ui_nn_rvar <- renderUI({ + req(input$nn_type) + + withProgress(message = i18n$t("Acquiring variable information"), value = 1, { + if (input$nn_type == "classification") { + vars <- two_level_vars() + } else { + isNum <- .get_class() %in% c("integer", "numeric", "ts") + vars <- varnames()[isNum] + } + }) + + init <- if (input$nn_type == "classification") { + if (is.empty(input$logit_rvar)) isolate(input$nn_rvar) else input$logit_rvar + } else { + if (is.empty(input$reg_rvar)) isolate(input$nn_rvar) else input$reg_rvar + } + + selectInput( + inputId = "nn_rvar", + label = i18n$t("Response variable:"), + choices = vars, + selected = state_single("nn_rvar", vars, init), + multiple = FALSE + ) +}) + +output$ui_nn_lev <- renderUI({ + req(input$nn_type == "classification") + req(available(input$nn_rvar)) + levs <- .get_data()[[input$nn_rvar]] %>% + as_factor() %>% + levels() + + init <- if (is.empty(input$logit_lev)) isolate(input$nn_lev) else input$logit_lev + selectInput( + inputId = "nn_lev", label = i18n$t("Choose level:"), + choices = levs, + selected = state_init("nn_lev", init) + ) +}) + +output$ui_nn_evar <- renderUI({ + if (not_available(input$nn_rvar)) { + return() + } + vars <- varnames() + if (length(vars) > 0) { + vars <- vars[-which(vars == input$nn_rvar)] + } + + init <- if (input$nn_type == "classification") { + # input$logit_evar + if (is.empty(input$logit_evar)) isolate(input$nn_evar) else input$logit_evar + } else { + # input$reg_evar + if (is.empty(input$reg_evar)) isolate(input$nn_evar) else input$reg_evar + } + + selectInput( + inputId = "nn_evar", + label = i18n$t("Explanatory variables:"), + choices = vars, + selected = state_multiple("nn_evar", vars, init), + multiple = TRUE, + size = min(10, length(vars)), + selectize = FALSE + ) +}) + +# function calls generate UI elements +output_incl("nn") +output_incl_int("nn") + +output$ui_nn_wts <- renderUI({ + isNum <- .get_class() %in% c("integer", "numeric", "ts") + vars <- varnames()[isNum] + if (length(vars) > 0 && any(vars %in% input$nn_evar)) { + vars <- base::setdiff(vars, input$nn_evar) + names(vars) <- varnames() %>% + { + .[match(vars, .)] + } %>% + names() + } + vars <- c("None", vars) + + selectInput( + inputId = "nn_wts", label = i18n$t("Weights:"), choices = vars, + selected = state_single("nn_wts", vars), + multiple = FALSE + ) +}) + +output$ui_nn_store_pred_name <- renderUI({ + init <- state_init("nn_store_pred_name", "pred_nn") %>% + sub("\\d{1,}$", "", .) %>% + paste0(., ifelse(is.empty(input$nn_size), "", input$nn_size)) + textInput( + "nn_store_pred_name", + i18n$t("Store predictions:"), + init + ) +}) + +output$ui_nn_store_res_name <- renderUI({ + req(input$dataset) + textInput("nn_store_res_name", i18n$t("Store residuals:"), "", placeholder = i18n$t("Provide variable name")) +}) + +## reset prediction and plot settings when the dataset changes +observeEvent(input$dataset, { + updateSelectInput(session = session, inputId = "nn_predict", selected = "none") + updateSelectInput(session = session, inputId = "nn_plots", selected = "none") +}) + +## reset prediction settings when the model type changes +observeEvent(input$nn_type, { + updateSelectInput(session = session, inputId = "nn_predict", selected = "none") + updateSelectInput(session = session, inputId = "nn_plots", selected = "none") +}) + +output$ui_nn_predict_plot <- renderUI({ + predict_plot_controls("nn") +}) + +output$ui_nn_plots <- renderUI({ + req(input$nn_type) + if (input$nn_type != "regression") { + nn_plots <- head(nn_plots, -1) + } + selectInput( + "nn_plots", i18n$t("Plots:"), + choices = nn_plots, + selected = state_single("nn_plots", nn_plots) + ) +}) + +output$ui_nn_nrobs <- renderUI({ + nrobs <- nrow(.get_data()) + choices <- c("1,000" = 1000, "5,000" = 5000, "10,000" = 10000, "All" = -1) %>% + .[. < nrobs] + selectInput( + "nn_nrobs", i18n$t("Number of data points plotted:"), + choices = choices, + selected = state_single("nn_nrobs", choices, 1000) + ) +}) + +## add a spinning refresh icon if the model needs to be (re)estimated +run_refresh(nn_args, "nn", tabs = "tabs_nn", label = i18n$t("Estimate model"), relabel = i18n$t("Re-estimate model")) + +output$ui_nn <- renderUI({ + req(input$dataset) + tagList( + conditionalPanel( + condition = "input.tabs_nn == 'Summary'", + wellPanel( + actionButton("nn_run", i18n$t("Estimate model"), width = "100%", icon = icon("play", verify_fa = FALSE), class = "btn-success") + ) + ), + wellPanel( + conditionalPanel( + condition = "input.tabs_nn == 'Summary'", + radioButtons( + "nn_type", + label = NULL, + choices = c("classification", "regression") %>% + { names(.) <- c(i18n$t("Classification"), i18n$t("Regression")); . }, + inline = TRUE + ), + uiOutput("ui_nn_rvar"), + uiOutput("ui_nn_lev"), + uiOutput("ui_nn_evar"), + uiOutput("ui_nn_wts"), + tags$table( + tags$td(numericInput( + "nn_size", + label = i18n$t("Size:"), min = 1, max = 20, + value = state_init("nn_size", 1), width = "77px" + )), + tags$td(numericInput( + "nn_decay", + label = i18n$t("Decay:"), min = 0, max = 1, + step = .1, value = state_init("nn_decay", .5), width = "77px" + )), + tags$td(numericInput( + "nn_seed", + label = i18n$t("Seed:"), + value = state_init("nn_seed", 1234), width = "77px" + )), + width = "100%" + ) + ), + conditionalPanel( + condition = "input.tabs_nn == 'Predict'", + selectInput( + "nn_predict", + label = i18n$t("Prediction input type:"), reg_predict, + selected = state_single("nn_predict", reg_predict, "none") + ), + conditionalPanel( + "input.nn_predict == 'data' | input.nn_predict == 'datacmd'", + selectizeInput( + inputId = "nn_pred_data", label = i18n$t("Prediction data:"), + choices = c("None" = "", r_info[["datasetlist"]]), + selected = state_single("nn_pred_data", c("None" = "", r_info[["datasetlist"]])), + multiple = FALSE + ) + ), + conditionalPanel( + "input.nn_predict == 'cmd' | input.nn_predict == 'datacmd'", + returnTextAreaInput( + "nn_pred_cmd", i18n$t("Prediction command:"), + value = state_init("nn_pred_cmd", ""), + rows = 3, + placeholder = i18n$t("Type a formula to set values for model variables (e.g., carat = 1; cut = 'Ideal') and press return") + ) + ), + conditionalPanel( + condition = "input.nn_predict != 'none'", + checkboxInput("nn_pred_plot", i18n$t("Plot predictions"), state_init("nn_pred_plot", FALSE)), + conditionalPanel( + "input.nn_pred_plot == true", + uiOutput("ui_nn_predict_plot") + ) + ), + ## only show if full data is used for prediction + conditionalPanel( + "input.nn_predict == 'data' | input.nn_predict == 'datacmd'", + tags$table( + tags$td(uiOutput("ui_nn_store_pred_name")), + tags$td(actionButton("nn_store_pred", i18n$t("Store"), icon = icon("plus", verify_fa = FALSE)), class = "top") + ) + ) + ), + conditionalPanel( + condition = "input.tabs_nn == 'Plot'", + uiOutput("ui_nn_plots"), + conditionalPanel( + condition = "input.nn_plots == 'pdp' | input.nn_plots == 'pred_plot'", + uiOutput("ui_nn_incl"), + uiOutput("ui_nn_incl_int") + ), + conditionalPanel( + condition = "input.nn_plots == 'dashboard'", + uiOutput("ui_nn_nrobs") + ) + ), + conditionalPanel( + condition = "input.tabs_nn == 'Summary'", + tags$table( + tags$td(uiOutput("ui_nn_store_res_name")), + tags$td(actionButton("nn_store_res", i18n$t("Store"), icon = icon("plus", verify_fa = FALSE)), class = "top") + ) + ) + ), + help_and_report( + modal_title = i18n$t("Neural Network"), + fun_name = "nn", + help_file = inclMD(file.path(getOption("radiant.path.model"), "app/tools/help/nn.md")) + ) + ) +}) + +nn_plot <- reactive({ + if (nn_available() != "available") { + return() + } + if (is.empty(input$nn_plots, "none")) { + return() + } + res <- .nn() + if (is.character(res)) { + return() + } + plot_width <- 650 + if ("dashboard" %in% input$nn_plots) { + plot_height <- 750 + } else if (input$nn_plots %in% c("pdp", "pred_plot")) { + nr_vars <- length(input$nn_incl) + length(input$nn_incl_int) + plot_height <- max(250, ceiling(nr_vars / 2) * 250) + if (length(input$nn_incl_int) > 0) { + plot_width <- plot_width + min(2, length(input$nn_incl_int)) * 90 + } + } else { + mlt <- if ("net" %in% input$nn_plots) 45 else 30 + plot_height <- max(500, length(res$model$coefnames) * mlt) + } + + list(plot_width = plot_width, plot_height = plot_height) +}) + +nn_plot_width <- function() { + nn_plot() %>% + (function(x) if (is.list(x)) x$plot_width else 650) +} + +nn_plot_height <- function() { + nn_plot() %>% + (function(x) if (is.list(x)) x$plot_height else 500) +} + +nn_pred_plot_height <- function() { + if (input$nn_pred_plot) 500 else 1 +} + +## output is called from the main radiant ui.R +output$nn <- renderUI({ + register_print_output("summary_nn", ".summary_nn") + register_print_output("predict_nn", ".predict_print_nn") + register_plot_output( + "predict_plot_nn", ".predict_plot_nn", + height_fun = "nn_pred_plot_height" + ) + register_plot_output( + "plot_nn", ".plot_nn", + height_fun = "nn_plot_height", + width_fun = "nn_plot_width" + ) + + ## three separate tabs + nn_output_panels <- tabsetPanel( + id = "tabs_nn", + tabPanel( + i18n$t("Summary"), value = "Summary", + verbatimTextOutput("summary_nn") + ), + tabPanel( + i18n$t("Predict"), value = "Predict", + conditionalPanel( + "input.nn_pred_plot == true", + download_link("dlp_nn_pred"), + plotOutput("predict_plot_nn", width = "100%", height = "100%") + ), + download_link("dl_nn_pred"), br(), + verbatimTextOutput("predict_nn") + ), + tabPanel( + i18n$t("Plot"), value = "Plot", + download_link("dlp_nn"), + plotOutput("plot_nn", width = "100%", height = "100%") + ) + ) + + stat_tab_panel( + menu = i18n$t("Model > Estimate"), + tool = i18n$t("Neural Network"), + tool_ui = "ui_nn", + output_panels = nn_output_panels + ) +}) + +nn_available <- reactive({ + req(input$nn_type) + if (not_available(input$nn_rvar)) { + if (input$nn_type == "classification") { + i18n$t("This analysis requires a response variable with two levels and one\nor more explanatory variables. If these variables are not available\nplease select another dataset.\n\n") %>% + suggest_data("titanic") + } else { + i18n$t("This analysis requires a response variable of type integer\nor numeric and one or more explanatory variables.\nIf these variables are not available please select another dataset.\n\n") %>% + suggest_data("diamonds") + } + } else if (not_available(input$nn_evar)) { + if (input$nn_type == "classification") { + i18n$t("Please select one or more explanatory variables.") %>% + suggest_data("titanic") + } else { + i18n$t("Please select one or more explanatory variables.") %>% + suggest_data("diamonds") + } + } else { + "available" + } +}) + +.nn <- eventReactive(input$nn_run, { + nni <- nn_inputs() + nni$envir <- r_data + withProgress( + message = i18n$t("Estimating model"), value = 1, + do.call(nn, nni) + ) +}) + +.summary_nn <- reactive({ + if (not_pressed(input$nn_run)) { + return(i18n$t("** Press the Estimate button to estimate the model **")) + } + if (nn_available() != "available") { + return(nn_available()) + } + summary(.nn()) +}) + +.predict_nn <- reactive({ + if (not_pressed(input$nn_run)) { + return(i18n$t("** Press the Estimate button to estimate the model **")) + } + if (nn_available() != "available") { + return(nn_available()) + } + if (is.empty(input$nn_predict, "none")) { + return(i18n$t("** Select prediction input **")) + } + + if ((input$nn_predict == "data" || input$nn_predict == "datacmd") && is.empty(input$nn_pred_data)) { + return(i18n$t("** Select data for prediction **")) + } + if (input$nn_predict == "cmd" && is.empty(input$nn_pred_cmd)) { + return(i18n$t("** Enter prediction commands **")) + } + + withProgress(message = i18n$t("Generating predictions"), value = 1, { + nni <- nn_pred_inputs() + nni$object <- .nn() + nni$envir <- r_data + do.call(predict, nni) + }) +}) + +.predict_print_nn <- reactive({ + .predict_nn() %>% + { + if (is.character(.)) cat(., "\n") else print(.) + } +}) + +.predict_plot_nn <- reactive({ + req( + pressed(input$nn_run), input$nn_pred_plot, + available(input$nn_xvar), + !is.empty(input$nn_predict, "none") + ) + + # if (not_pressed(input$nn_run)) return(invisible()) + # if (nn_available() != "available") return(nn_available()) + # req(input$nn_pred_plot, available(input$nn_xvar)) + # if (is.empty(input$nn_predict, "none")) return(invisible()) + # if ((input$nn_predict == "data" || input$nn_predict == "datacmd") && is.empty(input$nn_pred_data)) { + # return(invisible()) + # } + # if (input$nn_predict == "cmd" && is.empty(input$nn_pred_cmd)) { + # return(invisible()) + # } + + withProgress(message = i18n$t("Generating prediction plot"), value = 1, { + do.call(plot, c(list(x = .predict_nn()), nn_pred_plot_inputs())) + }) +}) + +.plot_nn <- reactive({ + if (not_pressed(input$nn_run)) { + return(i18n$t("** Press the Estimate button to estimate the model **")) + } else if (nn_available() != "available") { + return(nn_available()) + } + req(input$nn_size) + if (is.empty(input$nn_plots, "none")) { + return(i18n$t("Please select a neural network plot from the drop-down menu")) + } + pinp <- nn_plot_inputs() + pinp$shiny <- TRUE + pinp$size <- NULL + if (input$nn_plots == "dashboard") { + req(input$nn_nrobs) + } + + if (input$nn_plots == "net") { + .nn() %>% + (function(x) if (is.character(x)) invisible() else capture_plot(do.call(plot, c(list(x = x), pinp)))) + } else { + withProgress(message = i18n$t("Generating plots"), value = 1, { + do.call(plot, c(list(x = .nn()), pinp)) + }) + } +}) + +observeEvent(input$nn_store_res, { + req(pressed(input$nn_run)) + robj <- .nn() + if (!is.list(robj)) { + return() + } + fixed <- fix_names(input$nn_store_res_name) + updateTextInput(session, "nn_store_res_name", value = fixed) + withProgress( + message = i18n$t("Storing residuals"), value = 1, + r_data[[input$dataset]] <- store(r_data[[input$dataset]], robj, name = fixed) + ) +}) + +observeEvent(input$nn_store_pred, { + req(!is.empty(input$nn_pred_data), pressed(input$nn_run)) + pred <- .predict_nn() + if (is.null(pred)) { + return() + } + fixed <- fix_names(input$nn_store_pred_name) + updateTextInput(session, "nn_store_pred_name", value = fixed) + withProgress( + message = i18n$t("Storing predictions"), value = 1, + r_data[[input$nn_pred_data]] <- store( + r_data[[input$nn_pred_data]], pred, + name = fixed + ) + ) +}) + +nn_report <- function() { + if (is.empty(input$nn_evar)) { + return(invisible()) + } + + outputs <- c("summary") + inp_out <- list(list(prn = TRUE), "") + figs <- FALSE + + if (!is.empty(input$nn_plots, "none")) { + inp <- check_plot_inputs(nn_plot_inputs()) + inp$size <- NULL + inp_out[[2]] <- clean_args(inp, nn_plot_args[-1]) + inp_out[[2]]$custom <- FALSE + outputs <- c(outputs, "plot") + figs <- TRUE + } + + if (!is.empty(input$nn_store_res_name)) { + fixed <- fix_names(input$nn_store_res_name) + updateTextInput(session, "nn_store_res_name", value = fixed) + xcmd <- paste0(input$dataset, " <- store(", input$dataset, ", result, name = \"", fixed, "\")\n") + } else { + xcmd <- "" + } + + if (!is.empty(input$nn_predict, "none") && + (!is.empty(input$nn_pred_data) || !is.empty(input$nn_pred_cmd))) { + pred_args <- clean_args(nn_pred_inputs(), nn_pred_args[-1]) + + if (!is.empty(pred_args$pred_cmd)) { + pred_args$pred_cmd <- strsplit(pred_args$pred_cmd, ";\\s*")[[1]] + } else { + pred_args$pred_cmd <- NULL + } + + if (!is.empty(pred_args$pred_data)) { + pred_args$pred_data <- as.symbol(pred_args$pred_data) + } else { + pred_args$pred_data <- NULL + } + + inp_out[[2 + figs]] <- pred_args + outputs <- c(outputs, "pred <- predict") + xcmd <- paste0(xcmd, "print(pred, n = 10)") + if (input$nn_predict %in% c("data", "datacmd")) { + fixed <- fix_names(input$nn_store_pred_name) + updateTextInput(session, "nn_store_pred_name", value = fixed) + xcmd <- paste0( + xcmd, "\n", input$nn_pred_data, " <- store(", + input$nn_pred_data, ", pred, name = \"", fixed, "\")" + ) + } + + if (input$nn_pred_plot && !is.empty(input$nn_xvar)) { + inp_out[[3 + figs]] <- clean_args(nn_pred_plot_inputs(), nn_pred_plot_args[-1]) + inp_out[[3 + figs]]$result <- "pred" + outputs <- c(outputs, "plot") + figs <- TRUE + } + } + + nn_inp <- nn_inputs() + if (input$nn_type == "regression") { + nn_inp$lev <- NULL + } + + update_report( + inp_main = clean_args(nn_inp, nn_args), + fun_name = "nn", + inp_out = inp_out, + outputs = outputs, + figs = figs, + fig.width = nn_plot_width(), + fig.height = nn_plot_height(), + xcmd = xcmd + ) +} + +dl_nn_pred <- function(path) { + if (pressed(input$nn_run)) { + write.csv(.predict_nn(), file = path, row.names = FALSE) + } else { + cat(i18n$t("No output available. Press the Estimate button to generate results"), file = path) + } +} + +download_handler( + id = "dl_nn_pred", + fun = dl_nn_pred, + fn = function() paste0(input$dataset, "_nn_pred"), + type = "csv", + caption = i18n$t("Save predictions") +) + +download_handler( + id = "dlp_nn_pred", + fun = download_handler_plot, + fn = function() paste0(input$dataset, "_nn_pred"), + type = "png", + caption = i18n$t("Save neural network prediction plot"), + plot = .predict_plot_nn, + width = plot_width, + height = nn_pred_plot_height +) + +download_handler( + id = "dlp_nn", + fun = download_handler_plot, + fn = function() paste0(input$dataset, "_nn"), + type = "png", + caption = i18n$t("Save neural network plot"), + plot = .plot_nn, + width = nn_plot_width, + height = nn_plot_height +) + +observeEvent(input$nn_report, { + r_info[["latest_screenshot"]] <- NULL + nn_report() +}) + +observeEvent(input$nn_screenshot, { + r_info[["latest_screenshot"]] <- NULL + radiant_screenshot_modal("modal_nn_screenshot") +}) + +observeEvent(input$modal_nn_screenshot, { + nn_report() + removeModal() ## remove shiny modal after save +}) diff --git a/radiant.model/man/scale_df.Rd b/radiant.model/man/scale_df.Rd index 0353553..d5c196e 100644 --- a/radiant.model/man/scale_df.Rd +++ b/radiant.model/man/scale_df.Rd @@ -1,11 +1,9 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/nn.R, R/svm.R +% Please edit documentation in R/nn.R \name{scale_df} \alias{scale_df} \title{Center or standardize variables in a data frame} \usage{ -scale_df(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE) - scale_df(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE) } \arguments{ @@ -25,7 +23,5 @@ scale_df(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE) Scaled data frame } \description{ -Center or standardize variables in a data frame - Center or standardize variables in a data frame } diff --git a/radiant.model/man/scale_sv.Rd b/radiant.model/man/scale_sv.Rd new file mode 100644 index 0000000..0bde119 --- /dev/null +++ b/radiant.model/man/scale_sv.Rd @@ -0,0 +1,11 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/svm.R +\name{scale_sv} +\alias{scale_sv} +\title{Center or standardize variables in a data frame} +\usage{ +scale_sv(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE) +} +\description{ +Center or standardize variables in a data frame +} -- 2.22.0