Commit c56c5069 authored by wuzekai's avatar wuzekai

修复了神经网络界面的bug

parent 4bde86d7
...@@ -112,6 +112,7 @@ export(repeater) ...@@ -112,6 +112,7 @@ export(repeater)
export(rforest) export(rforest)
export(rig) export(rig)
export(scale_df) export(scale_df)
export(scale_sv)
export(sdw) export(sdw)
export(sensitivity) export(sensitivity)
export(sim_cleaner) export(sim_cleaner)
......
#' Neural Networks using nnet #' Neural Networks using nnet
#' #'
#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant #' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant
#' #'
#' @param dataset Dataset #' @param dataset Dataset
#' @param rvar The response variable in the model #' @param rvar The response variable in the model
#' @param evar Explanatory variables in the model #' @param evar Explanatory variables in the model
#' @param type Model type (i.e., "classification" or "regression") #' @param type Model type (i.e., "classification" or "regression")
#' @param lev The level in the response variable defined as _success_ #' @param lev The level in the response variable defined as _success_
#' @param size Number of units (nodes) in the hidden layer #' @param size Number of units (nodes) in the hidden layer
#' @param decay Parameter decay #' @param decay Parameter decay
#' @param wts Weights to use in estimation #' @param wts Weights to use in estimation
#' @param seed Random seed to use as the starting point #' @param seed Random seed to use as the starting point
#' @param check Optional estimation parameters ("standardize" is the default) #' @param check Optional estimation parameters ("standardize" is the default)
#' @param form Optional formula to use instead of rvar and evar #' @param form Optional formula to use instead of rvar and evar
#' @param data_filter Expression entered in, e.g., Data > View to filter the dataset in Radiant. The expression should be a string (e.g., "price > 10000") #' @param data_filter Expression entered in, e.g., Data > View to filter the dataset in Radiant. The expression should be a string (e.g., "price > 10000")
#' @param arr Expression to arrange (sort) the data on (e.g., "color, desc(price)") #' @param arr Expression to arrange (sort) the data on (e.g., "color, desc(price)")
#' @param rows Rows to select from the specified dataset #' @param rows Rows to select from the specified dataset
#' @param envir Environment to extract data from #' @param envir Environment to extract data from
#' #'
#' @return A list with all variables defined in nn as an object of class nn #' @return A list with all variables defined in nn as an object of class nn
#' #'
#' @examples #' @examples
#' nn(titanic, "survived", c("pclass", "sex"), lev = "Yes") %>% summary() #' nn(titanic, "survived", c("pclass", "sex"), lev = "Yes") %>% summary()
#' nn(titanic, "survived", c("pclass", "sex")) %>% str() #' nn(titanic, "survived", c("pclass", "sex")) %>% str()
#' nn(diamonds, "price", c("carat", "clarity"), type = "regression") %>% summary() #' nn(diamonds, "price", c("carat", "clarity"), type = "regression") %>% summary()
#' @seealso \code{\link{summary.nn}} to summarize results #' @seealso \code{\link{summary.nn}} to summarize results
#' @seealso \code{\link{plot.nn}} to plot results #' @seealso \code{\link{plot.nn}} to plot results
#' @seealso \code{\link{predict.nn}} for prediction #' @seealso \code{\link{predict.nn}} for prediction
#' #'
#' @importFrom nnet nnet #' @importFrom nnet nnet
#' #'
#' @export #' @export
nn <- function(dataset, rvar, evar, nn <- function(dataset, rvar, evar,
type = "classification", lev = "", type = "classification", lev = "",
size = 1, decay = .5, wts = "None", size = 1, decay = .5, wts = "None",
seed = NA, check = "standardize", seed = NA, check = "standardize",
form, data_filter = "", arr = "", form, data_filter = "", arr = "",
rows = NULL, envir = parent.frame()) { rows = NULL, envir = parent.frame()) {
if (!missing(form)) { if (!missing(form)) {
form <- as.formula(format(form)) form <- as.formula(format(form))
paste0(format(form), collapse = "") paste0(format(form), collapse = "")
vars <- all.vars(form) vars <- all.vars(form)
rvar <- vars[1] rvar <- vars[1]
evar <- vars[-1] evar <- vars[-1]
} }
if (rvar %in% evar) { if (rvar %in% evar) {
return("Response variable contained in the set of explanatory variables.\nPlease update model specification." %>% return("Response variable contained in the set of explanatory variables.\nPlease update model specification." %>%
add_class("nn")) add_class("nn"))
} else if (is.empty(size) || size < 1) { } else if (is.empty(size) || size < 1) {
return("Size should be larger than or equal to 1." %>% add_class("nn")) return("Size should be larger than or equal to 1." %>% add_class("nn"))
} else if (is.empty(decay) || decay < 0) { } else if (is.empty(decay) || decay < 0) {
return("Decay should be larger than or equal to 0." %>% add_class("nn")) return("Decay should be larger than or equal to 0." %>% add_class("nn"))
} }
vars <- c(rvar, evar) vars <- c(rvar, evar)
if (is.empty(wts, "None")) { if (is.empty(wts, "None")) {
wts <- NULL wts <- NULL
} else if (is_string(wts)) { } else if (is_string(wts)) {
wtsname <- wts wtsname <- wts
vars <- c(rvar, evar, wtsname) vars <- c(rvar, evar, wtsname)
} }
df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset)) df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset))
dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir) dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir)
if (!is.empty(wts)) { if (!is.empty(wts)) {
if (exists("wtsname")) { if (exists("wtsname")) {
wts <- dataset[[wtsname]] wts <- dataset[[wtsname]]
dataset <- select_at(dataset, .vars = base::setdiff(colnames(dataset), wtsname)) dataset <- select_at(dataset, .vars = base::setdiff(colnames(dataset), wtsname))
} }
if (length(wts) != nrow(dataset)) { if (length(wts) != nrow(dataset)) {
return( return(
paste0("Length of the weights variable is not equal to the number of rows in the dataset (", format_nr(length(wts), dec = 0), " vs ", format_nr(nrow(dataset), dec = 0), ")") %>% paste0("Length of the weights variable is not equal to the number of rows in the dataset (", format_nr(length(wts), dec = 0), " vs ", format_nr(nrow(dataset), dec = 0), ")") %>%
add_class("nn") add_class("nn")
) )
} }
} }
not_vary <- colnames(dataset)[summarise_all(dataset, does_vary) == FALSE] not_vary <- colnames(dataset)[summarise_all(dataset, does_vary) == FALSE]
if (length(not_vary) > 0) { if (length(not_vary) > 0) {
return(paste0("The following variable(s) show no variation. Please select other variables.\n\n** ", paste0(not_vary, collapse = ", "), " **") %>% return(paste0("The following variable(s) show no variation. Please select other variables.\n\n** ", paste0(not_vary, collapse = ", "), " **") %>%
add_class("nn")) add_class("nn"))
} }
rv <- dataset[[rvar]] rv <- dataset[[rvar]]
if (type == "classification") { if (type == "classification") {
linout <- FALSE linout <- FALSE
entropy <- TRUE entropy <- TRUE
if (lev == "") { if (lev == "") {
if (is.factor(rv)) { if (is.factor(rv)) {
lev <- levels(rv)[1] lev <- levels(rv)[1]
} else { } else {
lev <- as.character(rv) %>% lev <- as.character(rv) %>%
as.factor() %>% as.factor() %>%
levels() %>% levels() %>%
.[1] .[1]
} }
} }
## transformation to TRUE/FALSE depending on the selected level (lev) ## transformation to TRUE/FALSE depending on the selected level (lev)
dataset[[rvar]] <- dataset[[rvar]] == lev dataset[[rvar]] <- dataset[[rvar]] == lev
} else { } else {
linout <- TRUE linout <- TRUE
entropy <- FALSE entropy <- FALSE
} }
## standardize data to limit stability issues ... ## standardize data to limit stability issues ...
# http://stats.stackexchange.com/questions/23235/how-do-i-improve-my-neural-network-stability # http://stats.stackexchange.com/questions/23235/how-do-i-improve-my-neural-network-stability
if ("standardize" %in% check) { if ("standardize" %in% check) {
dataset <- scale_df(dataset, wts = wts) dataset <- scale_df(dataset, wts = wts)
} }
vars <- evar vars <- evar
## in case : is used ## in case : is used
if (length(vars) < (ncol(dataset) - 1)) { if (length(vars) < (ncol(dataset) - 1)) {
vars <- evar <- colnames(dataset)[-1] vars <- evar <- colnames(dataset)[-1]
} }
if (missing(form)) form <- as.formula(paste(rvar, "~ . ")) if (missing(form)) form <- as.formula(paste(rvar, "~ . "))
## use decay http://stats.stackexchange.com/a/70146/61693 ## use decay http://stats.stackexchange.com/a/70146/61693
nninput <- list( nninput <- list(
formula = form, formula = form,
rang = .1, size = size, decay = decay, weights = wts, rang = .1, size = size, decay = decay, weights = wts,
maxit = 10000, linout = linout, entropy = entropy, maxit = 10000, linout = linout, entropy = entropy,
skip = FALSE, trace = FALSE, data = dataset skip = FALSE, trace = FALSE, data = dataset
) )
## based on https://stackoverflow.com/a/14324316/1974918 ## based on https://stackoverflow.com/a/14324316/1974918
seed <- gsub("[^0-9]", "", seed) seed <- gsub("[^0-9]", "", seed)
if (!is.empty(seed)) { if (!is.empty(seed)) {
if (exists(".Random.seed")) { if (exists(".Random.seed")) {
gseed <- .Random.seed gseed <- .Random.seed
on.exit(.Random.seed <<- gseed) on.exit(.Random.seed <<- gseed)
} }
set.seed(seed) set.seed(seed)
} }
## need do.call so Garson/Olden plot will work ## need do.call so Garson/Olden plot will work
model <- do.call(nnet::nnet, nninput) model <- do.call(nnet::nnet, nninput)
coefnames <- model$coefnames coefnames <- model$coefnames
hasLevs <- sapply(select(dataset, -1), function(x) is.factor(x) || is.logical(x) || is.character(x)) hasLevs <- sapply(select(dataset, -1), function(x) is.factor(x) || is.logical(x) || is.character(x))
if (sum(hasLevs) > 0) { if (sum(hasLevs) > 0) {
for (i in names(hasLevs[hasLevs])) { for (i in names(hasLevs[hasLevs])) {
coefnames %<>% gsub(paste0("^", i), paste0(i, "|"), .) %>% coefnames %<>% gsub(paste0("^", i), paste0(i, "|"), .) %>%
gsub(paste0(":", i), paste0(":", i, "|"), .) gsub(paste0(":", i), paste0(":", i, "|"), .)
} }
rm(i, hasLevs) rm(i, hasLevs)
} }
## nn returns residuals as a matrix ## nn returns residuals as a matrix
model$residuals <- model$residuals[, 1] model$residuals <- model$residuals[, 1]
## nn model object does not include the data by default ## nn model object does not include the data by default
model$model <- dataset model$model <- dataset
rm(dataset, envir) ## dataset not needed elsewhere rm(dataset, envir) ## dataset not needed elsewhere
as.list(environment()) %>% add_class(c("nn", "model")) as.list(environment()) %>% add_class(c("nn", "model"))
} }
#' Center or standardize variables in a data frame #' Center or standardize variables in a data frame
#' #'
#' @param dataset Data frame #' @param dataset Data frame
#' @param center Center data (TRUE or FALSE) #' @param center Center data (TRUE or FALSE)
#' @param scale Scale data (TRUE or FALSE) #' @param scale Scale data (TRUE or FALSE)
#' @param sf Scaling factor (default is 2) #' @param sf Scaling factor (default is 2)
#' @param wts Weights to use (default is NULL for no weights) #' @param wts Weights to use (default is NULL for no weights)
#' @param calc Calculate mean and sd or use attributes attached to dat #' @param calc Calculate mean and sd or use attributes attached to dat
#' #'
#' @return Scaled data frame #' @return Scaled data frame
#' #'
#' @export #' @export
scale_df <- function(dataset, center = TRUE, scale = TRUE, scale_df <- function(dataset, center = TRUE, scale = TRUE,
sf = 2, wts = NULL, calc = TRUE) { sf = 2, wts = NULL, calc = TRUE) {
isNum <- sapply(dataset, function(x) is.numeric(x)) isNum <- sapply(dataset, function(x) is.numeric(x))
if (length(isNum) == 0 || sum(isNum) == 0) { if (length(isNum) == 0 || sum(isNum) == 0) {
return(dataset) return(dataset)
} }
cn <- names(isNum)[isNum] cn <- names(isNum)[isNum]
## remove set_attr calls when dplyr removes and keep attributes appropriately
## remove set_attr calls when dplyr removes and keep attributes appropriately descr <- attr(dataset, "description")
descr <- attr(dataset, "description") if (calc) {
if (calc) { if (length(wts) == 0) {
if (length(wts) == 0) { ms <- summarise_at(dataset, .vars = cn, .funs = ~ mean(., na.rm = TRUE)) %>%
ms <- summarise_at(dataset, .vars = cn, .funs = ~ mean(., na.rm = TRUE)) %>% set_attr("description", NULL)
set_attr("description", NULL) if (scale) {
if (scale) { sds <- summarise_at(dataset, .vars = cn, .funs = ~ sd(., na.rm = TRUE)) %>%
sds <- summarise_at(dataset, .vars = cn, .funs = ~ sd(., na.rm = TRUE)) %>% set_attr("description", NULL)
set_attr("description", NULL) }
} } else {
} else { ms <- summarise_at(dataset, .vars = cn, .funs = ~ weighted.mean(., wts, na.rm = TRUE)) %>%
ms <- summarise_at(dataset, .vars = cn, .funs = ~ weighted.mean(., wts, na.rm = TRUE)) %>% set_attr("description", NULL)
set_attr("description", NULL) if (scale) {
if (scale) { sds <- summarise_at(dataset, .vars = cn, .funs = ~ weighted.sd(., wts, na.rm = TRUE)) %>%
sds <- summarise_at(dataset, .vars = cn, .funs = ~ weighted.sd(., wts, na.rm = TRUE)) %>% set_attr("description", NULL)
set_attr("description", NULL) }
} }
} } else {
} else { ms <- attr(dataset, "radiant_ms")
ms <- attr(dataset, "radiant_ms") sds <- attr(dataset, "radiant_sds")
sds <- attr(dataset, "radiant_sds") if (is.null(ms) && is.null(sds)) {
if (is.null(ms) && is.null(sds)) { return(dataset)
return(dataset) }
} }
} if (center && scale) {
if (center && scale) { icn <- intersect(names(ms), cn)
icn <- intersect(names(ms), cn) dataset[icn] <- lapply(icn, function(var) (dataset[[var]] - ms[[var]]) / (sf * sds[[var]]))
dataset[icn] <- lapply(icn, function(var) (dataset[[var]] - ms[[var]]) / (sf * sds[[var]])) dataset %>%
dataset %>% set_attr("radiant_ms", ms) %>%
set_attr("radiant_ms", ms) %>% set_attr("radiant_sds", sds) %>%
set_attr("radiant_sds", sds) %>% set_attr("radiant_sf", sf) %>%
set_attr("radiant_sf", sf) %>% set_attr("description", descr)
set_attr("description", descr) } else if (center) {
} else if (center) { icn <- intersect(names(ms), cn)
icn <- intersect(names(ms), cn) dataset[icn] <- lapply(icn, function(var) dataset[[var]] - ms[[var]])
dataset[icn] <- lapply(icn, function(var) dataset[[var]] - ms[[var]]) dataset %>%
dataset %>% set_attr("radiant_ms", ms) %>%
set_attr("radiant_ms", ms) %>% set_attr("description", descr)
set_attr("description", descr) } else if (scale) {
} else if (scale) { icn <- intersect(names(sds), cn)
icn <- intersect(names(sds), cn) dataset[icn] <- lapply(icn, function(var) dataset[[var]] / (sf * sds[[var]]))
dataset[icn] <- lapply(icn, function(var) dataset[[var]] / (sf * sds[[var]])) dataset %>%
set_attr("radiant_sds", sds) %>% set_attr("radiant_sds", sds) %>%
set_attr("radiant_sf", sf) %>% set_attr("radiant_sf", sf) %>%
set_attr("description", descr) set_attr("description", descr)
} else { } else {
dataset dataset
} }
} }
#' Summary method for the nn function #' Summary method for the nn function
#' #'
#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant #' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant
#' #'
#' @param object Return value from \code{\link{nn}} #' @param object Return value from \code{\link{nn}}
#' @param prn Print list of weights #' @param prn Print list of weights
#' @param ... further arguments passed to or from other methods #' @param ... further arguments passed to or from other methods
#' #'
#' @examples #' @examples
#' result <- nn(titanic, "survived", "pclass", lev = "Yes") #' result <- nn(titanic, "survived", "pclass", lev = "Yes")
#' summary(result) #' summary(result)
#' @seealso \code{\link{nn}} to generate results #' @seealso \code{\link{nn}} to generate results
#' @seealso \code{\link{plot.nn}} to plot results #' @seealso \code{\link{plot.nn}} to plot results
#' @seealso \code{\link{predict.nn}} for prediction #' @seealso \code{\link{predict.nn}} for prediction
#' #'
#' @export #' @export
summary.nn <- function(object, prn = TRUE, ...) { summary.nn <- function(object, prn = TRUE, ...) {
if (is.character(object)) { if (is.character(object)) {
return(object) return(object)
} }
cat("Neural Network\n") cat("Neural Network\n")
if (object$type == "classification") { if (object$type == "classification") {
cat("Activation function : Logistic (classification)") cat("Activation function : Logistic (classification)")
} else { } else {
cat("Activation function : Linear (regression)") cat("Activation function : Linear (regression)")
} }
cat("\nData :", object$df_name) cat("\nData :", object$df_name)
if (!is.empty(object$data_filter)) { if (!is.empty(object$data_filter)) {
cat("\nFilter :", gsub("\\n", "", object$data_filter)) cat("\nFilter :", gsub("\\n", "", object$data_filter))
} }
if (!is.empty(object$arr)) { if (!is.empty(object$arr)) {
cat("\nArrange :", gsub("\\n", "", object$arr)) cat("\nArrange :", gsub("\\n", "", object$arr))
} }
if (!is.empty(object$rows)) { if (!is.empty(object$rows)) {
cat("\nSlice :", gsub("\\n", "", object$rows)) cat("\nSlice :", gsub("\\n", "", object$rows))
} }
cat("\nResponse variable :", object$rvar) cat("\nResponse variable :", object$rvar)
if (object$type == "classification") { if (object$type == "classification") {
cat("\nLevel :", object$lev, "in", object$rvar) cat("\nLevel :", object$lev, "in", object$rvar)
} }
cat("\nExplanatory variables:", paste0(object$evar, collapse = ", "), "\n") cat("\nExplanatory variables:", paste0(object$evar, collapse = ", "), "\n")
if (length(object$wtsname) > 0) { if (length(object$wtsname) > 0) {
cat("Weights used :", object$wtsname, "\n") cat("Weights used :", object$wtsname, "\n")
} }
cat("Network size :", object$size, "\n") cat("Network size :", object$size, "\n")
cat("Parameter decay :", object$decay, "\n") cat("Parameter decay :", object$decay, "\n")
if (!is.empty(object$seed)) { if (!is.empty(object$seed)) {
cat("Seed :", object$seed, "\n") cat("Seed :", object$seed, "\n")
} }
network <- paste0(object$model$n, collapse = "-") network <- paste0(object$model$n, collapse = "-")
nweights <- length(object$model$wts) nweights <- length(object$model$wts)
cat("Network :", network, "with", nweights, "weights\n") cat("Network :", network, "with", nweights, "weights\n")
if (!is.empty(object$wts, "None") && (length(unique(object$wts)) > 2 || min(object$wts) >= 1)) { if (!is.empty(object$wts, "None") && (length(unique(object$wts)) > 2 || min(object$wts) >= 1)) {
cat("Nr obs :", format_nr(sum(object$wts), dec = 0), "\n") cat("Nr obs :", format_nr(sum(object$wts), dec = 0), "\n")
} else { } else {
cat("Nr obs :", format_nr(length(object$rv), dec = 0), "\n") cat("Nr obs :", format_nr(length(object$rv), dec = 0), "\n")
} }
if (object$model$convergence != 0) { if (object$model$convergence != 0) {
cat("\n** The model did not converge **") cat("\n** The model did not converge **")
} else { } else {
if (prn) { if (prn) {
cat("Weights :\n") cat("Weights :\n")
oop <- base::options(width = 100) oop <- base::options(width = 100)
on.exit(base::options(oop), add = TRUE) on.exit(base::options(oop), add = TRUE)
capture.output(summary(object$model))[-1:-2] %>% capture.output(summary(object$model))[-1:-2] %>%
gsub("^", " ", .) %>% gsub("^", " ", .) %>%
paste0(collapse = "\n") %>% paste0(collapse = "\n") %>%
cat("\n") cat("\n")
} }
} }
} }
#' Variable importance using the vip package and permutation importance #' Variable importance using the vip package and permutation importance
#' #'
#' @param object Model object created by Radiant #' @param object Model object created by Radiant
#' @param rvar Label to identify the response or target variable #' @param rvar Label to identify the response or target variable
#' @param lev Reference class for binary classifier (rvar) #' @param lev Reference class for binary classifier (rvar)
#' @param data Data to use for prediction. Will default to the data used to estimate the model #' @param data Data to use for prediction. Will default to the data used to estimate the model
#' @param seed Random seed for reproducibility #' @param seed Random seed for reproducibility
#' #'
#' @importFrom vip vi #' @importFrom vip vi
#' #'
#' @export #' @export
varimp <- function(object, rvar, lev, data = NULL, seed = 1234) { varimp <- function(object, rvar, lev, data = NULL, seed = 1234) {
if (is.null(data)) data <- object$model$model if (is.null(data)) data <- object$model$model
# needed to avoid rescaling during prediction # needed to avoid rescaling during prediction
object$check <- setdiff(object$check, c("center", "standardize")) object$check <- setdiff(object$check, c("center", "standardize"))
arg_list <- list(object, pred_data = data, se = FALSE) arg_list <- list(object, pred_data = data, se = FALSE)
if (missing(rvar)) rvar <- object$rvar if (missing(rvar)) rvar <- object$rvar
if (missing(lev) && object$type == "classification") { if (missing(lev) && object$type == "classification") {
if (!is.empty(object$lev)) { if (!is.empty(object$lev)) {
lev <- object$lev lev <- object$lev
} }
if (!is.logical(data[[rvar]])) { if (!is.logical(data[[rvar]])) {
# don't change if already logical # don't change if already logical
data[[rvar]] <- data[[rvar]] == lev data[[rvar]] <- data[[rvar]] == lev
} }
} else if (object$type == "classification") { } else if (object$type == "classification") {
data[[rvar]] <- data[[rvar]] == lev data[[rvar]] <- data[[rvar]] == lev
} }
fun <- function(object, arg_list) do.call(predict, arg_list)[["Prediction"]] fun <- function(object, arg_list) do.call(predict, arg_list)[["Prediction"]]
if (inherits(object, "rforest")) { if (inherits(object, "rforest")) {
arg_list$OOB <- FALSE # all 0 importance scores when using OOB arg_list$OOB <- FALSE # all 0 importance scores when using OOB
if (object$type == "classification") { if (object$type == "classification") {
fun <- function(object, arg_list) do.call(predict, arg_list)[[object$lev]] fun <- function(object, arg_list) do.call(predict, arg_list)[[object$lev]]
} }
} }
pred_fun <- function(object, newdata) { pred_fun <- function(object, newdata) {
arg_list$pred_data <- newdata arg_list$pred_data <- newdata
fun(object, arg_list) fun(object, arg_list)
} }
set.seed(seed) set.seed(seed)
if (object$type == "regression") { if (object$type == "regression") {
vimp <- vip::vi( vimp <- vip::vi(
object, object,
target = rvar, target = rvar,
method = "permute", method = "permute",
metric = "rsq", # "rmse" metric = "rsq", # "rmse"
pred_wrapper = pred_fun, pred_wrapper = pred_fun,
train = data train = data
) )
} else { } else {
# required after transition to yardstick by the vip package # required after transition to yardstick by the vip package
data[[rvar]] <- factor(data[[rvar]], levels = c("TRUE", "FALSE")) data[[rvar]] <- factor(data[[rvar]], levels = c("TRUE", "FALSE"))
vimp <- vip::vi( vimp <- vip::vi(
object, object,
target = rvar, target = rvar,
event_level = "first", event_level = "first",
method = "permute", method = "permute",
metric = "roc_auc", metric = "roc_auc",
pred_wrapper = pred_fun, pred_wrapper = pred_fun,
train = data train = data
) )
} }
vimp %>% vimp %>%
filter(Importance != 0) %>% filter(Importance != 0) %>%
mutate(Variable = factor(Variable, levels = rev(Variable))) mutate(Variable = factor(Variable, levels = rev(Variable)))
} }
#' Plot permutation importance #' Plot permutation importance
#' #'
#' @param object Model object created by Radiant #' @param object Model object created by Radiant
#' @param rvar Label to identify the response or target variable #' @param rvar Label to identify the response or target variable
#' @param lev Reference class for binary classifier (rvar) #' @param lev Reference class for binary classifier (rvar)
#' @param data Data to use for prediction. Will default to the data used to estimate the model #' @param data Data to use for prediction. Will default to the data used to estimate the model
#' @param seed Random seed for reproducibility #' @param seed Random seed for reproducibility
#' #'
#' @importFrom vip vi #' @importFrom vip vi
#' #'
#' @export #' @export
varimp_plot <- function(object, rvar, lev, data = NULL, seed = 1234) { varimp_plot <- function(object, rvar, lev, data = NULL, seed = 1234) {
vi_scores <- varimp(object, rvar, lev, data = data, seed = seed) vi_scores <- varimp(object, rvar, lev, data = data, seed = seed)
visualize(vi_scores, yvar = "Importance", xvar = "Variable", type = "bar", custom = TRUE) + visualize(vi_scores, yvar = "Importance", xvar = "Variable", type = "bar", custom = TRUE) +
labs( labs(
title = "Permutation Importance", title = "Permutation Importance",
x = NULL, x = NULL,
y = ifelse(object$type == "regression", "Importance (R-square decrease)", "Importance (AUC decrease)") y = ifelse(object$type == "regression", "Importance (R-square decrease)", "Importance (AUC decrease)")
) + ) +
coord_flip() + coord_flip() +
theme(axis.text.y = element_text(hjust = 0)) theme(axis.text.y = element_text(hjust = 0))
} }
#' Plot method for the nn function #' Plot method for the nn function
#' #'
#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant #' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant
#' #'
#' @param x Return value from \code{\link{nn}} #' @param x Return value from \code{\link{nn}}
#' @param plots Plots to produce for the specified Neural Network model. Use "" to avoid showing any plots (default). Options are "olden" or "garson" for importance plots, or "net" to depict the network structure #' @param plots Plots to produce for the specified Neural Network model. Use "" to avoid showing any plots (default). Options are "olden" or "garson" for importance plots, or "net" to depict the network structure
#' @param size Font size used #' @param size Font size used
#' @param pad_x Padding for explanatory variable labels in the network plot. Default value is 0.9, smaller numbers (e.g., 0.5) increase the amount of padding #' @param pad_x Padding for explanatory variable labels in the network plot. Default value is 0.9, smaller numbers (e.g., 0.5) increase the amount of padding
#' @param nrobs Number of data points to show in dashboard scatter plots (-1 for all) #' @param nrobs Number of data points to show in dashboard scatter plots (-1 for all)
#' @param incl Which variables to include in a coefficient plot or PDP plot #' @param incl Which variables to include in a coefficient plot or PDP plot
#' @param incl_int Which interactions to investigate in PDP plots #' @param incl_int Which interactions to investigate in PDP plots
#' @param shiny Did the function call originate inside a shiny app #' @param shiny Did the function call originate inside a shiny app
#' @param custom Logical (TRUE, FALSE) to indicate if ggplot object (or list of ggplot objects) should be returned. This option can be used to customize plots (e.g., add a title, change x and y labels, etc.). See examples and \url{https://ggplot2.tidyverse.org} for options. #' @param custom Logical (TRUE, FALSE) to indicate if ggplot object (or list of ggplot objects) should be returned. This option can be used to customize plots (e.g., add a title, change x and y labels, etc.). See examples and \url{https://ggplot2.tidyverse.org} for options.
#' @param ... further arguments passed to or from other methods #' @param ... further arguments passed to or from other methods
#' #'
#' @examples #' @examples
#' result <- nn(titanic, "survived", c("pclass", "sex"), lev = "Yes") #' result <- nn(titanic, "survived", c("pclass", "sex"), lev = "Yes")
#' plot(result, plots = "net") #' plot(result, plots = "net")
#' plot(result, plots = "olden") #' plot(result, plots = "olden")
#' @seealso \code{\link{nn}} to generate results #' @seealso \code{\link{nn}} to generate results
#' @seealso \code{\link{summary.nn}} to summarize results #' @seealso \code{\link{summary.nn}} to summarize results
#' @seealso \code{\link{predict.nn}} for prediction #' @seealso \code{\link{predict.nn}} for prediction
#' #'
#' @importFrom NeuralNetTools plotnet olden garson #' @importFrom NeuralNetTools plotnet olden garson
#' @importFrom graphics par #' @importFrom graphics par
#' #'
#' @export #' @export
plot.nn <- function(x, plots = "vip", size = 12, pad_x = 0.9, nrobs = -1, plot.nn <- function(x, plots = "vip", size = 12, pad_x = 0.9, nrobs = -1,
incl = NULL, incl_int = NULL, incl = NULL, incl_int = NULL,
shiny = FALSE, custom = FALSE, ...) { shiny = FALSE, custom = FALSE, ...) {
if (is.character(x) || !inherits(x$model, "nnet")) { if (is.character(x) || !inherits(x$model, "nnet")) {
return(x) return(x)
} }
plot_list <- list() plot_list <- list()
nrCol <- 1 nrCol <- 1
if ("olden" %in% plots || "olsen" %in% plots) { ## legacy for typo if ("olden" %in% plots || "olsen" %in% plots) { ## legacy for typo
plot_list[["olsen"]] <- NeuralNetTools::olden(x$model, x_lab = x$coefnames, cex_val = 4) + plot_list[["olsen"]] <- NeuralNetTools::olden(x$model, x_lab = x$coefnames, cex_val = 4) +
coord_flip() + coord_flip() +
theme_set(theme_gray(base_size = size)) + theme_set(theme_gray(base_size = size)) +
theme(legend.position = "none") + theme(legend.position = "none") +
labs(title = paste0("Olden plot of variable importance (size = ", x$size, ", decay = ", x$decay, ")")) labs(title = paste0("Olden plot of variable importance (size = ", x$size, ", decay = ", x$decay, ")"))
} }
if ("garson" %in% plots) { if ("garson" %in% plots) {
plot_list[["garson"]] <- NeuralNetTools::garson(x$model, x_lab = x$coefnames) + plot_list[["garson"]] <- NeuralNetTools::garson(x$model, x_lab = x$coefnames) +
coord_flip() + coord_flip() +
theme_set(theme_gray(base_size = size)) + theme_set(theme_gray(base_size = size)) +
theme(legend.position = "none") + theme(legend.position = "none") +
labs(title = paste0("Garson plot of variable importance (size = ", x$size, ", decay = ", x$decay, ")")) labs(title = paste0("Garson plot of variable importance (size = ", x$size, ", decay = ", x$decay, ")"))
} }
if ("vip" %in% plots) { if ("vip" %in% plots) {
vi_scores <- varimp(x) vi_scores <- varimp(x)
plot_list[["vip"]] <- plot_list[["vip"]] <-
visualize(vi_scores, yvar = "Importance", xvar = "Variable", type = "bar", custom = TRUE) + visualize(vi_scores, yvar = "Importance", xvar = "Variable", type = "bar", custom = TRUE) +
labs( labs(
title = paste0("Permutation Importance (size = ", x$size, ", decay = ", x$decay, ")"), title = paste0("Permutation Importance (size = ", x$size, ", decay = ", x$decay, ")"),
x = NULL, x = NULL,
y = ifelse(x$type == "regression", "Importance (R-square decrease)", "Importance (AUC decrease)") y = ifelse(x$type == "regression", "Importance (R-square decrease)", "Importance (AUC decrease)")
) + ) +
coord_flip() + coord_flip() +
theme(axis.text.y = element_text(hjust = 0)) theme(axis.text.y = element_text(hjust = 0))
} }
if ("net" %in% plots) { if ("net" %in% plots) {
## don't need as much spacing at the top and bottom ## don't need as much spacing at the top and bottom
mar <- par(mar = c(0, 4.1, 0, 2.1)) mar <- par(mar = c(0, 4.1, 0, 2.1))
on.exit(par(mar = mar$mar)) on.exit(par(mar = mar$mar))
return(do.call(NeuralNetTools::plotnet, list(mod_in = x$model, x_names = x$coefnames, pad_x = pad_x, cex_val = size / 16))) return(do.call(NeuralNetTools::plotnet, list(mod_in = x$model, x_names = x$coefnames, pad_x = pad_x, cex_val = size / 16)))
} }
if ("pred_plot" %in% plots) { if ("pred_plot" %in% plots) {
nrCol <- 2 nrCol <- 2
if (length(incl) > 0 | length(incl_int) > 0) { if (length(incl) > 0 | length(incl_int) > 0) {
plot_list <- pred_plot(x, plot_list, incl, incl_int, ...) plot_list <- pred_plot(x, plot_list, incl, incl_int, ...)
} else { } else {
return("Select one or more variables to generate Prediction plots") return("Select one or more variables to generate Prediction plots")
} }
} }
if ("pdp" %in% plots) { if ("pdp" %in% plots) {
nrCol <- 2 nrCol <- 2
if (length(incl) > 0 || length(incl_int) > 0) { if (length(incl) > 0 || length(incl_int) > 0) {
plot_list <- pdp_plot(x, plot_list, incl, incl_int, ...) plot_list <- pdp_plot(x, plot_list, incl, incl_int, ...)
} else { } else {
return("Select one or more variables to generate Partial Dependence Plots") return("Select one or more variables to generate Partial Dependence Plots")
} }
} }
if (x$type == "regression" && "dashboard" %in% plots) { if (x$type == "regression" && "dashboard" %in% plots) {
plot_list <- plot.regress(x, plots = "dashboard", lines = "line", nrobs = nrobs, custom = TRUE) plot_list <- plot.regress(x, plots = "dashboard", lines = "line", nrobs = nrobs, custom = TRUE)
nrCol <- 2 nrCol <- 2
} }
if (length(plot_list) > 0) { if (length(plot_list) > 0) {
if (custom) { if (custom) {
if (length(plot_list) == 1) plot_list[[1]] else plot_list if (length(plot_list) == 1) plot_list[[1]] else plot_list
} else { } else {
patchwork::wrap_plots(plot_list, ncol = nrCol) %>% patchwork::wrap_plots(plot_list, ncol = nrCol) %>%
(function(x) if (isTRUE(shiny)) x else print(x)) (function(x) if (isTRUE(shiny)) x else print(x))
} }
} }
} }
#' Predict method for the nn function #' Predict method for the nn function
#' #'
#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant #' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant
#' #'
#' @param object Return value from \code{\link{nn}} #' @param object Return value from \code{\link{nn}}
#' @param pred_data Provide the dataframe to generate predictions (e.g., diamonds). The dataset must contain all columns used in the estimation #' @param pred_data Provide the dataframe to generate predictions (e.g., diamonds). The dataset must contain all columns used in the estimation
#' @param pred_cmd Generate predictions using a command. For example, `pclass = levels(pclass)` would produce predictions for the different levels of factor `pclass`. To add another variable, create a vector of prediction strings, (e.g., c('pclass = levels(pclass)', 'age = seq(0,100,20)') #' @param pred_cmd Generate predictions using a command. For example, `pclass = levels(pclass)` would produce predictions for the different levels of factor `pclass`. To add another variable, create a vector of prediction strings, (e.g., c('pclass = levels(pclass)', 'age = seq(0,100,20)')
#' @param dec Number of decimals to show #' @param dec Number of decimals to show
#' @param envir Environment to extract data from #' @param envir Environment to extract data from
#' @param ... further arguments passed to or from other methods #' @param ... further arguments passed to or from other methods
#' #'
#' @examples #' @examples
#' result <- nn(titanic, "survived", c("pclass", "sex"), lev = "Yes") #' result <- nn(titanic, "survived", c("pclass", "sex"), lev = "Yes")
#' predict(result, pred_cmd = "pclass = levels(pclass)") #' predict(result, pred_cmd = "pclass = levels(pclass)")
#' result <- nn(diamonds, "price", "carat:color", type = "regression") #' result <- nn(diamonds, "price", "carat:color", type = "regression")
#' predict(result, pred_cmd = "carat = 1:3") #' predict(result, pred_cmd = "carat = 1:3")
#' predict(result, pred_data = diamonds) %>% head() #' predict(result, pred_data = diamonds) %>% head()
#' @seealso \code{\link{nn}} to generate the result #' @seealso \code{\link{nn}} to generate the result
#' @seealso \code{\link{summary.nn}} to summarize results #' @seealso \code{\link{summary.nn}} to summarize results
#' #'
#' @export #' @export
predict.nn <- function(object, pred_data = NULL, pred_cmd = "", predict.nn <- function(object, pred_data = NULL, pred_cmd = "",
dec = 3, envir = parent.frame(), ...) { dec = 3, envir = parent.frame(), ...) {
if (is.character(object)) { if (is.character(object)) {
return(object) return(object)
} }
## ensure you have a name for the prediction dataset ## ensure you have a name for the prediction dataset
if (is.data.frame(pred_data)) { if (is.data.frame(pred_data)) {
df_name <- deparse(substitute(pred_data)) df_name <- deparse(substitute(pred_data))
} else { } else {
df_name <- pred_data df_name <- pred_data
} }
pfun <- function(model, pred, se, conf_lev) { pfun <- function(model, pred, se, conf_lev) {
pred_val <- try(sshhr(predict(model, pred)), silent = TRUE) pred_val <- try(sshhr(predict(model, pred)), silent = TRUE)
if (!inherits(pred_val, "try-error")) {
if (!inherits(pred_val, "try-error")) { if (is.vector(pred_val)) {
pred_val %<>% as.data.frame(stringsAsFactors = FALSE) %>% pred_val <- data.frame(Prediction = pred_val, stringsAsFactors = FALSE)
select(1) %>% } else if (is.matrix(pred_val) || is.array(pred_val)) {
set_colnames("Prediction") pred_val <- as.data.frame(pred_val, stringsAsFactors = FALSE)
} if (ncol(pred_val) > 0) {
if (ncol(pred_val) == 1) {
pred_val colnames(pred_val) <- "Prediction"
} } else {
pred_val <- pred_val[, 1, drop = FALSE]
predict_model(object, pfun, "nn.predict", pred_data, pred_cmd, conf_lev = 0.95, se = FALSE, dec, envir = envir) %>% colnames(pred_val) <- "Prediction"
set_attr("radiant_pred_data", df_name) }
} } else {
pred_val <- data.frame(Prediction = numeric(0))
#' Print method for predict.nn }
#' } else {
#' @param x Return value from prediction method pred_val <- as.data.frame(pred_val, stringsAsFactors = FALSE)
#' @param ... further arguments passed to or from other methods if (ncol(pred_val) > 0) {
#' @param n Number of lines of prediction results to print. Use -1 to print all lines pred_val <- pred_val[, 1, drop = FALSE]
#' colnames(pred_val) <- "Prediction"
#' @export }
print.nn.predict <- function(x, ..., n = 10) { }
print_predict_model(x, ..., n = n, header = "Neural Network") }
} pred_val
}
#' Cross-validation for a Neural Network
#' predict_model(object, pfun, "nn.predict", pred_data, pred_cmd, conf_lev = 0.95, se = FALSE, dec, envir = envir) %>%
#' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant set_attr("radiant_pred_data", df_name)
#' }
#' @param object Object of type "nn" or "nnet"
#' @param K Number of cross validation passes to use #' Print method for predict.nn
#' @param repeats Repeated cross validation #'
#' @param size Number of units (nodes) in the hidden layer #' @param x Return value from prediction method
#' @param decay Parameter decay #' @param ... further arguments passed to or from other methods
#' @param seed Random seed to use as the starting point #' @param n Number of lines of prediction results to print. Use -1 to print all lines
#' @param trace Print progress #'
#' @param fun Function to use for model evaluation (i.e., auc for classification and RMSE for regression) #' @export
#' @param ... Additional arguments to be passed to 'fun' print.nn.predict <- function(x, ..., n = 10) {
#' print_predict_model(x, ..., n = n, header = "Neural Network")
#' @return A data.frame sorted by the mean of the performance metric }
#'
#' @seealso \code{\link{nn}} to generate an initial model that can be passed to cv.nn #' Cross-validation for a Neural Network
#' @seealso \code{\link{Rsq}} to calculate an R-squared measure for a regression #'
#' @seealso \code{\link{RMSE}} to calculate the Root Mean Squared Error for a regression #' @details See \url{https://radiant-rstats.github.io/docs/model/nn.html} for an example in Radiant
#' @seealso \code{\link{MAE}} to calculate the Mean Absolute Error for a regression #'
#' @seealso \code{\link{auc}} to calculate the area under the ROC curve for classification #' @param object Object of type "nn" or "nnet"
#' @seealso \code{\link{profit}} to calculate profits for classification at a cost/margin threshold #' @param K Number of cross validation passes to use
#' #' @param repeats Repeated cross validation
#' @importFrom nnet nnet.formula #' @param size Number of units (nodes) in the hidden layer
#' @importFrom shiny getDefaultReactiveDomain withProgress incProgress #' @param decay Parameter decay
#' #' @param seed Random seed to use as the starting point
#' @examples #' @param trace Print progress
#' \dontrun{ #' @param fun Function to use for model evaluation (i.e., auc for classification and RMSE for regression)
#' result <- nn(dvd, "buy", c("coupon", "purch", "last")) #' @param ... Additional arguments to be passed to 'fun'
#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2) #'
#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2, fun = profit, cost = 1, margin = 5) #' @return A data.frame sorted by the mean of the performance metric
#' result <- nn(diamonds, "price", c("carat", "color", "clarity"), type = "regression") #'
#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2) #' @seealso \code{\link{nn}} to generate an initial model that can be passed to cv.nn
#' cv.nn(result, decay = seq(0, 1, .5), size = 1:2, fun = Rsq) #' @seealso \code{\link{Rsq}} to calculate an R-squared measure for a regression
#' } #' @seealso \code{\link{RMSE}} to calculate the Root Mean Squared Error for a regression
#' #' @seealso \code{\link{MAE}} to calculate the Mean Absolute Error for a regression
#' @export #' @seealso \code{\link{auc}} to calculate the area under the ROC curve for classification
cv.nn <- function(object, K = 5, repeats = 1, decay = seq(0, 1, .2), size = 1:5, #' @seealso \code{\link{profit}} to calculate profits for classification at a cost/margin threshold
seed = 1234, trace = TRUE, fun, ...) { #'
if (inherits(object, "nn")) { #' @importFrom nnet nnet.formula
ms <- attr(object$model$model, "radiant_ms")[[object$rvar]] #' @importFrom shiny getDefaultReactiveDomain withProgress incProgress
sds <- attr(object$model$model, "radiant_sds")[[object$rvar]] #'
if (length(sds) == 0) { #' @examples
sds <- sf <- 1 #' \dontrun{
} else { #' result <- nn(dvd, "buy", c("coupon", "purch", "last"))
sf <- attr(object$model$model, "radiant_sf") #' cv.nn(result, decay = seq(0, 1, .5), size = 1:2)
sf <- ifelse(length(sf) == 0, 2, sf) #' cv.nn(result, decay = seq(0, 1, .5), size = 1:2, fun = profit, cost = 1, margin = 5)
} #' result <- nn(diamonds, "price", c("carat", "color", "clarity"), type = "regression")
object <- object$model #' cv.nn(result, decay = seq(0, 1, .5), size = 1:2)
} else { #' cv.nn(result, decay = seq(0, 1, .5), size = 1:2, fun = Rsq)
ms <- 0 #' }
sds <- 1 #'
sf <- 1 #' @export
} cv.nn <- function(object, K = 5, repeats = 1, decay = seq(0, 1, .2), size = 1:5,
seed = 1234, trace = TRUE, fun, ...) {
if (inherits(object, "nnet")) { if (inherits(object, "nn")) {
dv <- as.character(object$call$formula[[2]]) ms <- attr(object$model$model, "radiant_ms")[[object$rvar]]
m <- eval(object$call[["data"]]) sds <- attr(object$model$model, "radiant_sds")[[object$rvar]]
weights <- eval(object$call[["weights"]]) if (length(sds) == 0) {
if (is.numeric(m[[dv]])) { sds <- sf <- 1
type <- "regression" } else {
} else { sf <- attr(object$model$model, "radiant_sf")
type <- "classification" sf <- ifelse(length(sf) == 0, 2, sf)
if (is.factor(m[[dv]])) { }
lev <- levels(m[[dv]])[1] object <- object$model
} else if (is.logical(m[[dv]])) { } else {
lev <- TRUE ms <- 0
} else { sds <- 1
stop("The level to use for classification is not clear. Use a factor of logical as the response variable") sf <- 1
} }
}
} else { if (inherits(object, "nnet")) {
stop("The model object does not seems to be a neural network") dv <- as.character(object$call$formula[[2]])
} m <- eval(object$call[["data"]])
weights <- eval(object$call[["weights"]])
set.seed(seed) if (is.numeric(m[[dv]])) {
tune_grid <- expand.grid(decay = decay, size = size) type <- "regression"
out <- data.frame(mean = NA, std = NA, min = NA, max = NA, decay = tune_grid[["decay"]], size = tune_grid[["size"]]) } else {
type <- "classification"
if (missing(fun)) { if (is.factor(m[[dv]])) {
if (type == "classification") { lev <- levels(m[[dv]])[1]
fun <- radiant.model::auc } else if (is.logical(m[[dv]])) {
cn <- "AUC (mean)" lev <- TRUE
} else { } else {
fun <- radiant.model::RMSE stop("The level to use for classification is not clear. Use a factor of logical as the response variable")
cn <- "RMSE (mean)" }
} }
} else { } else {
cn <- glue("{deparse(substitute(fun))} (mean)") stop("The model object does not seems to be a neural network")
} }
if (length(shiny::getDefaultReactiveDomain()) > 0) { set.seed(seed)
trace <- FALSE tune_grid <- expand.grid(decay = decay, size = size)
incProgress <- shiny::incProgress out <- data.frame(mean = NA, std = NA, min = NA, max = NA, decay = tune_grid[["decay"]], size = tune_grid[["size"]])
withProgress <- shiny::withProgress
} else { if (missing(fun)) {
incProgress <- function(...) {} if (type == "classification") {
withProgress <- function(...) list(...)[["expr"]] fun <- radiant.model::auc
} cn <- "AUC (mean)"
} else {
nitt <- nrow(tune_grid) fun <- radiant.model::RMSE
withProgress(message = "Running cross-validation (nn)", value = 0, { cn <- "RMSE (mean)"
for (i in seq_len(nitt)) { }
perf <- double(K * repeats) } else {
object$call[["decay"]] <- tune_grid[i, "decay"] cn <- glue("{deparse(substitute(fun))} (mean)")
object$call[["size"]] <- tune_grid[i, "size"] }
if (trace) cat("Working on size", tune_grid[i, "size"], "decay", tune_grid[i, "decay"], "\n")
for (j in seq_len(repeats)) { if (length(shiny::getDefaultReactiveDomain()) > 0) {
rand <- sample(K, nrow(m), replace = TRUE) trace <- FALSE
for (k in seq_len(K)) { incProgress <- shiny::incProgress
object$call[["data"]] <- quote(m[rand != k, , drop = FALSE]) withProgress <- shiny::withProgress
if (length(weights) > 0) { } else {
object$call[["weights"]] <- weights[rand != k] incProgress <- function(...) {}
} withProgress <- function(...) list(...)[["expr"]]
pred <- predict(eval(object$call), newdata = m[rand == k, , drop = FALSE])[, 1] }
if (type == "classification") {
if (missing(...)) { nitt <- nrow(tune_grid)
perf[k + (j - 1) * K] <- fun(pred, unlist(m[rand == k, dv]), lev) withProgress(message = "Running cross-validation (nn)", value = 0, {
} else { for (i in seq_len(nitt)) {
perf[k + (j - 1) * K] <- fun(pred, unlist(m[rand == k, dv]), lev, ...) perf <- double(K * repeats)
} object$call[["decay"]] <- tune_grid[i, "decay"]
} else { object$call[["size"]] <- tune_grid[i, "size"]
pred <- pred * sf * sds + ms if (trace) cat("Working on size", tune_grid[i, "size"], "decay", tune_grid[i, "decay"], "\n")
rvar <- unlist(m[rand == k, dv]) * sf * sds + ms for (j in seq_len(repeats)) {
if (missing(...)) { rand <- sample(K, nrow(m), replace = TRUE)
perf[k + (j - 1) * K] <- fun(pred, rvar) for (k in seq_len(K)) {
} else { object$call[["data"]] <- quote(m[rand != k, , drop = FALSE])
perf[k + (j - 1) * K] <- fun(pred, rvar, ...) if (length(weights) > 0) {
} object$call[["weights"]] <- weights[rand != k]
} }
} pred <- predict(eval(object$call), newdata = m[rand == k, , drop = FALSE])[, 1]
} if (type == "classification") {
out[i, 1:4] <- c(mean(perf), sd(perf), min(perf), max(perf)) if (missing(...)) {
incProgress(1 / nitt, detail = paste("\nCompleted run", i, "out of", nitt)) perf[k + (j - 1) * K] <- fun(pred, unlist(m[rand == k, dv]), lev)
} } else {
}) perf[k + (j - 1) * K] <- fun(pred, unlist(m[rand == k, dv]), lev, ...)
}
if (type == "classification") { } else {
out <- arrange(out, desc(mean)) pred <- pred * sf * sds + ms
} else { rvar <- unlist(m[rand == k, dv]) * sf * sds + ms
out <- arrange(out, mean) if (missing(...)) {
} perf[k + (j - 1) * K] <- fun(pred, rvar)
## show evaluation metric in column name } else {
colnames(out)[1] <- cn perf[k + (j - 1) * K] <- fun(pred, rvar, ...)
out }
} }
}
}
out[i, 1:4] <- c(mean(perf), sd(perf), min(perf), max(perf))
incProgress(1 / nitt, detail = paste("\nCompleted run", i, "out of", nitt))
}
})
if (type == "classification") {
out <- arrange(out, desc(mean))
} else {
out <- arrange(out, mean)
}
## show evaluation metric in column name
colnames(out)[1] <- cn
out
}
...@@ -60,7 +60,7 @@ svm <- function(dataset, rvar, evar, ...@@ -60,7 +60,7 @@ svm <- function(dataset, rvar, evar,
# 标准化 # 标准化
if ("standardize" %in% check) { if ("standardize" %in% check) {
dataset <- scale_df(dataset) dataset <- scale_sv(dataset)
} }
## ---- 3. 分类任务的响应变量 ---- ## ---- 3. 分类任务的响应变量 ----
...@@ -135,7 +135,7 @@ svm <- function(dataset, rvar, evar, ...@@ -135,7 +135,7 @@ svm <- function(dataset, rvar, evar,
#' Center or standardize variables in a data frame #' Center or standardize variables in a data frame
#' @export #' @export
scale_df <- function(dataset, center = TRUE, scale = TRUE, scale_sv <- function(dataset, center = TRUE, scale = TRUE,
sf = 2, wts = NULL, calc = TRUE) { sf = 2, wts = NULL, calc = TRUE) {
isNum <- sapply(dataset, function(x) is.numeric(x)) isNum <- sapply(dataset, function(x) is.numeric(x))
if (length(isNum) == 0 || sum(isNum) == 0) { if (length(isNum) == 0 || sum(isNum) == 0) {
......
nn_plots <- c( nn_plots <- c(
"none", "net", "vip", "pred_plot", "pdp", "olden", "garson", "dashboard" "none", "net", "vip", "pred_plot", "pdp", "olden", "garson", "dashboard"
) )
names(nn_plots) <- c( names(nn_plots) <- c(
i18n$t("None"), i18n$t("None"),
i18n$t("Network"), i18n$t("Network"),
i18n$t("Permutation Importance"), i18n$t("Permutation Importance"),
i18n$t("Prediction plots"), i18n$t("Prediction plots"),
i18n$t("Partial Dependence"), i18n$t("Partial Dependence"),
i18n$t("Olden"), i18n$t("Olden"),
i18n$t("Garson"), i18n$t("Garson"),
i18n$t("Dashboard") i18n$t("Dashboard")
) )
## list of function arguments ## list of function arguments
nn_args <- as.list(formals(nn)) nn_args <- as.list(formals(nn))
## list of function inputs selected by user ## list of function inputs selected by user
nn_inputs <- reactive({ nn_inputs <- reactive({
## loop needed because reactive values don't allow single bracket indexing ## loop needed because reactive values don't allow single bracket indexing
nn_args$data_filter <- if (input$show_filter) input$data_filter else "" nn_args$data_filter <- if (input$show_filter) input$data_filter else ""
nn_args$arr <- if (input$show_filter) input$data_arrange else "" nn_args$arr <- if (input$show_filter) input$data_arrange else ""
nn_args$rows <- if (input$show_filter) input$data_rows else "" nn_args$rows <- if (input$show_filter) input$data_rows else ""
nn_args$dataset <- input$dataset nn_args$dataset <- input$dataset
for (i in r_drop(names(nn_args))) { for (i in r_drop(names(nn_args))) {
nn_args[[i]] <- input[[paste0("nn_", i)]] nn_args[[i]] <- input[[paste0("nn_", i)]]
} }
nn_args nn_args
}) })
nn_pred_args <- as.list(if (exists("predict.nn")) { nn_pred_args <- as.list(if (exists("predict.nn")) {
formals(predict.nn) formals(predict.nn)
} else { } else {
formals(radiant.model:::predict.nn) formals(radiant.model:::predict.nn)
}) })
# list of function inputs selected by user # list of function inputs selected by user
nn_pred_inputs <- reactive({ nn_pred_inputs <- reactive({
# loop needed because reactive values don't allow single bracket indexing # loop needed because reactive values don't allow single bracket indexing
for (i in names(nn_pred_args)) { for (i in names(nn_pred_args)) {
nn_pred_args[[i]] <- input[[paste0("nn_", i)]] nn_pred_args[[i]] <- input[[paste0("nn_", i)]]
} }
nn_pred_args$pred_cmd <- nn_pred_args$pred_data <- "" nn_pred_args$pred_cmd <- nn_pred_args$pred_data <- ""
if (input$nn_predict == "cmd") { if (input$nn_predict == "cmd") {
nn_pred_args$pred_cmd <- gsub("\\s{2,}", " ", input$nn_pred_cmd) %>% nn_pred_args$pred_cmd <- gsub("\\s{2,}", " ", input$nn_pred_cmd) %>%
gsub(";\\s+", ";", .) %>% gsub(";\\s+", ";", .) %>%
gsub("\"", "\'", .) gsub("\"", "\'", .)
} else if (input$nn_predict == "data") { } else if (input$nn_predict == "data") {
nn_pred_args$pred_data <- input$nn_pred_data nn_pred_args$pred_data <- input$nn_pred_data
} else if (input$nn_predict == "datacmd") { } else if (input$nn_predict == "datacmd") {
nn_pred_args$pred_cmd <- gsub("\\s{2,}", " ", input$nn_pred_cmd) %>% nn_pred_args$pred_cmd <- gsub("\\s{2,}", " ", input$nn_pred_cmd) %>%
gsub(";\\s+", ";", .) %>% gsub(";\\s+", ";", .) %>%
gsub("\"", "\'", .) gsub("\"", "\'", .)
nn_pred_args$pred_data <- input$nn_pred_data nn_pred_args$pred_data <- input$nn_pred_data
} }
nn_pred_args nn_pred_args
}) })
nn_plot_args <- as.list(if (exists("plot.nn")) { nn_plot_args <- as.list(if (exists("plot.nn")) {
formals(plot.nn) formals(plot.nn)
} else { } else {
formals(radiant.model:::plot.nn) formals(radiant.model:::plot.nn)
}) })
## list of function inputs selected by user ## list of function inputs selected by user
nn_plot_inputs <- reactive({ nn_plot_inputs <- reactive({
## loop needed because reactive values don't allow single bracket indexing ## loop needed because reactive values don't allow single bracket indexing
for (i in names(nn_plot_args)) { for (i in names(nn_plot_args)) {
nn_plot_args[[i]] <- input[[paste0("nn_", i)]] nn_plot_args[[i]] <- input[[paste0("nn_", i)]]
} }
nn_plot_args nn_plot_args
}) })
nn_pred_plot_args <- as.list(if (exists("plot.model.predict")) { nn_pred_plot_args <- as.list(if (exists("plot.model.predict")) {
formals(plot.model.predict) formals(plot.model.predict)
} else { } else {
formals(radiant.model:::plot.model.predict) formals(radiant.model:::plot.model.predict)
}) })
# list of function inputs selected by user # list of function inputs selected by user
nn_pred_plot_inputs <- reactive({ nn_pred_plot_inputs <- reactive({
# loop needed because reactive values don't allow single bracket indexing # loop needed because reactive values don't allow single bracket indexing
for (i in names(nn_pred_plot_args)) { for (i in names(nn_pred_plot_args)) {
nn_pred_plot_args[[i]] <- input[[paste0("nn_", i)]] nn_pred_plot_args[[i]] <- input[[paste0("nn_", i)]]
} }
nn_pred_plot_args nn_pred_plot_args
}) })
output$ui_nn_rvar <- renderUI({ output$ui_nn_rvar <- renderUI({
req(input$nn_type) req(input$nn_type)
withProgress(message = i18n$t("Acquiring variable information"), value = 1, { withProgress(message = i18n$t("Acquiring variable information"), value = 1, {
if (input$nn_type == "classification") { if (input$nn_type == "classification") {
vars <- two_level_vars() vars <- two_level_vars()
} else { } else {
isNum <- .get_class() %in% c("integer", "numeric", "ts") isNum <- .get_class() %in% c("integer", "numeric", "ts")
vars <- varnames()[isNum] vars <- varnames()[isNum]
} }
}) })
init <- if (input$nn_type == "classification") { init <- if (input$nn_type == "classification") {
if (is.empty(input$logit_rvar)) isolate(input$nn_rvar) else input$logit_rvar if (is.empty(input$logit_rvar)) isolate(input$nn_rvar) else input$logit_rvar
} else { } else {
if (is.empty(input$reg_rvar)) isolate(input$nn_rvar) else input$reg_rvar if (is.empty(input$reg_rvar)) isolate(input$nn_rvar) else input$reg_rvar
} }
selectInput( selectInput(
inputId = "nn_rvar", inputId = "nn_rvar",
label = i18n$t("Response variable:"), label = i18n$t("Response variable:"),
choices = vars, choices = vars,
selected = state_single("nn_rvar", vars, init), selected = state_single("nn_rvar", vars, init),
multiple = FALSE multiple = FALSE
) )
}) })
output$ui_nn_lev <- renderUI({ output$ui_nn_lev <- renderUI({
req(input$nn_type == "classification") req(input$nn_type == "classification")
req(available(input$nn_rvar)) req(available(input$nn_rvar))
levs <- .get_data()[[input$nn_rvar]] %>% levs <- .get_data()[[input$nn_rvar]] %>%
as_factor() %>% as_factor() %>%
levels() levels()
init <- if (is.empty(input$logit_lev)) isolate(input$nn_lev) else input$logit_lev init <- if (is.empty(input$logit_lev)) isolate(input$nn_lev) else input$logit_lev
selectInput( selectInput(
inputId = "nn_lev", label = i18n$t("Choose level:"), inputId = "nn_lev", label = i18n$t("Choose level:"),
choices = levs, choices = levs,
selected = state_init("nn_lev", init) selected = state_init("nn_lev", init)
) )
}) })
output$ui_nn_evar <- renderUI({ output$ui_nn_evar <- renderUI({
if (not_available(input$nn_rvar)) { if (not_available(input$nn_rvar)) {
return() return()
} }
vars <- varnames() vars <- varnames()
if (length(vars) > 0) { if (length(vars) > 0) {
vars <- vars[-which(vars == input$nn_rvar)] vars <- vars[-which(vars == input$nn_rvar)]
} }
init <- if (input$nn_type == "classification") { init <- if (input$nn_type == "classification") {
# input$logit_evar # input$logit_evar
if (is.empty(input$logit_evar)) isolate(input$nn_evar) else input$logit_evar if (is.empty(input$logit_evar)) isolate(input$nn_evar) else input$logit_evar
} else { } else {
# input$reg_evar # input$reg_evar
if (is.empty(input$reg_evar)) isolate(input$nn_evar) else input$reg_evar if (is.empty(input$reg_evar)) isolate(input$nn_evar) else input$reg_evar
} }
selectInput( selectInput(
inputId = "nn_evar", inputId = "nn_evar",
label = i18n$t("Explanatory variables:"), label = i18n$t("Explanatory variables:"),
choices = vars, choices = vars,
selected = state_multiple("nn_evar", vars, init), selected = state_multiple("nn_evar", vars, init),
multiple = TRUE, multiple = TRUE,
size = min(10, length(vars)), size = min(10, length(vars)),
selectize = FALSE selectize = FALSE
) )
}) })
# function calls generate UI elements # function calls generate UI elements
output_incl("nn") output_incl("nn")
output_incl_int("nn") output_incl_int("nn")
output$ui_nn_wts <- renderUI({ output$ui_nn_wts <- renderUI({
isNum <- .get_class() %in% c("integer", "numeric", "ts") isNum <- .get_class() %in% c("integer", "numeric", "ts")
vars <- varnames()[isNum] vars <- varnames()[isNum]
if (length(vars) > 0 && any(vars %in% input$nn_evar)) { if (length(vars) > 0 && any(vars %in% input$nn_evar)) {
vars <- base::setdiff(vars, input$nn_evar) vars <- base::setdiff(vars, input$nn_evar)
names(vars) <- varnames() %>% names(vars) <- varnames() %>%
{ {
.[match(vars, .)] .[match(vars, .)]
} %>% } %>%
names() names()
} }
vars <- c("None", vars) vars <- c("None", vars)
selectInput( selectInput(
inputId = "nn_wts", label = i18n$t("Weights:"), choices = vars, inputId = "nn_wts", label = i18n$t("Weights:"), choices = vars,
selected = state_single("nn_wts", vars), selected = state_single("nn_wts", vars),
multiple = FALSE multiple = FALSE
) )
}) })
output$ui_nn_store_pred_name <- renderUI({ output$ui_nn_store_pred_name <- renderUI({
init <- state_init("nn_store_pred_name", "pred_nn") %>% init <- state_init("nn_store_pred_name", "pred_nn") %>%
sub("\\d{1,}$", "", .) %>% sub("\\d{1,}$", "", .) %>%
paste0(., ifelse(is.empty(input$nn_size), "", input$nn_size)) paste0(., ifelse(is.empty(input$nn_size), "", input$nn_size))
textInput( textInput(
"nn_store_pred_name", "nn_store_pred_name",
i18n$t("Store predictions:"), i18n$t("Store predictions:"),
init init
) )
}) })
output$ui_nn_store_res_name <- renderUI({ output$ui_nn_store_res_name <- renderUI({
req(input$dataset) req(input$dataset)
textInput("nn_store_res_name", i18n$t("Store residuals:"), "", placeholder = i18n$t("Provide variable name")) textInput("nn_store_res_name", i18n$t("Store residuals:"), "", placeholder = i18n$t("Provide variable name"))
}) })
## reset prediction and plot settings when the dataset changes ## reset prediction and plot settings when the dataset changes
observeEvent(input$dataset, { observeEvent(input$dataset, {
updateSelectInput(session = session, inputId = "nn_predict", selected = "none") updateSelectInput(session = session, inputId = "nn_predict", selected = "none")
updateSelectInput(session = session, inputId = "nn_plots", selected = "none") updateSelectInput(session = session, inputId = "nn_plots", selected = "none")
}) })
## reset prediction settings when the model type changes ## reset prediction settings when the model type changes
observeEvent(input$nn_type, { observeEvent(input$nn_type, {
updateSelectInput(session = session, inputId = "nn_predict", selected = "none") updateSelectInput(session = session, inputId = "nn_predict", selected = "none")
updateSelectInput(session = session, inputId = "nn_plots", selected = "none") updateSelectInput(session = session, inputId = "nn_plots", selected = "none")
}) })
output$ui_nn_predict_plot <- renderUI({ output$ui_nn_predict_plot <- renderUI({
predict_plot_controls("nn") predict_plot_controls("nn")
}) })
output$ui_nn_plots <- renderUI({ output$ui_nn_plots <- renderUI({
req(input$nn_type) req(input$nn_type)
if (input$nn_type != "regression") { if (input$nn_type != "regression") {
nn_plots <- head(nn_plots, -1) nn_plots <- head(nn_plots, -1)
} }
selectInput( selectInput(
"nn_plots", i18n$t("Plots:"), "nn_plots", i18n$t("Plots:"),
choices = nn_plots, choices = nn_plots,
selected = state_single("nn_plots", nn_plots) selected = state_single("nn_plots", nn_plots)
) )
}) })
output$ui_nn_nrobs <- renderUI({ output$ui_nn_nrobs <- renderUI({
nrobs <- nrow(.get_data()) nrobs <- nrow(.get_data())
choices <- c("1,000" = 1000, "5,000" = 5000, "10,000" = 10000, "All" = -1) %>% choices <- c("1,000" = 1000, "5,000" = 5000, "10,000" = 10000, "All" = -1) %>%
.[. < nrobs] .[. < nrobs]
selectInput( selectInput(
"nn_nrobs", i18n$t("Number of data points plotted:"), "nn_nrobs", i18n$t("Number of data points plotted:"),
choices = choices, choices = choices,
selected = state_single("nn_nrobs", choices, 1000) selected = state_single("nn_nrobs", choices, 1000)
) )
}) })
## add a spinning refresh icon if the model needs to be (re)estimated ## add a spinning refresh icon if the model needs to be (re)estimated
run_refresh(nn_args, "nn", tabs = "tabs_nn", label = i18n$t("Estimate model"), relabel = i18n$t("Re-estimate model")) run_refresh(nn_args, "nn", tabs = "tabs_nn", label = i18n$t("Estimate model"), relabel = i18n$t("Re-estimate model"))
output$ui_nn <- renderUI({ output$ui_nn <- renderUI({
req(input$dataset) req(input$dataset)
tagList( tagList(
conditionalPanel( conditionalPanel(
condition = "input.tabs_nn == 'Summary'", condition = "input.tabs_nn == 'Summary'",
wellPanel( wellPanel(
actionButton("nn_run", i18n$t("Estimate model"), width = "100%", icon = icon("play", verify_fa = FALSE), class = "btn-success") actionButton("nn_run", i18n$t("Estimate model"), width = "100%", icon = icon("play", verify_fa = FALSE), class = "btn-success")
) )
), ),
wellPanel( wellPanel(
conditionalPanel( conditionalPanel(
condition = "input.tabs_nn == 'Summary'", condition = "input.tabs_nn == 'Summary'",
radioButtons( radioButtons(
"nn_type", "nn_type",
label = NULL, label = NULL,
choices = c("classification", "regression") %>% choices = c("classification", "regression") %>%
{ names(.) <- c(i18n$t("Classification"), i18n$t("Regression")); . }, { names(.) <- c(i18n$t("Classification"), i18n$t("Regression")); . },
inline = TRUE inline = TRUE
), ),
uiOutput("ui_nn_rvar"), uiOutput("ui_nn_rvar"),
uiOutput("ui_nn_lev"), uiOutput("ui_nn_lev"),
uiOutput("ui_nn_evar"), uiOutput("ui_nn_evar"),
uiOutput("ui_nn_wts"), uiOutput("ui_nn_wts"),
tags$table( tags$table(
tags$td(numericInput( tags$td(numericInput(
"nn_size", "nn_size",
label = i18n$t("Size:"), min = 1, max = 20, label = i18n$t("Size:"), min = 1, max = 20,
value = state_init("nn_size", 1), width = "77px" value = state_init("nn_size", 1), width = "77px"
)), )),
tags$td(numericInput( tags$td(numericInput(
"nn_decay", "nn_decay",
label = i18n$t("Decay:"), min = 0, max = 1, label = i18n$t("Decay:"), min = 0, max = 1,
step = .1, value = state_init("nn_decay", .5), width = "77px" step = .1, value = state_init("nn_decay", .5), width = "77px"
)), )),
tags$td(numericInput( tags$td(numericInput(
"nn_seed", "nn_seed",
label = i18n$t("Seed:"), label = i18n$t("Seed:"),
value = state_init("nn_seed", 1234), width = "77px" value = state_init("nn_seed", 1234), width = "77px"
)), )),
width = "100%" width = "100%"
) )
), ),
conditionalPanel( conditionalPanel(
condition = "input.tabs_nn == 'Predict'", condition = "input.tabs_nn == 'Predict'",
selectInput( selectInput(
"nn_predict", "nn_predict",
label = i18n$t("Prediction input type:"), reg_predict, label = i18n$t("Prediction input type:"), reg_predict,
selected = state_single("nn_predict", reg_predict, "none") selected = state_single("nn_predict", reg_predict, "none")
), ),
conditionalPanel( conditionalPanel(
"input.nn_predict == 'data' | input.nn_predict == 'datacmd'", "input.nn_predict == 'data' | input.nn_predict == 'datacmd'",
selectizeInput( selectizeInput(
inputId = "nn_pred_data", label = i18n$t("Prediction data:"), inputId = "nn_pred_data", label = i18n$t("Prediction data:"),
choices = c("None" = "", r_info[["datasetlist"]]), choices = c("None" = "", r_info[["datasetlist"]]),
selected = state_single("nn_pred_data", c("None" = "", r_info[["datasetlist"]])), selected = state_single("nn_pred_data", c("None" = "", r_info[["datasetlist"]])),
multiple = FALSE multiple = FALSE
) )
), ),
conditionalPanel( conditionalPanel(
"input.nn_predict == 'cmd' | input.nn_predict == 'datacmd'", "input.nn_predict == 'cmd' | input.nn_predict == 'datacmd'",
returnTextAreaInput( returnTextAreaInput(
"nn_pred_cmd", i18n$t("Prediction command:"), "nn_pred_cmd", i18n$t("Prediction command:"),
value = state_init("nn_pred_cmd", ""), value = state_init("nn_pred_cmd", ""),
rows = 3, rows = 3,
placeholder = i18n$t("Type a formula to set values for model variables (e.g., carat = 1; cut = 'Ideal') and press return") placeholder = i18n$t("Type a formula to set values for model variables (e.g., carat = 1; cut = 'Ideal') and press return")
) )
), ),
conditionalPanel( conditionalPanel(
condition = "input.nn_predict != 'none'", condition = "input.nn_predict != 'none'",
checkboxInput("nn_pred_plot", i18n$t("Plot predictions"), state_init("nn_pred_plot", FALSE)), checkboxInput("nn_pred_plot", i18n$t("Plot predictions"), state_init("nn_pred_plot", FALSE)),
conditionalPanel( conditionalPanel(
"input.nn_pred_plot == true", "input.nn_pred_plot == true",
uiOutput("ui_nn_predict_plot") uiOutput("ui_nn_predict_plot")
) )
), ),
## only show if full data is used for prediction ## only show if full data is used for prediction
conditionalPanel( conditionalPanel(
"input.nn_predict == 'data' | input.nn_predict == 'datacmd'", "input.nn_predict == 'data' | input.nn_predict == 'datacmd'",
tags$table( tags$table(
tags$td(uiOutput("ui_nn_store_pred_name")), tags$td(uiOutput("ui_nn_store_pred_name")),
tags$td(actionButton("nn_store_pred", i18n$t("Store"), icon = icon("plus", verify_fa = FALSE)), class = "top") tags$td(actionButton("nn_store_pred", i18n$t("Store"), icon = icon("plus", verify_fa = FALSE)), class = "top")
) )
) )
), ),
conditionalPanel( conditionalPanel(
condition = "input.tabs_nn == 'Plot'", condition = "input.tabs_nn == 'Plot'",
uiOutput("ui_nn_plots"), uiOutput("ui_nn_plots"),
conditionalPanel( conditionalPanel(
condition = "input.nn_plots == 'pdp' | input.nn_plots == 'pred_plot'", condition = "input.nn_plots == 'pdp' | input.nn_plots == 'pred_plot'",
uiOutput("ui_nn_incl"), uiOutput("ui_nn_incl"),
uiOutput("ui_nn_incl_int") uiOutput("ui_nn_incl_int")
), ),
conditionalPanel( conditionalPanel(
condition = "input.nn_plots == 'dashboard'", condition = "input.nn_plots == 'dashboard'",
uiOutput("ui_nn_nrobs") uiOutput("ui_nn_nrobs")
) )
), ),
conditionalPanel( conditionalPanel(
condition = "input.tabs_nn == 'Summary'", condition = "input.tabs_nn == 'Summary'",
tags$table( tags$table(
tags$td(uiOutput("ui_nn_store_res_name")), tags$td(uiOutput("ui_nn_store_res_name")),
tags$td(actionButton("nn_store_res", i18n$t("Store"), icon = icon("plus", verify_fa = FALSE)), class = "top") tags$td(actionButton("nn_store_res", i18n$t("Store"), icon = icon("plus", verify_fa = FALSE)), class = "top")
) )
) )
), ),
help_and_report( help_and_report(
modal_title = i18n$t("Neural Network"), modal_title = i18n$t("Neural Network"),
fun_name = "nn", fun_name = "nn",
help_file = inclMD(file.path(getOption("radiant.path.model"), "app/tools/help/nn.md")) help_file = inclMD(file.path(getOption("radiant.path.model"), "app/tools/help/nn.md"))
) )
) )
}) })
nn_plot <- reactive({ nn_plot <- reactive({
if (nn_available() != "available") { if (nn_available() != "available") {
return() return()
} }
if (is.empty(input$nn_plots, "none")) { if (is.empty(input$nn_plots, "none")) {
return() return()
} }
res <- .nn() res <- .nn()
if (is.character(res)) { if (is.character(res)) {
return() return()
} }
plot_width <- 650 plot_width <- 650
if ("dashboard" %in% input$nn_plots) { if ("dashboard" %in% input$nn_plots) {
plot_height <- 750 plot_height <- 750
} else if (input$nn_plots %in% c("pdp", "pred_plot")) { } else if (input$nn_plots %in% c("pdp", "pred_plot")) {
nr_vars <- length(input$nn_incl) + length(input$nn_incl_int) nr_vars <- length(input$nn_incl) + length(input$nn_incl_int)
plot_height <- max(250, ceiling(nr_vars / 2) * 250) plot_height <- max(250, ceiling(nr_vars / 2) * 250)
if (length(input$nn_incl_int) > 0) { if (length(input$nn_incl_int) > 0) {
plot_width <- plot_width + min(2, length(input$nn_incl_int)) * 90 plot_width <- plot_width + min(2, length(input$nn_incl_int)) * 90
} }
} else { } else {
mlt <- if ("net" %in% input$nn_plots) 45 else 30 mlt <- if ("net" %in% input$nn_plots) 45 else 30
plot_height <- max(500, length(res$model$coefnames) * mlt) plot_height <- max(500, length(res$model$coefnames) * mlt)
} }
list(plot_width = plot_width, plot_height = plot_height) list(plot_width = plot_width, plot_height = plot_height)
}) })
nn_plot_width <- function() { nn_plot_width <- function() {
nn_plot() %>% nn_plot() %>%
(function(x) if (is.list(x)) x$plot_width else 650) (function(x) if (is.list(x)) x$plot_width else 650)
} }
nn_plot_height <- function() { nn_plot_height <- function() {
nn_plot() %>% nn_plot() %>%
(function(x) if (is.list(x)) x$plot_height else 500) (function(x) if (is.list(x)) x$plot_height else 500)
} }
nn_pred_plot_height <- function() { nn_pred_plot_height <- function() {
if (input$nn_pred_plot) 500 else 1 if (input$nn_pred_plot) 500 else 1
} }
## output is called from the main radiant ui.R ## output is called from the main radiant ui.R
output$nn <- renderUI({ output$nn <- renderUI({
register_print_output("summary_nn", ".summary_nn") register_print_output("summary_nn", ".summary_nn")
register_print_output("predict_nn", ".predict_print_nn") register_print_output("predict_nn", ".predict_print_nn")
register_plot_output( register_plot_output(
"predict_plot_nn", ".predict_plot_nn", "predict_plot_nn", ".predict_plot_nn",
height_fun = "nn_pred_plot_height" height_fun = "nn_pred_plot_height"
) )
register_plot_output( register_plot_output(
"plot_nn", ".plot_nn", "plot_nn", ".plot_nn",
height_fun = "nn_plot_height", height_fun = "nn_plot_height",
width_fun = "nn_plot_width" width_fun = "nn_plot_width"
) )
## three separate tabs ## three separate tabs
nn_output_panels <- tabsetPanel( nn_output_panels <- tabsetPanel(
id = "tabs_nn", id = "tabs_nn",
tabPanel( tabPanel(
i18n$t("Summary"), value = "Summary", i18n$t("Summary"), value = "Summary",
verbatimTextOutput("summary_nn") verbatimTextOutput("summary_nn")
), ),
tabPanel( tabPanel(
i18n$t("Predict"), value = "Predict", i18n$t("Predict"), value = "Predict",
conditionalPanel( conditionalPanel(
"input.nn_pred_plot == true", "input.nn_pred_plot == true",
download_link("dlp_nn_pred"), download_link("dlp_nn_pred"),
plotOutput("predict_plot_nn", width = "100%", height = "100%") plotOutput("predict_plot_nn", width = "100%", height = "100%")
), ),
download_link("dl_nn_pred"), br(), download_link("dl_nn_pred"), br(),
verbatimTextOutput("predict_nn") verbatimTextOutput("predict_nn")
), ),
tabPanel( tabPanel(
i18n$t("Plot"), value = "Plot", i18n$t("Plot"), value = "Plot",
download_link("dlp_nn"), download_link("dlp_nn"),
plotOutput("plot_nn", width = "100%", height = "100%") plotOutput("plot_nn", width = "100%", height = "100%")
) )
) )
stat_tab_panel( stat_tab_panel(
menu = i18n$t("Model > Estimate"), menu = i18n$t("Model > Estimate"),
tool = i18n$t("Neural Network"), tool = i18n$t("Neural Network"),
tool_ui = "ui_nn", tool_ui = "ui_nn",
output_panels = nn_output_panels output_panels = nn_output_panels
) )
}) })
nn_available <- reactive({ nn_available <- reactive({
req(input$nn_type) req(input$nn_type)
if (not_available(input$nn_rvar)) { if (not_available(input$nn_rvar)) {
if (input$nn_type == "classification") { if (input$nn_type == "classification") {
i18n$t("This analysis requires a response variable with two levels and one\nor more explanatory variables. If these variables are not available\nplease select another dataset.\n\n") %>% i18n$t("This analysis requires a response variable with two levels and one\nor more explanatory variables. If these variables are not available\nplease select another dataset.\n\n") %>%
suggest_data("titanic") suggest_data("titanic")
} else { } else {
i18n$t("This analysis requires a response variable of type integer\nor numeric and one or more explanatory variables.\nIf these variables are not available please select another dataset.\n\n") %>% i18n$t("This analysis requires a response variable of type integer\nor numeric and one or more explanatory variables.\nIf these variables are not available please select another dataset.\n\n") %>%
suggest_data("diamonds") suggest_data("diamonds")
} }
} else if (not_available(input$nn_evar)) { } else if (not_available(input$nn_evar)) {
if (input$nn_type == "classification") { if (input$nn_type == "classification") {
i18n$t("Please select one or more explanatory variables.") %>% i18n$t("Please select one or more explanatory variables.") %>%
suggest_data("titanic") suggest_data("titanic")
} else { } else {
i18n$t("Please select one or more explanatory variables.") %>% i18n$t("Please select one or more explanatory variables.") %>%
suggest_data("diamonds") suggest_data("diamonds")
} }
} else { } else {
"available" "available"
} }
}) })
.nn <- eventReactive(input$nn_run, { .nn <- eventReactive(input$nn_run, {
nni <- nn_inputs() nni <- nn_inputs()
nni$envir <- r_data nni$envir <- r_data
withProgress( withProgress(
message = i18n$t("Estimating model"), value = 1, message = i18n$t("Estimating model"), value = 1,
do.call(nn, nni) do.call(nn, nni)
) )
}) })
.summary_nn <- reactive({ .summary_nn <- reactive({
if (not_pressed(input$nn_run)) { if (not_pressed(input$nn_run)) {
return(i18n$t("** Press the Estimate button to estimate the model **")) return(i18n$t("** Press the Estimate button to estimate the model **"))
} }
if (nn_available() != "available") { if (nn_available() != "available") {
return(nn_available()) return(nn_available())
} }
summary(.nn()) summary(.nn())
}) })
.predict_nn <- reactive({ .predict_nn <- reactive({
if (not_pressed(input$nn_run)) { if (not_pressed(input$nn_run)) {
return(i18n$t("** Press the Estimate button to estimate the model **")) return(i18n$t("** Press the Estimate button to estimate the model **"))
} }
if (nn_available() != "available") { if (nn_available() != "available") {
return(nn_available()) return(nn_available())
} }
if (is.empty(input$nn_predict, "none")) { if (is.empty(input$nn_predict, "none")) {
return(i18n$t("** Select prediction input **")) return(i18n$t("** Select prediction input **"))
} }
if ((input$nn_predict == "data" || input$nn_predict == "datacmd") && is.empty(input$nn_pred_data)) { if ((input$nn_predict == "data" || input$nn_predict == "datacmd") && is.empty(input$nn_pred_data)) {
return(i18n$t("** Select data for prediction **")) return(i18n$t("** Select data for prediction **"))
} }
if (input$nn_predict == "cmd" && is.empty(input$nn_pred_cmd)) { if (input$nn_predict == "cmd" && is.empty(input$nn_pred_cmd)) {
return(i18n$t("** Enter prediction commands **")) return(i18n$t("** Enter prediction commands **"))
} }
withProgress(message = i18n$t("Generating predictions"), value = 1, { withProgress(message = i18n$t("Generating predictions"), value = 1, {
nni <- nn_pred_inputs() nni <- nn_pred_inputs()
nni$object <- .nn() nni$object <- .nn()
nni$envir <- r_data nni$envir <- r_data
do.call(predict, nni) do.call(predict, nni)
}) })
}) })
.predict_print_nn <- reactive({ .predict_print_nn <- reactive({
.predict_nn() %>% .predict_nn() %>%
{ {
if (is.character(.)) cat(., "\n") else print(.) if (is.character(.)) cat(., "\n") else print(.)
} }
}) })
.predict_plot_nn <- reactive({ .predict_plot_nn <- reactive({
req( req(
pressed(input$nn_run), input$nn_pred_plot, pressed(input$nn_run), input$nn_pred_plot,
available(input$nn_xvar), available(input$nn_xvar),
!is.empty(input$nn_predict, "none") !is.empty(input$nn_predict, "none")
) )
# if (not_pressed(input$nn_run)) return(invisible()) # if (not_pressed(input$nn_run)) return(invisible())
# if (nn_available() != "available") return(nn_available()) # if (nn_available() != "available") return(nn_available())
# req(input$nn_pred_plot, available(input$nn_xvar)) # req(input$nn_pred_plot, available(input$nn_xvar))
# if (is.empty(input$nn_predict, "none")) return(invisible()) # if (is.empty(input$nn_predict, "none")) return(invisible())
# if ((input$nn_predict == "data" || input$nn_predict == "datacmd") && is.empty(input$nn_pred_data)) { # if ((input$nn_predict == "data" || input$nn_predict == "datacmd") && is.empty(input$nn_pred_data)) {
# return(invisible()) # return(invisible())
# } # }
# if (input$nn_predict == "cmd" && is.empty(input$nn_pred_cmd)) { # if (input$nn_predict == "cmd" && is.empty(input$nn_pred_cmd)) {
# return(invisible()) # return(invisible())
# } # }
withProgress(message = i18n$t("Generating prediction plot"), value = 1, { withProgress(message = i18n$t("Generating prediction plot"), value = 1, {
do.call(plot, c(list(x = .predict_nn()), nn_pred_plot_inputs())) do.call(plot, c(list(x = .predict_nn()), nn_pred_plot_inputs()))
}) })
}) })
.plot_nn <- reactive({ .plot_nn <- reactive({
if (not_pressed(input$nn_run)) { if (not_pressed(input$nn_run)) {
return(i18n$t("** Press the Estimate button to estimate the model **")) return(i18n$t("** Press the Estimate button to estimate the model **"))
} else if (nn_available() != "available") { } else if (nn_available() != "available") {
return(nn_available()) return(nn_available())
} }
req(input$nn_size) req(input$nn_size)
if (is.empty(input$nn_plots, "none")) { if (is.empty(input$nn_plots, "none")) {
return(i18n$t("Please select a neural network plot from the drop-down menu")) return(i18n$t("Please select a neural network plot from the drop-down menu"))
} }
pinp <- nn_plot_inputs() pinp <- nn_plot_inputs()
pinp$shiny <- TRUE pinp$shiny <- TRUE
pinp$size <- NULL pinp$size <- NULL
if (input$nn_plots == "dashboard") { if (input$nn_plots == "dashboard") {
req(input$nn_nrobs) req(input$nn_nrobs)
} }
if (input$nn_plots == "net") { if (input$nn_plots == "net") {
.nn() %>% .nn() %>%
(function(x) if (is.character(x)) invisible() else capture_plot(do.call(plot, c(list(x = x), pinp)))) (function(x) if (is.character(x)) invisible() else capture_plot(do.call(plot, c(list(x = x), pinp))))
} else { } else {
withProgress(message = i18n$t("Generating plots"), value = 1, { withProgress(message = i18n$t("Generating plots"), value = 1, {
do.call(plot, c(list(x = .nn()), pinp)) do.call(plot, c(list(x = .nn()), pinp))
}) })
} }
}) })
observeEvent(input$nn_store_res, { observeEvent(input$nn_store_res, {
req(pressed(input$nn_run)) req(pressed(input$nn_run))
robj <- .nn() robj <- .nn()
if (!is.list(robj)) { if (!is.list(robj)) {
return() return()
} }
fixed <- fix_names(input$nn_store_res_name) fixed <- fix_names(input$nn_store_res_name)
updateTextInput(session, "nn_store_res_name", value = fixed) updateTextInput(session, "nn_store_res_name", value = fixed)
withProgress( withProgress(
message = i18n$t("Storing residuals"), value = 1, message = i18n$t("Storing residuals"), value = 1,
r_data[[input$dataset]] <- store(r_data[[input$dataset]], robj, name = fixed) r_data[[input$dataset]] <- store(r_data[[input$dataset]], robj, name = fixed)
) )
}) })
observeEvent(input$nn_store_pred, { observeEvent(input$nn_store_pred, {
req(!is.empty(input$nn_pred_data), pressed(input$nn_run)) req(!is.empty(input$nn_pred_data), pressed(input$nn_run))
pred <- .predict_nn() pred <- .predict_nn()
if (is.null(pred)) { if (is.null(pred)) {
return() return()
} }
fixed <- fix_names(input$nn_store_pred_name) fixed <- fix_names(input$nn_store_pred_name)
updateTextInput(session, "nn_store_pred_name", value = fixed) updateTextInput(session, "nn_store_pred_name", value = fixed)
withProgress( withProgress(
message = i18n$t("Storing predictions"), value = 1, message = i18n$t("Storing predictions"), value = 1,
r_data[[input$nn_pred_data]] <- store( r_data[[input$nn_pred_data]] <- store(
r_data[[input$nn_pred_data]], pred, r_data[[input$nn_pred_data]], pred,
name = fixed name = fixed
) )
) )
}) })
nn_report <- function() { nn_report <- function() {
if (is.empty(input$nn_evar)) { if (is.empty(input$nn_evar)) {
return(invisible()) return(invisible())
} }
outputs <- c("summary") outputs <- c("summary")
inp_out <- list(list(prn = TRUE), "") inp_out <- list(list(prn = TRUE), "")
figs <- FALSE figs <- FALSE
if (!is.empty(input$nn_plots, "none")) { if (!is.empty(input$nn_plots, "none")) {
inp <- check_plot_inputs(nn_plot_inputs()) inp <- check_plot_inputs(nn_plot_inputs())
inp$size <- NULL inp$size <- NULL
inp_out[[2]] <- clean_args(inp, nn_plot_args[-1]) inp_out[[2]] <- clean_args(inp, nn_plot_args[-1])
inp_out[[2]]$custom <- FALSE inp_out[[2]]$custom <- FALSE
outputs <- c(outputs, "plot") outputs <- c(outputs, "plot")
figs <- TRUE figs <- TRUE
} }
if (!is.empty(input$nn_store_res_name)) { if (!is.empty(input$nn_store_res_name)) {
fixed <- fix_names(input$nn_store_res_name) fixed <- fix_names(input$nn_store_res_name)
updateTextInput(session, "nn_store_res_name", value = fixed) updateTextInput(session, "nn_store_res_name", value = fixed)
xcmd <- paste0(input$dataset, " <- store(", input$dataset, ", result, name = \"", fixed, "\")\n") xcmd <- paste0(input$dataset, " <- store(", input$dataset, ", result, name = \"", fixed, "\")\n")
} else { } else {
xcmd <- "" xcmd <- ""
} }
if (!is.empty(input$nn_predict, "none") && if (!is.empty(input$nn_predict, "none") &&
(!is.empty(input$nn_pred_data) || !is.empty(input$nn_pred_cmd))) { (!is.empty(input$nn_pred_data) || !is.empty(input$nn_pred_cmd))) {
pred_args <- clean_args(nn_pred_inputs(), nn_pred_args[-1]) pred_args <- clean_args(nn_pred_inputs(), nn_pred_args[-1])
if (!is.empty(pred_args$pred_cmd)) { if (!is.empty(pred_args$pred_cmd)) {
pred_args$pred_cmd <- strsplit(pred_args$pred_cmd, ";\\s*")[[1]] pred_args$pred_cmd <- strsplit(pred_args$pred_cmd, ";\\s*")[[1]]
} else { } else {
pred_args$pred_cmd <- NULL pred_args$pred_cmd <- NULL
} }
if (!is.empty(pred_args$pred_data)) { if (!is.empty(pred_args$pred_data)) {
pred_args$pred_data <- as.symbol(pred_args$pred_data) pred_args$pred_data <- as.symbol(pred_args$pred_data)
} else { } else {
pred_args$pred_data <- NULL pred_args$pred_data <- NULL
} }
inp_out[[2 + figs]] <- pred_args inp_out[[2 + figs]] <- pred_args
outputs <- c(outputs, "pred <- predict") outputs <- c(outputs, "pred <- predict")
xcmd <- paste0(xcmd, "print(pred, n = 10)") xcmd <- paste0(xcmd, "print(pred, n = 10)")
if (input$nn_predict %in% c("data", "datacmd")) { if (input$nn_predict %in% c("data", "datacmd")) {
fixed <- fix_names(input$nn_store_pred_name) fixed <- fix_names(input$nn_store_pred_name)
updateTextInput(session, "nn_store_pred_name", value = fixed) updateTextInput(session, "nn_store_pred_name", value = fixed)
xcmd <- paste0( xcmd <- paste0(
xcmd, "\n", input$nn_pred_data, " <- store(", xcmd, "\n", input$nn_pred_data, " <- store(",
input$nn_pred_data, ", pred, name = \"", fixed, "\")" input$nn_pred_data, ", pred, name = \"", fixed, "\")"
) )
} }
if (input$nn_pred_plot && !is.empty(input$nn_xvar)) { if (input$nn_pred_plot && !is.empty(input$nn_xvar)) {
inp_out[[3 + figs]] <- clean_args(nn_pred_plot_inputs(), nn_pred_plot_args[-1]) inp_out[[3 + figs]] <- clean_args(nn_pred_plot_inputs(), nn_pred_plot_args[-1])
inp_out[[3 + figs]]$result <- "pred" inp_out[[3 + figs]]$result <- "pred"
outputs <- c(outputs, "plot") outputs <- c(outputs, "plot")
figs <- TRUE figs <- TRUE
} }
} }
nn_inp <- nn_inputs() nn_inp <- nn_inputs()
if (input$nn_type == "regression") { if (input$nn_type == "regression") {
nn_inp$lev <- NULL nn_inp$lev <- NULL
} }
update_report( update_report(
inp_main = clean_args(nn_inp, nn_args), inp_main = clean_args(nn_inp, nn_args),
fun_name = "nn", fun_name = "nn",
inp_out = inp_out, inp_out = inp_out,
outputs = outputs, outputs = outputs,
figs = figs, figs = figs,
fig.width = nn_plot_width(), fig.width = nn_plot_width(),
fig.height = nn_plot_height(), fig.height = nn_plot_height(),
xcmd = xcmd xcmd = xcmd
) )
} }
dl_nn_pred <- function(path) { dl_nn_pred <- function(path) {
if (pressed(input$nn_run)) { if (pressed(input$nn_run)) {
write.csv(.predict_nn(), file = path, row.names = FALSE) write.csv(.predict_nn(), file = path, row.names = FALSE)
} else { } else {
cat(i18n$t("No output available. Press the Estimate button to generate results"), file = path) cat(i18n$t("No output available. Press the Estimate button to generate results"), file = path)
} }
} }
download_handler( download_handler(
id = "dl_nn_pred", id = "dl_nn_pred",
fun = dl_nn_pred, fun = dl_nn_pred,
fn = function() paste0(input$dataset, "_nn_pred"), fn = function() paste0(input$dataset, "_nn_pred"),
type = "csv", type = "csv",
caption = i18n$t("Save predictions") caption = i18n$t("Save predictions")
) )
download_handler( download_handler(
id = "dlp_nn_pred", id = "dlp_nn_pred",
fun = download_handler_plot, fun = download_handler_plot,
fn = function() paste0(input$dataset, "_nn_pred"), fn = function() paste0(input$dataset, "_nn_pred"),
type = "png", type = "png",
caption = i18n$t("Save neural network prediction plot"), caption = i18n$t("Save neural network prediction plot"),
plot = .predict_plot_nn, plot = .predict_plot_nn,
width = plot_width, width = plot_width,
height = nn_pred_plot_height height = nn_pred_plot_height
) )
download_handler( download_handler(
id = "dlp_nn", id = "dlp_nn",
fun = download_handler_plot, fun = download_handler_plot,
fn = function() paste0(input$dataset, "_nn"), fn = function() paste0(input$dataset, "_nn"),
type = "png", type = "png",
caption = i18n$t("Save neural network plot"), caption = i18n$t("Save neural network plot"),
plot = .plot_nn, plot = .plot_nn,
width = nn_plot_width, width = nn_plot_width,
height = nn_plot_height height = nn_plot_height
) )
observeEvent(input$nn_report, { observeEvent(input$nn_report, {
r_info[["latest_screenshot"]] <- NULL r_info[["latest_screenshot"]] <- NULL
nn_report() nn_report()
}) })
observeEvent(input$nn_screenshot, { observeEvent(input$nn_screenshot, {
r_info[["latest_screenshot"]] <- NULL r_info[["latest_screenshot"]] <- NULL
radiant_screenshot_modal("modal_nn_screenshot") radiant_screenshot_modal("modal_nn_screenshot")
}) })
observeEvent(input$modal_nn_screenshot, { observeEvent(input$modal_nn_screenshot, {
nn_report() nn_report()
removeModal() ## remove shiny modal after save removeModal() ## remove shiny modal after save
}) })
% Generated by roxygen2: do not edit by hand % Generated by roxygen2: do not edit by hand
% Please edit documentation in R/nn.R, R/svm.R % Please edit documentation in R/nn.R
\name{scale_df} \name{scale_df}
\alias{scale_df} \alias{scale_df}
\title{Center or standardize variables in a data frame} \title{Center or standardize variables in a data frame}
\usage{ \usage{
scale_df(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE)
scale_df(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE) scale_df(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE)
} }
\arguments{ \arguments{
...@@ -25,7 +23,5 @@ scale_df(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE) ...@@ -25,7 +23,5 @@ scale_df(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE)
Scaled data frame Scaled data frame
} }
\description{ \description{
Center or standardize variables in a data frame
Center or standardize variables in a data frame Center or standardize variables in a data frame
} }
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/svm.R
\name{scale_sv}
\alias{scale_sv}
\title{Center or standardize variables in a data frame}
\usage{
scale_sv(dataset, center = TRUE, scale = TRUE, sf = 2, wts = NULL, calc = TRUE)
}
\description{
Center or standardize variables in a data frame
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment