#' Cox Proportional Hazards Regression
#'
#' @export
coxp <- function(dataset,
                 time,
                 status,
                 evar,
                 int = "",
                 check = "",
                 form,
                 data_filter = "",
                 arr = "",
                 rows = NULL,
                 envir = parent.frame()) {
  
  if (!requireNamespace("survival", quietly = TRUE))
    stop("survival package is required but not installed.")
  attachNamespace("survival") 
  on.exit(detach("package:survival"), add = TRUE)
  
  ## ---- 公式入口 ----------------------------------------------------------
  if (!missing(form)) {
    form  <- as.formula(format(form))
    vars  <- all.vars(form)
    time  <- vars[1]
    status<- vars[2]
    evar  <- vars[-(1:2)]
  }
  
  ## ---- 基础检查 ----------------------------------------------------------
  if (time %in% evar || status %in% evar) {
    return("Time/status variable contained in explanatory variables." %>%
             add_class("coxp"))
  }
  
  vars <- unique(c(time, status, evar))
  df_name <- if (is_string(dataset)) dataset else deparse(substitute(dataset))
  dataset <- get_data(dataset, vars, filt = data_filter, arr = arr, rows = rows, envir = envir)
  
  ## 状态变量检查与转换
  surv_status <- dataset[[status]]
  if (!is.numeric(surv_status)) {
    ## 允许 0/1、FALSE/TRUE、factor(未事件/事件) 等常见编码
    if (is.factor(surv_status) || is.character(surv_status)) {
      lv <- unique(surv_status)
      if (length(lv) != 2) {
        return("Status variable must be binary (0/1 or two levels)." %>% add_class("coxp"))
      }
      dataset[[status]] <- as.numeric(factor(surv_status, levels = lv)) - 1L
    } else {
      return("Status variable must be numeric 0/1 or binary factor." %>% add_class("coxp"))
    }
  } else {
    if (!all(unique(surv_status) %in% c(0, 1))) {
      return("Status variable must contain only 0 and 1." %>% add_class("coxp"))
    }
  }
  
  if (missing(form)) {
    rhs <- if (length(evar) == 0) "1" else paste(evar, collapse = " + ")
    if (!is.empty(int)) rhs <- paste(rhs, paste(int, collapse = " + "), sep = " + ")
    form <- as.formula(paste("Surv(", time, ", ", status, ") ~ ", rhs))
  }
  
  if ("robust" %in% check) {
    model <- survival::coxph(form, data = dataset, robust = TRUE)
  } else {
    model <- survival::coxph(form, data = dataset)
  }
  
  ## 失败模型保护
  if (inherits(model, "try-error")) {
    return("Model estimation failed. Check data separation or collinearity." %>% add_class("coxp"))
  }
  
  ## 基础摘要信息
  coef_df <- broom::tidy(model, conf.int = TRUE)  # 系数、HR、CI、p
  n  <- nrow(dataset)          # 样本量
  n_event <- sum(dataset[[status]])  # 事件数
  conc <- tryCatch(
  survival::concordancefit(
    y = Surv(dataset[[time]], dataset[[status]]),
    x = predict(model, type = "lp"),
    data = dataset
  )$concordance,
  error = function(e) NA
  )
  ## 打包返回
  out <- as.list(environment())
  out$model   <- model
  out$df_name <- df_name
  out$type    <- "survival"
  out$check   <- check
  ## 附加对象
  out$coef_df <- coef_df
  out$n       <- n
  out$n_event <- n_event
  out$concordance <- conc
  add_class(out, c("coxp", "model"))
}

#' @export
summary.coxp <- function(object, dec = 3, ...) {
  if (is.character(object)) {
    cat(object, "\n")
    return(invisible(object))
  }
  
  if (!inherits(object$model, "coxph")) {
    cat("** Invalid Cox model object. **\n")
    return(invisible(object))
  }
  
  ## 基础模型信息
  cat("Cox Proportional Hazards Regression\n")
  cat("Data:", object$df_name, "\n")
  if (!is.empty(object$data_filter)) {
    cat("Filter:", gsub("\\n", "", object$data_filter), "\n")
  }
  if (!is.empty(object$arr)) {
    cat("Arrange:", gsub("\\n", "", object$arr), "\n")
  }
  if (!is.empty(object$rows)) {
    cat("Slice:", gsub("\\n", "", object$rows), "\n")
  }
  cat("Time variable   :", object$time, "\n")
  cat("Status variable :", object$status, "\n")
  cat("Explanatory vars:", paste(object$evar, collapse = ", "), "\n")
  cat("N =", object$n, ", Events =", object$n_event, "\n")
  cat("Concordance =", sprintf("%.3f", object$concordance), "\n\n")
  
  ## 系数表
  coef_df <- object$coef_df
  coef_df$sig_star <- sig_stars(coef_df$p.value) %>% format(justify = "left")
  coef_df$label <- rownames(coef_df)
  
  ## 格式化输出
  coeff <- coef_df %>%
    mutate(
      HR = sprintf("%.3f", exp(estimate)),
      `HR.low` = sprintf("%.3f", exp(conf.low)),
      `HR.high` = sprintf("%.3f", exp(conf.high)),
      coef = sprintf("%.3f", estimate),
      se = sprintf("%.3f", std.error),
      z = sprintf("%.3f", statistic),
      p = ifelse(p.value < .001, "< .001", sprintf("%.3f", p.value))
    ) %>%
    select(label, coef, se, z, p, sig_star, HR, HR.low, HR.high)
  
  colnames(coeff) <- c(" ", "Coef", "SE", "z", "p", " ", "HR", "HR.lower", "HR.upper")
  print.data.frame(coeff, row.names = FALSE)

  cat("\nSignif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n\n")
  
  ## 模型检验
  sm <- summary(object$model)
  cat("Likelihood ratio test =", sprintf("%.2f", sm$logtest[1]), "on", sm$logtest[2], "df, p =",
      ifelse(sm$logtest[3] < .0001, "< .0001", sprintf("%.4f", sm$logtest[3])), "\n")
  cat("Wald test             =", sprintf("%.2f", sm$waldtest[1]), "on", sm$waldtest[2], "df, p =",
      ifelse(sm$waldtest[3] < .0001, "< .0001", sprintf("%.4f", sm$waldtest[3])), "\n")
  cat("Score (logrank) test  =", sprintf("%.2f", sm$sctest[1]), "on", sm$sctest[2], "df, p =",
      ifelse(sm$sctest[3] < .0001, "< .0001", sprintf("%.4f", sm$sctest[3])), "\n")
  
  invisible(object)
}


#' @export
predict.coxp <- function(object, pred_data = NULL, pred_cmd = "",
                         conf_lev = 0.95, dec = 3, envir = parent.frame(), ...) {
  if (is.character(object)) return(object)
  
  # 1. 构造预测数据
  if (is.null(pred_data)) {
    newdata <- envir$.model_frame %||% object$model$model
  } else {
    # 获取预测数据集（只取模型需要的变量，但先全取以便校验）
    newdata <- get_data(pred_data, vars = NULL, envir = envir)
    
    # 变量存在性校验
    model_evar <- object$evar  # 模型使用的解释变量
    pred_cols  <- colnames(newdata)
    missing_vars <- setdiff(model_evar, pred_cols)
    
    if (length(missing_vars) > 0) {
      return(paste0(
        "All variables in the model must also be in the prediction data\n",
        "Variables in the model: ", paste(model_evar, collapse = ", "), "\n",
        "Variables not available in prediction data: ", paste(missing_vars, collapse = ", ")
      ) %>% add_class("coxp.predict"))
    }
    newdata <- newdata[, model_evar, drop = FALSE]
  }
  
  # 2. 应用 pred_cmd
  if (!is.empty(pred_cmd)) {
    newdata <- modify_data(newdata, pred_cmd, envir = envir)
  }
  
  # 3. 过滤全NA的列
  newdata <- newdata[, colSums(is.na(newdata)) != nrow(newdata), drop = FALSE]
  if (ncol(newdata) == 0 && length(object$evar) > 0) {
    return(paste("预测数据中无有效解释变量列（需包含：", paste(object$evar, collapse = ", "), "）") %>% add_class("coxp.predict"))
  }
  
  # 4. 核心预测计算
  pred_cox <- predict(
    object$model, 
    newdata = newdata, 
    type = "lp", 
    se.fit = TRUE
  )
  z_val <- qnorm((1 + conf_lev) / 2)
  lp_lower <- pred_cox$fit - z_val * pred_cox$se.fit
  lp_upper <- pred_cox$fit + z_val * pred_cox$se.fit
  hr <- exp(pred_cox$fit)
  hr_lower <- exp(lp_lower)
  hr_upper <- exp(lp_upper)
  
  # 5. 构建结果数据框
  pred_result <- data.frame(
    lp = round(pred_cox$fit, dec),
    HR = round(hr, dec),
    lp_lower = round(lp_lower, dec),
    lp_upper = round(lp_upper, dec),
    HR_lower = round(hr_lower, dec),
    HR_upper = round(hr_upper, dec),
    check.names = FALSE,
    stringsAsFactors = FALSE
  )
  pred_full <- cbind(newdata, pred_result)
  
  # 6. 添加元信息
  pred_full <- pred_full %>%
    radiant.data::set_attr("radiant_df_name", object$df_name) %>%
    radiant.data::set_attr("radiant_time", object$time) %>%
    radiant.data::set_attr("radiant_status", object$status) %>%
    radiant.data::set_attr("radiant_evar_actual", colnames(newdata)) %>%
    radiant.data::set_attr("radiant_conf_lev", conf_lev) %>%
    radiant.data::set_attr("radiant_dec", dec) %>%
    add_class(c("coxp.predict", "model.predict"))
  
  return(pred_full)
}


#' @export
print.coxp.predict <- function(x, ..., n = 10) {
  if (is.character(x)) {
    cat(x, "\n")
    return(invisible(x))
  }
  
  # 转为数据框
  x_df <- as.data.frame(x, stringsAsFactors = FALSE)
  
  df_name <- attr(x_df, "radiant_df_name") %||% "Unknown"
  time_var <- attr(x_df, "radiant_time") %||% "Unknown"
  status_var <- attr(x_df, "radiant_status") %||% "Unknown"
  conf_lev <- attr(x_df, "radiant_conf_lev") %||% 0.95
  dec <- attr(x_df, "radiant_dec") %||% 3
  ci_perc <- paste0(c(round((1 - conf_lev) / 2 * 100, 1), 
                      round((1 + conf_lev) / 2 * 100, 1)), "%")
  
  total_cols <- ncol(x_df)
  result_count <- 6 
  
  if (total_cols < result_count) {
    cat("Error: Not enough columns for prediction results (need at least 6 result columns).\n")
    return(invisible(x))
  }
  
  evar_count_actual <- total_cols - result_count
  evar_cols_actual <- colnames(x_df)[1:evar_count_actual]
  
  new_result_names <- c(
    "lp",
    "HR",
    ci_perc[1],                    
    ci_perc[2],                    
    paste0("HR_", ci_perc[1]),     
    paste0("HR_", ci_perc[2])      
  )
  
  # 拼接完整列名向量
  new_colnames <- c(evar_cols_actual, new_result_names)
  
  # 最终校验
  if (length(new_colnames) != total_cols) {
    cat("Error: Column name length mismatch.\n")
    cat("Total columns:", total_cols, "\n")
    cat("Constructed names:", length(new_colnames), "\n")
    cat("evar_cols_actual:", paste(evar_cols_actual, collapse = ", "), "\n")
    cat("new_result_names:", paste(new_result_names, collapse = ", "), "\n")
    return(invisible(x))
  }
  
  # 应用新列名
  colnames(x_df) <- new_colnames
  
  cat("Cox Proportional Hazards Regression\n")
  cat("Data                 :", df_name, "\n")
  cat("Time variable        :", time_var, "\n")
  cat("Status variable      :", status_var, "\n")
  cat("Explanatory variables:", if (length(evar_cols_actual) > 0) paste(evar_cols_actual, collapse = ", ") else "None", "\n")
  cat("Confidence level     :", paste0(conf_lev * 100, "%"), "\n")
  cat("Total columns        :", total_cols, "(Explanatory:", evar_count_actual, ", Result:", result_count, ")\n")
  
  total_rows <- nrow(x_df)
  if (n == -1 || total_rows <= n) {
    cat("Rows shown           :", total_rows, "of", total_rows, "\n")
    out_df <- x_df
  } else {
    cat("Rows shown           :", n, "of", total_rows, "\n")
    out_df <- head(x_df, n)
  }
  cat("\n")
  
  # 格式化数值列
  numeric_cols <- (evar_count_actual + 1):total_cols
  if (length(numeric_cols) > 0) {
    out_df[, numeric_cols] <- lapply(out_df[, numeric_cols, drop = FALSE], function(col) {
      sprintf(paste0("%.", dec, "f"), as.numeric(col))
    })
  }
  
  print(out_df, row.names = FALSE)
  invisible(x)
}

#' @export
store.coxp.predict <- function(dataset, object, name = "hr", ...) {
  if (is.empty(name)) name <- "hr"
  name <- unlist(strsplit(name, "(\\s*,\\s*|\\s*;\\s*|\\s+)")) %>% 
    gsub("\\s", "", .) %>% 
    .[1] 
  
  pred_col <- "HR"
  if (!pred_col %in% colnames(object)) {
    stop("Prediction column 'HR' not found in prediction object.")
  }
  
  pred_df <- object[, pred_col, drop = FALSE]
  colnames(pred_df) <- name
  
  evar_actual <- attr(object, "radiant_evar_actual") %||% 
    attr(object, "radiant_evar") %||% character(0)
  
  indr <- indexr(dataset, vars = evar_actual, filt = "", rows = NULL, 
                 cmd = attr(object, "radiant_pred_cmd"))
  
  out_df <- as.data.frame(matrix(NA, nrow = nrow(dataset), ncol = 1), stringsAsFactors = FALSE)
  out_df[indr$ind, 1] <- pred_df[[1]]
  colnames(out_df) <- name
  
  dataset[, name] <- out_df
  dataset
}

#' @export
plot.coxp <- function(x, plots = "none", incl = NULL, incl_int = NULL,
                      conf_lev = 0.95, intercept = FALSE,
                      shiny = FALSE, custom = FALSE, ...) {
  if (is.character(x)) return(x)
  if (is.empty(plots) || plots == "none") return(invisible())
  
  plot_list <- list()
  
  if ("coef" %in% plots) {
    # 提取系数和 CI
    coef_df <- broom::tidy(x$model, conf.int = TRUE, conf.level = conf_lev)
    coef_df$hr <- exp(coef_df$estimate)
    coef_df$hr_low <- exp(coef_df$conf.low)
    coef_df$hr_high <- exp(coef_df$conf.high)
    coef_df$term <- coef_df$term
    
    if (!intercept) {
      coef_df <- coef_df[!grepl("Intercept", coef_df$term), ]
    }
    
    if (length(incl) > 0) {
      incl_regex <- paste0("^(", paste(incl, collapse = "|"), ")")
      coef_df <- coef_df[grepl(incl_regex, coef_df$term), ]
    }
    
    if (nrow(coef_df) == 0) {
      plot_list[["coef"]] <- "** No coefficients to plot **"
    } else {
      p <- ggplot(coef_df, aes(x = term, y = hr, ymin = hr_low, ymax = hr_high)) +
        geom_pointrange() +
        geom_hline(yintercept = 1, linetype = "dashed", color = "red") +
        scale_x_discrete(limits = rev) +
        coord_flip() +
        labs(x = "", y = "Hazard Ratio (HR)", title = "Coefficient Plot (HR)")
      plot_list[["coef"]] <- p
    }
  }
  
  if ("dist" %in% plots) {
    data <- x$model$model
    vars <- c(x$time, x$status, x$evar)
    for (v in vars) {
      if (v %in% colnames(data)) {
        p <- visualize(data, xvar = v, bins = 30, custom = TRUE)
        plot_list[[paste0("dist_", v)]] <- p
      }
    }
  }
  
  if ("vip" %in% plots) {
    coef_df <- broom::tidy(x$model)
    coef_df$Importance <- abs(coef_df$estimate)
    coef_df <- coef_df[order(coef_df$Importance, decreasing = TRUE), ]
    p <- visualize(coef_df, xvar = "term", yvar = "Importance", type = "bar", custom = TRUE) +
      coord_flip() + labs(title = "Variable Importance (|coef|)")
    plot_list[["vip"]] <- p
  }
  
  if ("pdp" %in% plots || "pred_plot" %in% plots) {
    plot_list[["pdp"]] <- "** PDP not yet implemented for Cox **"
  }
  
  if ("influence" %in% plots) {
    plot_list[["influence"]] <- "** Influence plot not yet implemented **"
  }
  
  # 输出
  if (length(plot_list) == 0) return(invisible())
  if (custom) {
    if (length(plot_list) == 1) return(plot_list[[1]]) else return(plot_list)
  } else {
    patchwork::wrap_plots(plot_list, ncol = 2) %>%
      (function(x) if (isTRUE(shiny)) x else print(x))
  }
}