import argparse import copy import os, sys import pathlib current_dir = pathlib.Path(__file__).parent.resolve() while "cls_train" != os.path.basename(current_dir): current_dir = current_dir.parent sys.path.append(current_dir.as_posix()) os.environ["OMP_NUM_THREADS"] = "1" from cls_utils.log_utils import get_logger from sklearn.metrics import classification_report import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.distributed.optim import ZeroRedundancyOptimizer from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel from pytorch_train.classification_model_2d3d import Net2d3d, Net2d, Net3d, init_modules from data.dataset_2d3d import ClassificationDataset2d3d, ClassificationDataset3d, ClassificationDataset2d, custom_collate_fn_2d, custom_collate_fn_3d, custom_collate_fn_2d3d, cls_report_dict_to_string, ClassificationDatasetError2d, ClassificationDatasetError3d, ClassificationDatasetError2d3d, custom_collate_fn_2d_error, custom_collate_fn_3d_error, custom_collate_fn_2d3d_error from pytorch_train.encoder_cls import S3dClassifier as NetS3d from pytorch_train.encoder_cls import D2dClassifier as NetD2d from data.dataset_2d3d import ClassificationDatasetS3d, ClassificationDatasetErrorS3d, custom_collate_fn_s3d, custom_collate_fn_s3d_error, ClassificationDatasetD2d, ClassificationDatasetErrorD2d, custom_collate_fn_d2d, custom_collate_fn_d2d_error from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix import multiprocessing as mp import traceback import pandas as pd from pathlib import Path from torch.optim import lr_scheduler, Adam import random import numpy as np def set_seed(seed=1004): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def get_model_params(model): params = "" with torch.no_grad(): for name, param in model.named_parameters(): if "net2d.diff_classifier.weight" in name or "net3d.diff_classifier.weight" in name: params += f"{name}\n{param[0][:20]}\n\n" return params def test_on_single_gpu(init_model, config, device_index, train_2d3d_data_2d_csv_file, val_2d3d_data_2d_csv_file, test_2d3d_data_2d_csv_file, train_2d3d_data_3d_csv_file, val_2d3d_data_3d_csv_file, test_2d3d_data_3d_csv_file, train_csv_file, val_csv_file, test_csv_file): logger = config['logger'] net_id = config['net_id'] save_dir = config['save_dir'] current_device = torch.device(f"cuda:{device_index}") test_train_dataset_info_dict = config['test_train_dataset_info_dict'] test_val_dataset_info_dict = config['test_val_dataset_info_dict'] test_test_dataset_info_dict = config['test_test_dataset_info_dict'] train_batch_size = config['train_batch_size'] val_batch_size = config['val_batch_size'] test_batch_size = config['test_batch_size'] num_workers = config['num_workers'] if net_id == "2d": train_dataset = ClassificationDataset2d(train_csv_file, data_info="train_dataset_2d") val_dataset = ClassificationDataset2d(val_csv_file, data_info="val_dataset_2d") test_dataset = ClassificationDataset2d(test_csv_file, data_info="test_dataset_2d") custom_collate_fn = custom_collate_fn_2d elif net_id == "3d": train_dataset = ClassificationDataset3d(train_csv_file, data_info="train_dataset_3d") val_dataset = ClassificationDataset3d(val_csv_file, data_info="val_dataset_3d") test_dataset = ClassificationDataset3d(test_csv_file, data_info="test_dataset_3d") custom_collate_fn = custom_collate_fn_3d elif net_id == "2d3d": train_dataset = ClassificationDataset2d3d(train_2d3d_data_2d_csv_file, train_2d3d_data_3d_csv_file, data_info="train_dataset_2d3d") val_dataset = ClassificationDataset2d3d(val_2d3d_data_2d_csv_file, val_2d3d_data_3d_csv_file, data_info="val_dataset_2d3d") test_dataset = ClassificationDataset2d3d(test_2d3d_data_2d_csv_file, test_2d3d_data_3d_csv_file, data_info="test_dataset_2d3d") custom_collate_fn = custom_collate_fn_2d3d elif net_id == "s3d": train_dataset = ClassificationDatasetS3d(train_csv_file, data_info="train_dataset_S3d") val_dataset = ClassificationDatasetS3d(val_csv_file, data_info="val_dataset_S3d") test_dataset = ClassificationDatasetS3d(test_csv_file, data_info="test_dataset_s3d") custom_collate_fn = custom_collate_fn_s3d elif net_id == "d2d": train_dataset = ClassificationDatasetD2d(train_csv_file, data_info="train_dataset_d2d") val_dataset = ClassificationDatasetD2d(val_csv_file, data_info="val_dataset_d2d") test_dataset = ClassificationDatasetD2d(test_csv_file, data_info="test_dataset_d2d") custom_collate_fn = custom_collate_fn_d2d else: raise ValueError(f"net_id {net_id} not supported") logger.info(f"before start test, train_dataset: {len(train_dataset)}\n val_dataset: {len(val_dataset)}\n test_dataset: {len(test_dataset)}") train_data_loader = DataLoader( train_dataset, batch_size=train_batch_size, drop_last=False, shuffle=False, num_workers=num_workers, collate_fn=custom_collate_fn ) val_data_loader = DataLoader( val_dataset, batch_size=val_batch_size, drop_last=False, shuffle=False, num_workers=num_workers, collate_fn=custom_collate_fn ) test_data_loader = DataLoader( test_dataset, batch_size=test_batch_size, drop_last=False, shuffle=False, num_workers=num_workers, collate_fn=custom_collate_fn ) logger.info(f"local rank: {device_index}, train data total batch: {len(train_data_loader)}, val data total batch: {len(val_data_loader)}, test data total batch: {len(test_data_loader)}") params = get_model_params(copy.deepcopy(init_model)) logger.info(f"local rank: {device_index}\nbefore start test, load init, params:\n{params}") # 评测训练集 torch.cuda.empty_cache() test_train_dataset_flag = test_train_dataset_info_dict["test_train_dataset_flag"] if test_train_dataset_flag: test_train_epoch_list = test_train_dataset_info_dict["test_train_epoch_list"] test_train_epoch_list = sorted(test_train_epoch_list) logger.info(f"local rank: {device_index}, start test_[train]_dataset") for idx_test_train_epoch in test_train_epoch_list: idx_epoch_pt_file = f"train_epoch_{idx_test_train_epoch}_net_{net_id}.pt" idx_epoch_pt_file = os.path.join(save_dir, idx_epoch_pt_file) model = copy.deepcopy(init_model) model.load_state_dict(torch.load(idx_epoch_pt_file, map_location="cpu")) model = model.to(current_device) logger.info(f"local rank: {device_index}, test_[train]_dataset, start test_[train]_dataset epoch: {idx_test_train_epoch}") logger.info(f"local rank: {device_index}, test_[train]_dataset, after load epoch {idx_test_train_epoch}\nparams:\n{get_model_params(model)}") model.eval() train_all_label_id = [] train_all_label_id_2d = [] train_all_label_id_3d = [] train_all_z_index = [] train_all_z_index_2d = [] train_all_z_index_3d = [] train_all_rotate_count = [] train_all_rotate_count_2d = [] train_all_rotate_count_3d = [] train_all_labels = [] train_all_labels_2d = [] train_all_labels_3d = [] train_all_preds = [] train_all_preds_2d = [] train_all_preds_3d = [] model.eval() with torch.no_grad(): logger.info(f"local rank: {device_index}, test_[train]_dataset, start test_[train]_dataset epoch: {idx_test_train_epoch}, train_data_loader: {len(train_data_loader)}") for step, batch_data in enumerate(train_data_loader): if net_id == "2d3d": label_id_2d, data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, data_3d, label_3d, z_index_3d, rotate_count_3d = batch_data train_all_label_id_2d.extend(label_id_2d) train_all_label_id_3d.extend(label_id_3d) train_all_z_index_2d.extend(z_index_2d) train_all_z_index_3d.extend(z_index_3d) train_all_rotate_count_2d.extend(rotate_count_2d) train_all_rotate_count_3d.extend(rotate_count_3d) train_all_labels_2d.append(label_2d.cpu()) train_all_labels_3d.append(label_3d.cpu()) data_2d = data_2d.to(current_device) data_3d = data_3d.to(current_device) label_2d = label_2d.to(current_device) label_3d = label_3d.to(current_device) y_pred_2d, y_pred_3d = model(data_2d, data_3d) train_all_preds_2d.append(y_pred_2d.detach().cpu()) train_all_preds_3d.append(y_pred_3d.detach().cpu()) else: label_id, data, label, z_index, rotate_count = batch_data train_all_label_id.extend(label_id) train_all_z_index.extend(z_index) train_all_rotate_count.extend(rotate_count) train_all_labels.append(label.cpu()) data = data.to(current_device) label = label.to(current_device) y_pred = model(data) train_all_preds.append(y_pred.detach().cpu()) logger.info(f"local rank: {device_index}, test_[train]_dataset, step: {step}, train_data_loader: {len(train_data_loader)}") if net_id == "2d3d": del model, data_2d, data_3d, label_2d, label_3d, y_pred_2d, y_pred_3d, z_index_2d, z_index_3d, rotate_count_2d, rotate_count_3d else: del model, data, label, y_pred, z_index, rotate_count torch.cuda.empty_cache() if net_id == "2d3d": train_all_label_id_2d = train_all_label_id_2d[:] train_all_label_id_3d = train_all_label_id_3d[:] train_all_net_type_2d = ["2d_net"] * len(train_all_label_id_2d) train_all_net_type_3d = ["3d_net"] * len(train_all_label_id_3d) train_all_labels_2d = torch.cat(train_all_labels_2d, dim=0) train_all_labels_3d = torch.cat(train_all_labels_3d, dim=0) train_all_preds_2d = torch.cat(train_all_preds_2d, dim=0) train_all_preds_3d = torch.cat(train_all_preds_3d, dim=0) concat_label_id = train_all_label_id_2d + train_all_label_id_3d concat_z_index = train_all_z_index_2d[:] + train_all_z_index_3d[:] concat_rotate_count = train_all_rotate_count_2d[:] + train_all_rotate_count_3d[:] concat_net_type = train_all_net_type_2d + train_all_net_type_3d concat_label = torch.cat([train_all_labels_2d, train_all_labels_3d], dim=0) concat_pred = torch.cat([train_all_preds_2d, train_all_preds_3d], dim=0) logger.info(f"local rank: {device_index}, test_[train]_dataset, concat_label_id: {len(concat_label_id)}, concat_label: {concat_label.shape}, concat_pred: {concat_pred.shape}") lens = len(concat_label_id) idx_test_train_epoch_df = pd.DataFrame({ "test_type": ["train_dataset"] * lens, "local_rank": [device_index] * lens, "net_id": [net_id] * lens, "net_type": concat_net_type, "epoch": [idx_test_train_epoch] * lens, "label_id": concat_label_id, "z_index": concat_z_index, "rotate_count": concat_rotate_count, "label": concat_label.tolist(), "pred": concat_pred.tolist() }) idx_test_train_epoch_df_file = f"local_rank_{device_index}_test_type_train_dataset_epoch_{idx_test_train_epoch}_net_{net_id}.csv" idx_test_train_epoch_df_file = os.path.join(save_dir, idx_test_train_epoch_df_file) idx_test_train_epoch_df.to_csv(idx_test_train_epoch_df_file, index=False, encoding="utf-8") logger.info(f"local rank: {device_index}, test_[train]_dataset epoch: {idx_test_train_epoch}, save to {idx_test_train_epoch_df_file}") else: train_all_label_id = train_all_label_id[:] train_all_labels = torch.cat(train_all_labels, dim=0) train_all_preds = torch.cat(train_all_preds, dim=0) lens = len(train_all_label_id) idx_test_train_epoch_df = pd.DataFrame({ "test_type": ["train_dataset"] * lens, "local_rank": [device_index] * lens, "net_id": [net_id] * lens, "net_type": [f"{net_id}_net"] * lens, "epoch": [idx_test_train_epoch] * lens, "label_id": train_all_label_id, "z_index": train_all_z_index[:], "rotate_count": train_all_rotate_count[:], "label": train_all_labels.tolist(), "pred": train_all_preds.tolist() }) idx_test_train_epoch_df_file = f"local_rank_{device_index}_test_type_train_dataset_epoch_{idx_test_train_epoch}_net_{net_id}.csv" idx_test_train_epoch_df_file = os.path.join(save_dir, idx_test_train_epoch_df_file) idx_test_train_epoch_df.to_csv(idx_test_train_epoch_df_file, index=False, encoding="utf-8") logger.info(f"local rank: {device_index}, test_[train]_dataset epoch: {idx_test_train_epoch}, save to {idx_test_train_epoch_df_file}") logger.info(f"local rank: {device_index}, test_[train]_dataset finished") # 评测验证集 torch.cuda.empty_cache() test_val_dataset_flag = test_val_dataset_info_dict["test_val_dataset_flag"] if test_val_dataset_flag: test_val_epoch_list = test_val_dataset_info_dict["test_val_epoch_list"] test_val_epoch_list = sorted(test_val_epoch_list) logger.info(f"local rank: {device_index}, start test_[val]_dataset") for idx_test_val_epoch in test_val_epoch_list: idx_epoch_pt_file = f"train_epoch_{idx_test_val_epoch}_net_{net_id}.pt" idx_epoch_pt_file = os.path.join(save_dir, idx_epoch_pt_file) model = copy.deepcopy(init_model) model.load_state_dict(torch.load(idx_epoch_pt_file, map_location="cpu")) model = model.to(current_device) logger.info(f"local rank: {device_index}, test_[val]_dataset, start test_[val]_dataset epoch: {idx_test_val_epoch}") logger.info(f"local rank: {device_index}, test_[val]_dataset, after load epoch {idx_test_val_epoch}\nparams:\n{get_model_params(model)}") model.eval() val_all_label_id = [] val_all_label_id_2d = [] val_all_label_id_3d = [] val_all_z_index = [] val_all_z_index_2d = [] val_all_z_index_3d = [] val_all_rotate_count = [] val_all_rotate_count_2d = [] val_all_rotate_count_3d = [] val_all_labels = [] val_all_labels_2d = [] val_all_labels_3d = [] val_all_preds = [] val_all_preds_2d = [] val_all_preds_3d = [] model.eval() with torch.no_grad(): logger.info(f"local rank: {device_index}, test_[val]_dataset, start test_[val]_dataset epoch: {idx_test_val_epoch}, val_data_loader: {len(val_data_loader)}") for step, batch_data in enumerate(val_data_loader): if net_id == "2d3d": label_id_2d, data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, data_3d, label_3d, z_index_3d, rotate_count_3d = batch_data val_all_label_id_2d.extend(label_id_2d) val_all_label_id_3d.extend(label_id_3d) val_all_z_index_2d.extend(z_index_2d) val_all_z_index_3d.extend(z_index_3d) val_all_rotate_count_2d.extend(rotate_count_2d) val_all_rotate_count_3d.extend(rotate_count_3d) val_all_labels_2d.append(label_2d.cpu()) val_all_labels_3d.append(label_3d.cpu()) data_2d = data_2d.to(current_device) data_3d = data_3d.to(current_device) label_2d = label_2d.to(current_device) label_3d = label_3d.to(current_device) y_pred_2d, y_pred_3d = model(data_2d, data_3d) val_all_preds_2d.append(y_pred_2d.detach().cpu()) val_all_preds_3d.append(y_pred_3d.detach().cpu()) else: label_id, data, label, z_index, rotate_count = batch_data val_all_label_id.extend(label_id) val_all_z_index.extend(z_index) val_all_rotate_count.extend(rotate_count) val_all_labels.append(label.cpu()) data = data.to(current_device) label = label.to(current_device) y_pred = model(data) val_all_preds.append(y_pred.detach().cpu()) logger.info(f"local rank: {device_index}, test_[val]_dataset, step: {step}, val_data_loader: {len(val_data_loader)}") if net_id == "2d3d": del model, data_2d, data_3d, label_2d, label_3d, y_pred_2d, y_pred_3d, z_index_2d, z_index_3d, rotate_count_2d, rotate_count_3d else: del model, data, label, y_pred, z_index, rotate_count torch.cuda.empty_cache() if net_id == "2d3d": val_all_label_id_2d = val_all_label_id_2d[:] val_all_label_id_3d = val_all_label_id_3d[:] val_all_net_type_2d = ["2d_net"] * len(val_all_label_id_2d) val_all_net_type_3d = ["3d_net"] * len(val_all_label_id_3d) val_all_labels_2d = torch.cat(val_all_labels_2d, dim=0) val_all_labels_3d = torch.cat(val_all_labels_3d, dim=0) val_all_preds_2d = torch.cat(val_all_preds_2d, dim=0) val_all_preds_3d = torch.cat(val_all_preds_3d, dim=0) concat_label_id = val_all_label_id_2d + val_all_label_id_3d concat_z_index = val_all_z_index_2d[:] + val_all_z_index_3d[:] concat_rotate_count = val_all_rotate_count_2d[:] + val_all_rotate_count_3d[:] concat_net_type = val_all_net_type_2d + val_all_net_type_3d concat_label = torch.cat([val_all_labels_2d, val_all_labels_3d], dim=0) concat_pred = torch.cat([val_all_preds_2d, val_all_preds_3d], dim=0) logger.info(f"local rank: {device_index}, test_[val]_dataset, concat_label_id: {len(concat_label_id)}, concat_label: {concat_label.shape}, concat_pred: {concat_pred.shape}") lens = len(concat_label_id) idx_test_val_epoch_df = pd.DataFrame({ "test_type": ["val_dataset"] * lens, "local_rank": [device_index] * lens, "net_id": [net_id] * lens, "net_type": concat_net_type, "epoch": [idx_test_val_epoch] * lens, "label_id": concat_label_id, "z_index": concat_z_index, "rotate_count": concat_rotate_count, "label": concat_label.tolist(), "pred": concat_pred.tolist() }) idx_test_val_epoch_df_file = f"local_rank_{device_index}_test_type_val_dataset_epoch_{idx_test_val_epoch}_net_{net_id}.csv" idx_test_val_epoch_df_file = os.path.join(save_dir, idx_test_val_epoch_df_file) idx_test_val_epoch_df.to_csv(idx_test_val_epoch_df_file, index=False, encoding="utf-8") logger.info(f"local rank: {device_index}, test_[val]_dataset epoch: {idx_test_val_epoch}, save to {idx_test_val_epoch_df_file}") else: val_all_label_id = val_all_label_id[:] val_all_labels = torch.cat(val_all_labels, dim=0) val_all_preds = torch.cat(val_all_preds, dim=0) lens = len(val_all_label_id) idx_test_val_epoch_df = pd.DataFrame({ "test_type": ["val_dataset"] * lens, "local_rank": [device_index] * lens, "net_id": [net_id] * lens, "net_type": [f"{net_id}_net"] * lens, "epoch": [idx_test_val_epoch] * lens, "label_id": val_all_label_id, "z_index": val_all_z_index[:], "rotate_count": val_all_rotate_count[:], "label": val_all_labels.tolist(), "pred": val_all_preds.tolist() }) idx_test_val_epoch_df_file = f"local_rank_{device_index}_test_type_val_dataset_epoch_{idx_test_val_epoch}_net_{net_id}.csv" idx_test_val_epoch_df_file = os.path.join(save_dir, idx_test_val_epoch_df_file) idx_test_val_epoch_df.to_csv(idx_test_val_epoch_df_file, index=False, encoding="utf-8") logger.info(f"local rank: {device_index}, test_[val]_dataset epoch: {idx_test_val_epoch}, save to {idx_test_val_epoch_df_file}") logger.info(f"local rank: {device_index}, test_[val]_dataset finished") # 评测测试集 torch.cuda.empty_cache() test_test_dataset_flag = test_test_dataset_info_dict["test_test_dataset_flag"] if test_test_dataset_flag: test_test_epoch_list = test_test_dataset_info_dict["test_test_epoch_list"] test_test_epoch_list = sorted(test_test_epoch_list) logger.info(f"local rank: {device_index}, start test_[test]_dataset") for idx_test_test_epoch in test_test_epoch_list: idx_epoch_pt_file = f"train_epoch_{idx_test_test_epoch}_net_{net_id}.pt" idx_epoch_pt_file = os.path.join(save_dir, idx_epoch_pt_file) model = copy.deepcopy(init_model) model.load_state_dict(torch.load(idx_epoch_pt_file, map_location="cpu")) model = model.to(current_device) logger.info(f"local rank: {device_index}, test_[test]_dataset, start test_[test]_dataset epoch: {idx_test_test_epoch}") logger.info(f"local rank: {device_index}, test_[test]_dataset, after load epoch {idx_test_test_epoch}\nparams:\n{get_model_params(model)}") model.eval() test_all_label_id = [] test_all_label_id_2d = [] test_all_label_id_3d = [] test_all_z_index = [] test_all_z_index_2d = [] test_all_z_index_3d = [] test_all_rotate_count = [] test_all_rotate_count_2d = [] test_all_rotate_count_3d = [] test_all_labels = [] test_all_labels_2d = [] test_all_labels_3d = [] test_all_preds = [] test_all_preds_2d = [] test_all_preds_3d = [] model.eval() with torch.no_grad(): logger.info(f"local rank: {device_index}, test_[test]_dataset, start test_[test]_dataset epoch: {idx_test_test_epoch}, test_data_loader: {len(test_data_loader)}") for step, batch_data in enumerate(test_data_loader): if net_id == "2d3d": label_id_2d, data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, data_3d, label_3d, z_index_3d, rotate_count_3d = batch_data test_all_label_id_2d.extend(label_id_2d) test_all_label_id_3d.extend(label_id_3d) test_all_z_index_2d.extend(z_index_2d) test_all_z_index_3d.extend(z_index_3d) test_all_rotate_count_2d.extend(rotate_count_2d) test_all_rotate_count_3d.extend(rotate_count_3d) test_all_labels_2d.append(label_2d.cpu()) test_all_labels_3d.append(label_3d.cpu()) data_2d = data_2d.to(current_device) data_3d = data_3d.to(current_device) label_2d = label_2d.to(current_device) label_3d = label_3d.to(current_device) y_pred_2d, y_pred_3d = model(data_2d, data_3d) test_all_preds_2d.append(y_pred_2d.detach().cpu()) test_all_preds_3d.append(y_pred_3d.detach().cpu()) else: label_id, data, label, z_index, rotate_count = batch_data test_all_label_id.extend(label_id) test_all_z_index.extend(z_index) test_all_rotate_count.extend(rotate_count) test_all_labels.append(label.cpu()) data = data.to(current_device) label = label.to(current_device) y_pred = model(data) test_all_preds.append(y_pred.detach().cpu()) logger.info(f"local rank: {device_index}, test_[test]_dataset, step: {step}, test_data_loader: {len(test_data_loader)}") if net_id == "2d3d": del model, data_2d, data_3d, label_2d, label_3d, y_pred_2d, y_pred_3d, z_index_2d, z_index_3d, rotate_count_2d, rotate_count_3d else: del model, data, label, y_pred, z_index, rotate_count torch.cuda.empty_cache() if net_id == "2d3d": test_all_label_id_2d = test_all_label_id_2d[:] test_all_label_id_3d = test_all_label_id_3d[:] test_all_net_type_2d = ["2d_net"] * len(test_all_label_id_2d) test_all_net_type_3d = ["3d_net"] * len(test_all_label_id_3d) test_all_labels_2d = torch.cat(test_all_labels_2d, dim=0) test_all_labels_3d = torch.cat(test_all_labels_3d, dim=0) test_all_preds_2d = torch.cat(test_all_preds_2d, dim=0) test_all_preds_3d = torch.cat(test_all_preds_3d, dim=0) concat_label_id = test_all_label_id_2d + test_all_label_id_3d concat_z_index = test_all_z_index_2d[:] + test_all_z_index_3d[:] concat_rotate_count = test_all_rotate_count_2d[:] + test_all_rotate_count_3d[:] concat_net_type = test_all_net_type_2d + test_all_net_type_3d concat_label = torch.cat([test_all_labels_2d, test_all_labels_3d], dim=0) concat_pred = torch.cat([test_all_preds_2d, test_all_preds_3d], dim=0) logger.info(f"local rank: {device_index}, test_[test]_dataset, concat_label_id: {len(concat_label_id)}, concat_label: {concat_label.shape}, concat_pred: {concat_pred.shape}") lens = len(concat_label_id) idx_test_test_epoch_df = pd.DataFrame({ "test_type": ["test_dataset"] * lens, "local_rank": [device_index] * lens, "net_id": [net_id] * lens, "net_type": concat_net_type, "epoch": [idx_test_test_epoch] * lens, "label_id": concat_label_id, "z_index": concat_z_index, "rotate_count": concat_rotate_count, "label": concat_label.tolist(), "pred": concat_pred.tolist() }) idx_test_test_epoch_df_file = f"local_rank_{device_index}_test_type_test_dataset_epoch_{idx_test_test_epoch}_net_{net_id}.csv" idx_test_test_epoch_df_file = os.path.join(save_dir, idx_test_test_epoch_df_file) idx_test_test_epoch_df.to_csv(idx_test_test_epoch_df_file, index=False, encoding="utf-8") logger.info(f"local rank: {device_index}, test_[test]_dataset epoch: {idx_test_test_epoch}, save to {idx_test_test_epoch_df_file}") else: test_all_label_id = test_all_label_id[:] test_all_labels = torch.cat(test_all_labels, dim=0) test_all_preds = torch.cat(test_all_preds, dim=0) lens = len(test_all_label_id) idx_test_test_epoch_df = pd.DataFrame({ "test_type": ["test_dataset"] * lens, "local_rank": [device_index] * lens, "net_id": [net_id] * lens, "net_type": [f"{net_id}_net"] * lens, "epoch": [idx_test_test_epoch] * lens, "label_id": test_all_label_id, "z_index": test_all_z_index[:], "rotate_count": test_all_rotate_count[:], "label": test_all_labels.tolist(), "pred": test_all_preds.tolist() }) idx_test_test_epoch_df_file = f"local_rank_{device_index}_test_type_test_dataset_epoch_{idx_test_test_epoch}_net_{net_id}.csv" idx_test_test_epoch_df_file = os.path.join(save_dir, idx_test_test_epoch_df_file) idx_test_test_epoch_df.to_csv(idx_test_test_epoch_df_file, index=False, encoding="utf-8") logger.info(f"local rank: {device_index}, test_[test]_dataset epoch: {idx_test_test_epoch}, save to {idx_test_test_epoch_df_file}") logger.info(f"local rank: {device_index}, test_[test]_dataset finished") def start_test_ddp(model, config): logger = config['logger'] net_id = config['net_id'] train_2d3d_data_2d_csv_file = config['train_2d3d_data_2d_csv_file'] val_2d3d_data_2d_csv_file = config['val_2d3d_data_2d_csv_file'] test_2d3d_data_2d_csv_file = config['test_2d3d_data_2d_csv_file'] train_2d3d_data_3d_csv_file = config['train_2d3d_data_3d_csv_file'] val_2d3d_data_3d_csv_file = config['val_2d3d_data_3d_csv_file'] test_2d3d_data_3d_csv_file = config['test_2d3d_data_3d_csv_file'] train_csv_file = config['train_csv_file'] val_csv_file = config['val_csv_file'] test_csv_file = config['test_csv_file'] device_index = config['device_index'] process_count = len(device_index) save_dir = config['save_dir'] if net_id == "2d3d": train_data_2d_df = pd.read_csv(train_2d3d_data_2d_csv_file, header=0, encoding="utf-8") train_data_3d_df = pd.read_csv(train_2d3d_data_3d_csv_file, header=0, encoding="utf-8") val_data_2d_df = pd.read_csv(val_2d3d_data_2d_csv_file, header=0, encoding="utf-8") val_data_3d_df = pd.read_csv(val_2d3d_data_3d_csv_file, header=0, encoding="utf-8") test_data_2d_df = pd.read_csv(test_2d3d_data_2d_csv_file, header=0, encoding="utf-8") test_data_3d_df = pd.read_csv(test_2d3d_data_3d_csv_file, header=0, encoding="utf-8") train_data_2d_file_list = [ os.path.join(save_dir, f"local_rank_{idx_process}_test_type_train_data_2d_dataset_net_{net_id}.csv") for idx_process in range(process_count) ] train_data_3d_file_list = [ os.path.join(save_dir, f"local_rank_{idx_process}_test_type_train_data_3d_dataset_net_{net_id}.csv") for idx_process in range(process_count) ] val_data_2d_file_list = [ os.path.join(save_dir, f"local_rank_{idx_process}_test_type_val_data_2d_dataset_net_{net_id}.csv") for idx_process in range(process_count) ] val_data_3d_file_list = [ os.path.join(save_dir, f"local_rank_{idx_process}_test_type_val_data_3d_dataset_net_{net_id}.csv") for idx_process in range(process_count) ] test_data_2d_file_list = [ os.path.join(save_dir, f"local_rank_{idx_process}_test_type_test_data_2d_dataset_net_{net_id}.csv") for idx_process in range(process_count) ] test_data_3d_file_list = [ os.path.join(save_dir, f"local_rank_{idx_process}_test_type_test_data_3d_dataset_net_{net_id}.csv") for idx_process in range(process_count) ] train_data_2d_df_list = np.array_split(train_data_2d_df, process_count) val_data_2d_df_list = np.array_split(val_data_2d_df, process_count) test_data_2d_df_list = np.array_split(test_data_2d_df, process_count) train_data_3d_df_list = np.array_split(train_data_3d_df, process_count) val_data_3d_df_list = np.array_split(val_data_3d_df, process_count) test_data_3d_df_list = np.array_split(test_data_3d_df, process_count) for idx_process in range(process_count): train_data_2d_df_list[idx_process].to_csv(train_data_2d_file_list[idx_process], index=False, encoding="utf-8") val_data_2d_df_list[idx_process].to_csv(val_data_2d_file_list[idx_process], index=False, encoding="utf-8") test_data_2d_df_list[idx_process].to_csv(test_data_2d_file_list[idx_process], index=False, encoding="utf-8") train_data_3d_df_list[idx_process].to_csv(train_data_3d_file_list[idx_process], index=False, encoding="utf-8") val_data_3d_df_list[idx_process].to_csv(val_data_3d_file_list[idx_process], index=False, encoding="utf-8") test_data_3d_df_list[idx_process].to_csv(test_data_3d_file_list[idx_process], index=False, encoding="utf-8") else: train_df = pd.read_csv(train_csv_file, header=0, encoding="utf-8") val_df = pd.read_csv(val_csv_file, header=0, encoding="utf-8") test_df = pd.read_csv(test_csv_file, header=0, encoding="utf-8") train_data_file_list = [ os.path.join(save_dir, f"local_rank_{idx_process}_test_type_train_dataset_net_{net_id}.csv") for idx_process in range(process_count) ] val_data_file_list = [ os.path.join(save_dir, f"local_rank_{idx_process}_test_type_val_dataset_net_{net_id}.csv") for idx_process in range(process_count) ] test_data_file_list = [ os.path.join(save_dir, f"local_rank_{idx_process}_test_type_test_dataset_net_{net_id}.csv") for idx_process in range(process_count) ] train_df_list = np.array_split(train_df, process_count) val_df_list = np.array_split(val_df, process_count) test_df_list = np.array_split(test_df, process_count) for idx_process in range(process_count): train_df_list[idx_process].to_csv(train_data_file_list[idx_process], index=False, encoding="utf-8") val_df_list[idx_process].to_csv(val_data_file_list[idx_process], index=False, encoding="utf-8") test_df_list[idx_process].to_csv(test_data_file_list[idx_process], index=False, encoding="utf-8") process_count = len(device_index) process_list = [] if net_id == "2d3d": train_data_file_list = [None] * process_count val_data_file_list = [None] * process_count test_data_file_list = [None] * process_count else: train_data_2d_df_list = [None] * process_count val_data_2d_df_list = [None] * process_count test_data_2d_df_list = [None] * process_count train_data_3d_df_list = [None] * process_count val_data_3d_df_list = [None] * process_count test_data_3d_df_list = [None] * process_count for idx in range(process_count): idx_process = mp.Process( target=test_on_single_gpu, args=( model, config, device_index[idx], train_data_2d_file_list[idx], val_data_2d_file_list[idx], test_data_2d_file_list[idx], train_data_3d_file_list[idx], val_data_3d_file_list[idx], test_data_3d_file_list[idx], train_data_file_list[idx], val_data_file_list[idx], test_data_file_list[idx] ) ) idx_process.start() process_list.append(idx_process) for idx_process in process_list: idx_process.join() logger.info("test_ddp finished") def test_ddp(model, config): try: set_seed() start_test_ddp(model, config) except Exception as e: print(f"error in test_ddp: {traceback.format_exc()}") raise e finally: print("test_ddp finished") def gather_test_result(config): net_id = config['net_id'] node_id = config['node_id'] date_id = config['date_id'] save_dir = config['save_dir'] logger = config['logger'] local_rank_list = config['device_index'] test_train_all_epoch_file = None test_val_all_epoch_file = None test_test_all_epoch_file = None def parse_label_id(df=None, label_id_column="label_id"): if df is None: return df for idx in range(len(df)): idx_label_id = df.loc[idx, label_id_column] idx_node = idx_label_id.split("_")[0] idx_label = idx_label_id.split("_")[1] df.loc[idx, "label_id"] = f"{int(float(idx_node))}_{int(float(idx_label))}" return df # 训练集评测 test_train_dataset_info_dict = config['test_train_dataset_info_dict'] test_train_dataset_flag = test_train_dataset_info_dict['test_train_dataset_flag'] if test_train_dataset_flag: print(f"start gather test_[train]_dataset") test_train_epoch_list = test_train_dataset_info_dict['test_train_epoch_list'] test_train_df_list = [] for idx_test_train_epoch in test_train_epoch_list: for idx_local_rank in local_rank_list: idx_test_train_epoch_df_file = f"local_rank_{idx_local_rank}_test_type_train_dataset_epoch_{idx_test_train_epoch}_net_{net_id}.csv" idx_test_train_epoch_df_file = os.path.join(config['save_dir'], idx_test_train_epoch_df_file) idx_df = pd.read_csv(idx_test_train_epoch_df_file) print(f"load idx_epoch_test_[train]_file: {idx_test_train_epoch_df_file}") test_train_df_list.append(idx_df) test_train_df = pd.concat(test_train_df_list, ignore_index=True) test_train_df = parse_label_id(test_train_df) test_train_df = test_train_df.drop_duplicates(subset=["epoch", "net_type", "label_id", "z_index", "rotate_count"], keep="first") test_train_df = test_train_df.reset_index(drop=True) test_train_df = test_train_df.sort_values(by=["test_type", "local_rank", "epoch", "net_type", "label_id", "z_index", "rotate_count"], ascending=[True, True, True, True, True, True, True]) test_train_all_epoch_file = f"gather_test_[train]_dataset_net_{net_id}_node_{node_id}_date_{date_id}.csv" test_train_all_epoch_file = os.path.join(save_dir, test_train_all_epoch_file) test_train_df.to_csv(test_train_all_epoch_file, index=False, encoding="utf-8") print(f"gather test_[train]_dataset finished, save to {test_train_all_epoch_file}") # 验证集评测 test_val_dataset_info_dict = config['test_val_dataset_info_dict'] test_val_dataset_flag = test_val_dataset_info_dict['test_val_dataset_flag'] if test_val_dataset_flag: print(f"start gather test_[val]_dataset") test_val_epoch_list = test_val_dataset_info_dict['test_val_epoch_list'] test_val_df_list = [] for idx_test_val_epoch in test_val_epoch_list: for idx_local_rank in local_rank_list: idx_test_val_epoch_df_file = f"local_rank_{idx_local_rank}_test_type_val_dataset_epoch_{idx_test_val_epoch}_net_{net_id}.csv" idx_test_val_epoch_df_file = os.path.join(config['save_dir'], idx_test_val_epoch_df_file) idx_df = pd.read_csv(idx_test_val_epoch_df_file) print(f"load idx_epoch_test_[val]_file: {idx_test_val_epoch_df_file}") test_val_df_list.append(idx_df) test_val_df = pd.concat(test_val_df_list, ignore_index=True) test_val_df = parse_label_id(test_val_df) test_val_df = test_val_df.drop_duplicates(subset=["epoch", "net_type", "label_id", "z_index", "rotate_count"], keep="first") test_val_df = test_val_df.reset_index(drop=True) test_val_df = test_val_df.sort_values(by=["test_type", "local_rank", "epoch", "net_type", "label_id", "z_index", "rotate_count"], ascending=[True, True, True, True, True, True, True]) test_val_all_epoch_file = f"gather_test_[val]_dataset_net_{net_id}_node_{node_id}_date_{date_id}.csv" test_val_all_epoch_file = os.path.join(save_dir, test_val_all_epoch_file) test_val_df.to_csv(test_val_all_epoch_file, index=False, encoding="utf-8") print(f"gather test_[val]_dataset finished, save to {test_val_all_epoch_file}") # 测试集评测 test_test_dataset_info_dict = config['test_test_dataset_info_dict'] test_test_dataset_flag = test_test_dataset_info_dict['test_test_dataset_flag'] if test_test_dataset_flag: print(f"start gather test_[test]_dataset") test_test_epoch_list = test_test_dataset_info_dict['test_test_epoch_list'] test_test_df_list = [] for idx_test_test_epoch in test_test_epoch_list: for idx_local_rank in local_rank_list: idx_test_test_epoch_df_file = f"local_rank_{idx_local_rank}_test_type_test_dataset_epoch_{idx_test_test_epoch}_net_{net_id}.csv" idx_test_test_epoch_df_file = os.path.join(config['save_dir'], idx_test_test_epoch_df_file) idx_df = pd.read_csv(idx_test_test_epoch_df_file) print(f"load idx_epoch_test_[test]_file: {idx_test_test_epoch_df_file}") test_test_df_list.append(idx_df) test_test_df = pd.concat(test_test_df_list, ignore_index=True) test_test_df = parse_label_id(test_test_df) test_test_df = test_test_df.drop_duplicates(subset=["epoch", "net_type", "label_id", "z_index", "rotate_count"], keep="first") test_test_df = test_test_df.reset_index(drop=True) test_test_df = test_test_df.sort_values(by=["test_type", "local_rank", "epoch", "net_type", "label_id", "z_index", "rotate_count"], ascending=[True, True, True, True, True, True, True]) test_test_all_epoch_file = f"gather_test_[test]_dataset_net_{net_id}_node_{node_id}_date_{date_id}.csv" test_test_all_epoch_file = os.path.join(save_dir, test_test_all_epoch_file) test_test_df.to_csv(test_test_all_epoch_file, index=False, encoding="utf-8") print(f"gather test_[test]_dataset finished, save to {test_test_all_epoch_file}") return test_train_all_epoch_file, test_val_all_epoch_file, test_test_all_epoch_file def get_gather_result_train_dataset(gather_all_epoch_file, config, test_type="train_dataset", positive_threshold_list=[], negative_threshold_list=[], get_epoch_list=[]): ''' 筛选出数据集,预测错误的数据 ''' net_id = config['net_id'] save_dir = config['save_dir'] train_csv_file = config['train_csv_file'] train_df = pd.read_csv(train_csv_file) label_id_patient_id_dict = {} for idx in range(len(train_df)): idx_node = train_df.loc[idx, "node_time"] idx_label= train_df.loc[idx, "label_id"] idx_patient_id = train_df.loc[idx, "patient_id"] idx_label_id = f"{idx_node}_{idx_label}" label_id_patient_id_dict[idx_label_id] = idx_patient_id df = pd.read_csv(gather_all_epoch_file) df = df[df['test_type'] == test_type].reset_index(drop=True) print(f"df: {len(df)}") epoch_list = df['epoch'].unique().tolist() if get_epoch_list == [] else get_epoch_list epoch_list = sorted(epoch_list) df_epoch_list = [] df_threshold_list = [] epoch_net_id_list = [] epoch_net_type_list = [] epoch_label_id_list = [] epoch_z_index_list = [] epoch_roate_count_list = [] epoch_label_list = [] epoch_pred_list = [] for idx_epoch in epoch_list: df_epoch = df[df['epoch'] == idx_epoch].reset_index(drop=True) if net_id == "2d3d": for idx_positive_threshold, idx_negative_threshold in zip(positive_threshold_list, negative_threshold_list): idx_cp_df_epoch = copy.deepcopy(df_epoch) idx_df_2d = idx_cp_df_epoch[idx_cp_df_epoch['net_type'] == "2d_net"].reset_index(drop=True) idx_df_3d = idx_cp_df_epoch[idx_cp_df_epoch['net_type'] == "3d_net"].reset_index(drop=True) print(f"idx_df_2d: {len(idx_df_2d)}, idx_df_3d: {len(idx_df_3d)}") for idx in range(len(idx_df_2d)): idx_preds_2d = idx_df_2d.loc[idx, "pred"] idx_label_2d = idx_df_2d.loc[idx, "label"] idx_label_id = idx_df_2d.loc[idx, "label_id"] idx_z_index = idx_df_2d.loc[idx, "z_index"] idx_rotate_count = idx_df_2d.loc[idx, "rotate_count"] idx_net_type = idx_df_2d.loc[idx, "net_type"] if idx_label_2d == 1 and idx_preds_2d < idx_positive_threshold: df_epoch_list.append(idx_epoch) df_threshold_list.append(idx_positive_threshold) epoch_net_id_list.append(net_id) epoch_net_type_list.append(idx_net_type) epoch_label_id_list.append(idx_label_id) epoch_z_index_list.append(idx_z_index) epoch_roate_count_list.append(idx_rotate_count) epoch_label_list.append(idx_label_2d) epoch_pred_list.append(idx_preds_2d) elif idx_label_2d == 0 and idx_preds_2d > idx_negative_threshold: df_epoch_list.append(idx_epoch) df_threshold_list.append(idx_negative_threshold) epoch_net_id_list.append(net_id) epoch_net_type_list.append(idx_net_type) epoch_label_id_list.append(idx_label_id) epoch_z_index_list.append(idx_z_index) epoch_roate_count_list.append(idx_rotate_count) epoch_label_list.append(idx_label_2d) epoch_pred_list.append(idx_preds_2d) for idx in range(len(idx_df_3d)): idx_preds_3d = idx_df_3d.loc[idx, "pred"] idx_label_3d = idx_df_3d.loc[idx, "label"] idx_label_id = idx_df_3d.loc[idx, "label_id"] idx_z_index = idx_df_2d.loc[idx, "z_index"] idx_rotate_count = idx_df_2d.loc[idx, "rotate_count"] idx_net_type = idx_df_3d.loc[idx, "net_type"] if idx_label_3d == 1 and idx_preds_3d < idx_positive_threshold: df_epoch_list.append(idx_epoch) df_threshold_list.append(idx_positive_threshold) epoch_net_id_list.append(net_id) epoch_net_type_list.append(idx_net_type) epoch_label_id_list.append(idx_label_id) epoch_z_index_list.append(idx_z_index) epoch_roate_count_list.append(idx_rotate_count) epoch_label_list.append(idx_label_3d) epoch_pred_list.append(idx_preds_3d) elif idx_label_3d == 0 and idx_preds_3d > idx_negative_threshold: df_epoch_list.append(idx_epoch) df_threshold_list.append(idx_negative_threshold) epoch_net_id_list.append(net_id) epoch_net_type_list.append(idx_net_type) epoch_label_id_list.append(idx_label_id) epoch_z_index_list.append(idx_z_index) epoch_roate_count_list.append(idx_rotate_count) epoch_label_list.append(idx_label_3d) epoch_pred_list.append(idx_preds_3d) else: for idx_positive_threshold, idx_negative_threshold in zip(positive_threshold_list, negative_threshold_list): idx_cp_df_epoch = copy.deepcopy(df_epoch) for idx in range(len(idx_cp_df_epoch)): idx_preds = idx_cp_df_epoch.loc[idx, "pred"] idx_label = idx_cp_df_epoch.loc[idx, "label"] idx_label_id = idx_cp_df_epoch.loc[idx, "label_id"] idx_z_index = idx_df_2d.loc[idx, "z_index"] idx_rotate_count = idx_df_2d.loc[idx, "rotate_count"] idx_net_type = idx_cp_df_epoch.loc[idx, "net_type"] if idx_label == 1 and idx_preds < idx_positive_threshold: df_epoch_list.append(idx_epoch) df_threshold_list.append(idx_positive_threshold) epoch_net_id_list.append(net_id) epoch_net_type_list.append(idx_net_type) epoch_label_id_list.append(idx_label_id) epoch_z_index_list.append(idx_z_index) epoch_roate_count_list.append(idx_rotate_count) epoch_label_list.append(idx_label) epoch_pred_list.append(idx_preds) elif idx_label == 0 and idx_preds > idx_negative_threshold: df_threshold_list.append(idx_positive_threshold) epoch_net_id_list.append(net_id) epoch_net_type_list.append(idx_net_type) epoch_label_id_list.append(idx_label_id) epoch_z_index_list.append(idx_z_index) epoch_roate_count_list.append(idx_rotate_count) epoch_label_list.append(idx_label) epoch_pred_list.append(idx_preds) epoch_error_df = pd.DataFrame({ "epoch": df_epoch_list, "threshold": df_threshold_list, "net_id": epoch_net_id_list, "net_type": epoch_net_type_list, "label_id": epoch_label_id_list, "z_index": epoch_z_index_list, "rotate_count": epoch_roate_count_list, "patient_id": [label_id_patient_id_dict[idx_label_id] for idx_label_id in epoch_label_id_list], "label": epoch_label_list, "pred": epoch_pred_list, }) epoch_error_file = f"{gather_all_epoch_file.replace('.csv', '_epoch_预测不在阈值区间的数据.csv')}" epoch_error_file = os.path.join(save_dir, epoch_error_file) epoch_error_df.to_csv(epoch_error_file, index=False, encoding="utf-8") print(f"epoch_预测不在阈值区间的数据, save epoch_error_df to {epoch_error_file}") def get_gather_result(gather_all_epoch_file, config, test_type="test_dataset", positive_threshold_list=[], step=0.01): net_id = config['net_id'] save_dir = config['save_dir'] if net_id == "2d3d": if test_type == "train_dataset": origin_file_data_2d = config['train_2d3d_data_2d_csv_file'] origin_file_data_3d = config['train_2d3d_data_3d_csv_file'] print(f"origin_file_data_2d: {origin_file_data_2d}\norigin_file_data_3d: {origin_file_data_3d}") elif test_type == "test_dataset": origin_file_data_2d = config['test_2d3d_data_2d_csv_file'] origin_file_data_3d = config['test_2d3d_data_3d_csv_file'] else: origin_file_data_2d = config['val_2d3d_data_2d_csv_file'] origin_file_data_3d = config['val_2d3d_data_3d_csv_file'] origin_file_data_2d = os.path.join(save_dir, origin_file_data_2d) origin_file_data_3d = os.path.join(save_dir, origin_file_data_3d) origin_df_data_2d = pd.read_csv(origin_file_data_2d) origin_df_data_3d = pd.read_csv(origin_file_data_3d) origin_df = pd.concat([origin_df_data_2d, origin_df_data_3d], ignore_index=True) else: if test_type == "train_dataset": origin_file = config["train_csv_file"] elif test_type == "test_dataset": origin_file = config["test_csv_file"] else: origin_file = config["val_csv_file"] origin_file = os.path.join(save_dir, origin_file) origin_df = pd.read_csv(origin_file) label_id_patient_id_dict = {} for idx in range(len(origin_df)): idx_node = origin_df.loc[idx, "node"] idx_label_id = origin_df.loc[idx, "label_id"] idx_patient_id = origin_df.loc[idx, "patient_id"] idx_label_id = f"{int(float(idx_node))}_{int(float(idx_label_id))}" label_id_patient_id_dict[idx_label_id] = idx_patient_id df = pd.read_csv(gather_all_epoch_file) # 统计训练集多个评测结果 df = df[df['test_type'] == test_type].reset_index(drop=True) lens = len(df) # 去除重复的 df = df.drop_duplicates(subset=['test_type', 'net_id', 'net_type', 'epoch', 'label_id', 'label', 'z_index', 'rotate_count'], keep="first") print(f"去除重复的: {gather_all_epoch_file}\n{lens} --> {len(df)}") # # 2d_net, 取平均 # aggregated_df = df.groupby(['test_type', 'local_rank', 'net_id', 'net_type', 'epoch', 'label_id', 'label'])['pred'].mean().reset_index() # df = copy.deepcopy(aggregated_df) # avg_2d_df_file = f"{gather_all_epoch_file.replace('.csv', '_2d_net_avg.csv')}" # avg_2d_df_file = os.path.join(save_dir, avg_2d_df_file) # df.to_csv(avg_2d_df_file, index=False, encoding="utf-8") # print(f"save 2d_net_avg to {avg_2d_df_file}") if positive_threshold_list == []: positive_threshold_list = [round(i, 2) for i in np.arange(0.1, 1.0, 0.01)] print(f"positive_threshold_list: {len(positive_threshold_list)}, {positive_threshold_list}") epoch_list = df['epoch'].unique().tolist() epoch_list = sorted(epoch_list) print(f"epoch_list: {len(epoch_list)}, {epoch_list}") df_epoch_list = [] df_threshold_list = [] epoch_accuracy_list = [] epoch_precision_list = [] epoch_recall_list = [] epoch_f1_list = [] epoch_error_list = [] epoch_positive_accuracy_list = [] epoch_negative_accuracy_list = [] epoch_positive_precision_list = [] epoch_positive_recall_list = [] epoch_positive_f1_list = [] epoch_negative_precision_list = [] epoch_negative_recall_list = [] epoch_negative_f1_list = [] epoch_accuracy_2d_list = [] epoch_precision_2d_list = [] epoch_recall_2d_list = [] epoch_f1_2d_list = [] epoch_error_2d_list = [] epoch_positive_accuracy_2d_list = [] epoch_negative_accuracy_2d_list = [] epoch_positive_precision_2d_list = [] epoch_positive_recall_2d_list = [] epoch_positive_f1_2d_list = [] epoch_negative_precision_2d_list = [] epoch_negative_recall_2d_list = [] epoch_negative_f1_2d_list = [] epoch_accuracy_3d_list = [] epoch_precision_3d_list = [] epoch_recall_3d_list = [] epoch_f1_3d_list = [] epoch_error_3d_list = [] epoch_positive_accuracy_3d_list = [] epoch_negative_accuracy_3d_list = [] epoch_positive_precision_3d_list = [] epoch_positive_recall_3d_list = [] epoch_positive_f1_3d_list = [] epoch_negative_precision_3d_list = [] epoch_negative_recall_3d_list = [] epoch_negative_f1_3d_list = [] for idx_epoch in epoch_list: df_epoch = df[df['epoch'] == idx_epoch].reset_index(drop=True) if net_id == "2d3d": df_2d = df_epoch[df_epoch['net_type'] == "2d_net"].reset_index(drop=True) df_3d = df_epoch[df_epoch['net_type'] == "3d_net"].reset_index(drop=True) print(f"df_2d: {len(df_2d)}, df_3d: {len(df_3d)}") epoch_df_2d_preds = torch.tensor(df_2d['pred'].tolist(), dtype=torch.float32) epoch_df_2d_labels = torch.tensor(df_2d['label'].tolist()) epoch_df_3d_preds = torch.tensor(df_3d['pred'].tolist(), dtype=torch.float32) epoch_df_3d_labels = torch.tensor(df_3d['label'].tolist()) epoch_df_2d_label_id = df_2d['label_id'].tolist() epoch_df_3d_label_id = df_3d['label_id'].tolist() epoch_df_2d_z_index = df_2d['z_index'].tolist() epoch_df_3d_z_index = df_3d['z_index'].tolist() epoch_df_2d_rotate_count = df_2d['rotate_count'].tolist() epoch_df_3d_rotate_count = df_3d['rotate_count'].tolist() for idx_positive_threshold in positive_threshold_list: idx_epoch_df_2d_preds = copy.deepcopy(epoch_df_2d_preds) idx_epoch_df_3d_preds = copy.deepcopy(epoch_df_3d_preds) idx_epoch_df_2d_labels = copy.deepcopy(epoch_df_2d_labels) idx_epoch_df_3d_labels = copy.deepcopy(epoch_df_3d_labels) # 根据阈值将预测概率转换为二分类标签 idx_epoch_df_2d_preds_binary = (idx_epoch_df_2d_preds >= idx_positive_threshold).to(torch.int) idx_epoch_df_3d_preds_binary = (idx_epoch_df_3d_preds >= idx_positive_threshold).to(torch.int) # 将 tensor 转换为 numpy 数组以便于计算 idx_epoch_df_2d_preds_binary = idx_epoch_df_2d_preds_binary.numpy() idx_epoch_df_3d_preds_binary = idx_epoch_df_3d_preds_binary.numpy() idx_epoch_df_2d_labels = idx_epoch_df_2d_labels.numpy() idx_epoch_df_3d_labels = idx_epoch_df_3d_labels.numpy() # 整体准确率 acc_2d = accuracy_score(idx_epoch_df_2d_labels, idx_epoch_df_2d_preds_binary) acc_3d = accuracy_score(idx_epoch_df_3d_labels, idx_epoch_df_3d_preds_binary,) prec_2d = precision_score(idx_epoch_df_2d_labels, idx_epoch_df_2d_preds_binary) rec_2d = recall_score(idx_epoch_df_2d_labels, idx_epoch_df_2d_preds_binary) f1_2d = f1_score(idx_epoch_df_2d_labels, idx_epoch_df_2d_preds_binary) prec_3d = precision_score(idx_epoch_df_3d_labels, idx_epoch_df_3d_preds_binary) rec_3d = recall_score(idx_epoch_df_3d_labels, idx_epoch_df_3d_preds_binary) f1_3d = f1_score(idx_epoch_df_3d_labels, idx_epoch_df_3d_preds_binary) # 混淆矩阵 c_2d = confusion_matrix(idx_epoch_df_2d_labels, idx_epoch_df_2d_preds_binary) c_3d = confusion_matrix(idx_epoch_df_3d_labels, idx_epoch_df_3d_preds_binary) tp_2d, fp_2d, fn_2d, tn_2d = c_2d[1, 1], c_2d[0, 1], c_2d[1, 0], c_2d[0, 0] tp_3d, fp_3d, fn_3d, tn_3d = c_3d[1, 1], c_3d[0, 1], c_3d[1, 0], c_3d[0, 0] # 正样本指标 acc_pos_2d = (tp_2d) / (tp_2d + fn_2d) if (tp_2d + fn_2d) > 0 else 0 prec_pos_2d = tp_2d / (tp_2d + fp_2d) if (tp_2d + fp_2d) > 0 else 0 rec_pos_2d = tp_2d / (tp_2d + fn_2d) if (tp_2d + fn_2d) > 0 else 0 f1_pos_2d = 2 * (prec_pos_2d * rec_pos_2d) / (prec_pos_2d + rec_pos_2d) if (prec_pos_2d + rec_pos_2d) > 0 else 0 acc_pos_3d = (tp_3d) / (tp_3d + fn_3d) if (tp_3d + fn_3d) > 0 else 0 prec_pos_3d = tp_3d / (tp_3d + fp_3d) if (tp_3d + fp_3d) > 0 else 0 rec_pos_3d = tp_3d / (tp_3d + fn_3d) if (tp_3d + fn_3d) > 0 else 0 f1_pos_3d = 2 * (prec_pos_3d * rec_pos_3d) / (prec_pos_3d + rec_pos_3d) if (prec_pos_3d + rec_pos_3d) > 0 else 0 # 负样本指标 acc_neg_2d = (tn_2d) / (tn_2d + fp_2d) if (tn_2d + fp_2d) > 0 else 0 prec_neg_2d = tn_2d / (tn_2d + fn_2d) if (tn_2d + fn_2d) > 0 else 0 rec_neg_2d = tn_2d / (tn_2d + fp_2d) if (tn_2d + fp_2d) > 0 else 0 f1_neg_2d = 2 * (prec_neg_2d * rec_neg_2d) / (prec_neg_2d + rec_neg_2d) if (prec_neg_2d + rec_neg_2d) > 0 else 0 acc_neg_3d = (tn_3d) / (tn_3d + fp_3d) if (tn_3d + fp_3d) > 0 else 0 prec_neg_3d = tn_3d / (tn_3d + fn_3d) if (tn_3d + fn_3d) > 0 else 0 rec_neg_3d = tn_3d / (tn_3d + fp_3d) if (tn_3d + fp_3d) > 0 else 0 f1_neg_3d = 2 * (prec_neg_3d * rec_neg_3d) / (prec_neg_3d + rec_neg_3d) if (prec_neg_3d + rec_neg_3d) > 0 else 0 # 记录预测错误的数据 error_2d = [(label_id, z_index, rotate_count, label, pred_prob.numpy(), label_id_patient_id_dict[label_id]) for label_id, z_index, rotate_count, label, pred_prob, pred_binary in zip(epoch_df_2d_label_id, epoch_df_2d_z_index, epoch_df_2d_rotate_count, idx_epoch_df_2d_labels, idx_epoch_df_2d_preds, idx_epoch_df_2d_preds_binary) if label != pred_binary] error_3d = [(label_id, z_index, rotate_count, label, pred_prob.numpy(), label_id_patient_id_dict[label_id]) for label_id, z_index, rotate_count, label, pred_prob, pred_binary in zip(epoch_df_3d_label_id, epoch_df_3d_z_index, epoch_df_3d_rotate_count, idx_epoch_df_3d_labels, idx_epoch_df_3d_preds, idx_epoch_df_3d_preds_binary) if label != pred_binary] df_epoch_list.append(idx_epoch) df_threshold_list.append(idx_positive_threshold) epoch_accuracy_2d_list.append(acc_2d) epoch_precision_2d_list.append(prec_2d) epoch_recall_2d_list.append(rec_2d) epoch_f1_2d_list.append(f1_2d) epoch_error_2d_list.append(error_2d) epoch_accuracy_3d_list.append(acc_3d) epoch_precision_3d_list.append(prec_3d) epoch_recall_3d_list.append(rec_3d) epoch_f1_3d_list.append(f1_3d) epoch_error_3d_list.append(error_3d) epoch_positive_accuracy_2d_list.append(acc_pos_2d) epoch_negative_accuracy_2d_list.append(acc_neg_2d) epoch_positive_accuracy_3d_list.append(acc_pos_3d) epoch_negative_accuracy_3d_list.append(acc_neg_3d) epoch_positive_precision_2d_list.append(prec_pos_2d) epoch_negative_precision_2d_list.append(prec_neg_2d) epoch_positive_precision_3d_list.append(prec_pos_3d) epoch_negative_precision_3d_list.append(prec_neg_3d) epoch_positive_recall_2d_list.append(rec_pos_2d) epoch_negative_recall_2d_list.append(rec_neg_2d) epoch_positive_recall_3d_list.append(rec_pos_3d) epoch_negative_recall_3d_list.append(rec_neg_3d) epoch_positive_f1_2d_list.append(f1_pos_2d) epoch_negative_f1_2d_list.append(f1_neg_2d) epoch_positive_f1_3d_list.append(f1_pos_3d) epoch_negative_f1_3d_list.append(f1_neg_3d) else: epoch_df_preds = torch.tensor(df_epoch['pred'].tolist(), dtype=torch.float32) epoch_df_labels = torch.tensor(df_epoch['label'].tolist()) epoch_df_label_id = df_epoch['label_id'].tolist() epoch_df_z_index = df_epoch['z_index'].tolist() epoch_df_rotate_count = df_epoch['rotate_count'].tolist() for idx_positive_threshold in positive_threshold_list: idx_epoch_df_preds = copy.deepcopy(epoch_df_preds) idx_epoch_df_labels = copy.deepcopy(epoch_df_labels) # 根据阈值将预测概率转换为二分类标签 idx_epoch_df_preds_binary = (idx_epoch_df_preds >= idx_positive_threshold).to(torch.int) # 将 tensor 转换为 numpy 数组以便于计算 idx_epoch_df_preds_binary = idx_epoch_df_preds_binary.numpy() idx_epoch_df_labels = idx_epoch_df_labels.numpy() # 整体准确率 acc = accuracy_score(idx_epoch_df_labels, idx_epoch_df_preds_binary) prec = precision_score(idx_epoch_df_labels, idx_epoch_df_preds_binary) rec = recall_score(idx_epoch_df_labels, idx_epoch_df_preds_binary) f1 = f1_score(idx_epoch_df_labels, idx_epoch_df_preds_binary) # 混淆矩阵 c = confusion_matrix(idx_epoch_df_labels, idx_epoch_df_preds_binary) tp, fp, fn, tn = c[1, 1], c[0, 1], c[1, 0], c[0, 0] # 正样本指标 acc_pos = (tp) / (tp + fn) if (tp + fn) > 0 else 0 prec_pos = tp / (tp + fp) if (tp + fp) > 0 else 0 rec_pos = tp / (tp + fn) if (tp + fn) > 0 else 0 f1_pos = 2 * (prec_pos * rec_pos) / (prec_pos + rec_pos) if (prec_pos + rec_pos) > 0 else 0 # 负样本指标 acc_neg = (tn) / (tn + fp) if (tn + fp) > 0 else 0 prec_neg = tn / (tn + fn) if (tn + fn) > 0 else 0 rec_neg = tn / (tn + fp) if (tn + fp) > 0 else 0 f1_neg = 2 * (prec_neg * rec_neg) / (prec_neg + rec_neg) if (prec_neg + rec_neg) > 0 else 0 # 记录预测错误的数据 error = [(label_id, z_index, rotate_count, label, pred_prob.numpy(), label_id_patient_id_dict[label_id]) for label_id, z_index, rotate_count, label, pred_prob, pred_binary in zip(epoch_df_label_id, epoch_df_z_index, epoch_df_rotate_count, idx_epoch_df_labels, idx_epoch_df_preds, idx_epoch_df_preds_binary) if label != pred_binary] df_epoch_list.append(idx_epoch) df_threshold_list.append(idx_positive_threshold) epoch_accuracy_list.append(acc) epoch_precision_list.append(prec) epoch_recall_list.append(rec) epoch_f1_list.append(f1) epoch_error_list.append(error) epoch_positive_accuracy_list.append(acc_pos) epoch_negative_accuracy_list.append(acc_neg) epoch_positive_precision_list.append(prec_pos) epoch_negative_precision_list.append(prec_neg) epoch_positive_recall_list.append(rec_pos) epoch_negative_recall_list.append(rec_neg) epoch_positive_f1_list.append(f1_pos) epoch_negative_f1_list.append(f1_neg) if net_id == "2d3d": df_epoch_result = pd.DataFrame({ "test_type": [test_type] * len(df_epoch_list), "epoch": df_epoch_list, "threshold": df_threshold_list, "acc_2d": epoch_accuracy_2d_list, "prec_2d": epoch_precision_2d_list, "rec_2d": epoch_recall_2d_list, "f1_2d": epoch_f1_2d_list, "acc_3d": epoch_accuracy_3d_list, "prec_3d": epoch_precision_3d_list, "rec_3d": epoch_recall_3d_list, "f1_3d": epoch_f1_3d_list, "acc_pos_2d": epoch_positive_accuracy_2d_list, "prec_pos_2d": epoch_positive_precision_2d_list, "rec_pos_2d": epoch_positive_recall_2d_list, "f1_pos_2d": epoch_positive_f1_2d_list, "acc_neg_2d": epoch_negative_accuracy_2d_list, "prec_neg_2d": epoch_negative_precision_2d_list, "rec_neg_2d": epoch_negative_recall_2d_list, "f1_neg_2d": epoch_negative_f1_2d_list, "acc_pos_3d": epoch_positive_accuracy_3d_list, "prec_pos_3d": epoch_positive_precision_3d_list, "rec_pos_3d": epoch_positive_recall_3d_list, "f1_pos_3d": epoch_positive_f1_3d_list, "acc_neg_3d": epoch_negative_accuracy_3d_list, "prec_neg_3d": epoch_negative_precision_3d_list, "rec_neg_3d": epoch_negative_recall_3d_list, "f1_neg_3d": epoch_negative_f1_3d_list, "error_data_2d": epoch_error_2d_list, "error_data_3d": epoch_error_3d_list, }) else: df_epoch_result = pd.DataFrame({ "test_type": [test_type] * len(df_epoch_list), "epoch": df_epoch_list, "threshold": df_threshold_list, "acc": epoch_accuracy_list, "prec": epoch_precision_list, "rec": epoch_recall_list, "f1": epoch_f1_list, "acc_pos": epoch_positive_accuracy_list, "prec_pos": epoch_positive_precision_list, "rec_pos": epoch_positive_recall_list, "f1_pos": epoch_positive_f1_list, "acc_neg": epoch_negative_accuracy_list, "prec_neg": epoch_negative_precision_list, "rec_neg": epoch_negative_recall_list, "f1_neg": epoch_negative_f1_list, "error_data": epoch_error_list, }) df_epoch_result_file = f"{gather_all_epoch_file.replace('.csv', '_不同阈值_评测结果.csv')}" df_epoch_result_file = os.path.join(save_dir, df_epoch_result_file) df_epoch_result.to_csv(df_epoch_result_file, index=False, encoding="utf-8") print(f"gather result finished, 不同阈值, save to {df_epoch_result_file}") def get_test_config(): from pytorch_train.train_2d3d_config import node_net_train_file_dict config = {} net_id = "2d3d" epochs = 1 node_id = "2021_2031" device_index = [0] print(f"node_id: {node_id}") date_id = "20241217" task_info = f"test_{net_id}_{node_id}_{date_id}_rotate_10_ddp_count_00000005" train_log_dir = "/df_lung/ai-project/cls_train/log/train" save_dir = "/df_lung/ai-project/cls_train/cls_ckpt" save_folder = f"{net_id}_{node_id}_{date_id}" log_file = os.path.join(train_log_dir, f"{task_info}.log") save_dir = os.path.join(save_dir, save_folder) train_data_dir = "/df_lung/cls_train_data/train_csv_data/" train_csv_file = None val_csv_file = None test_csv_file = None train_2d3d_data_2d_csv_file = None val_2d3d_data_2d_csv_file = None test_2d3d_data_2d_csv_file = None train_2d3d_data_3d_csv_file = None val_2d3d_data_3d_csv_file = None test_2d3d_data_3d_csv_file = None if net_id == "2d3d": train_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_train_file"] val_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_val_file"] test_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_test_file"] train_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_train_file"] val_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_val_file"] test_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_test_file"] else: train_csv_file = node_net_train_file_dict[(node_id, net_id)]["train_file"] val_csv_file = node_net_train_file_dict[(node_id, net_id)]["val_file"] test_csv_file = node_net_train_file_dict[(node_id, net_id)]["test_file"] logger = get_logger(log_file) if net_id == "2d3d": model = Net2d3d() batch_size = 32 elif net_id == "2d": model = Net2d() batch_size = 200 elif net_id == "3d": model = Net3d() batch_size = 6 elif net_id == "s3d": model = NetS3d() batch_size = 6 elif net_id == "d2d": model = NetD2d() batch_size = 1 if net_id not in ["s3d", "d2d"]: init_modules(model) # 创建文件夹 Path(train_log_dir).mkdir(parents=True, exist_ok=True) Path(save_dir).mkdir(parents=True, exist_ok=True) config['net_id'] = net_id config['node_id'] = node_id config['date_id'] = date_id config['logger'] = logger config['criterion'] = nn.BCELoss() config['device_index'] = device_index config['num_workers'] = 1 config['train_batch_size'] = batch_size config['val_batch_size'] = batch_size config['test_batch_size'] = batch_size config['save_dir'] = save_dir config['train_2d3d_data_2d_csv_file'] = train_2d3d_data_2d_csv_file config['val_2d3d_data_2d_csv_file'] = val_2d3d_data_2d_csv_file config['test_2d3d_data_2d_csv_file'] = test_2d3d_data_2d_csv_file config['train_2d3d_data_3d_csv_file'] = train_2d3d_data_3d_csv_file config['val_2d3d_data_3d_csv_file'] = val_2d3d_data_3d_csv_file config['test_2d3d_data_3d_csv_file'] = test_2d3d_data_3d_csv_file config['train_csv_file'] = train_csv_file config['val_csv_file'] = val_csv_file config['test_csv_file'] = test_csv_file config['num_epochs'] = epochs config['test_train_dataset_info_dict'] = { "test_train_dataset_flag": True, "test_train_epoch_list": [ 1 ], "positive_threshold_list": [0.7], "negative_threshold_list": [0.3] } config['test_val_dataset_info_dict'] = { "test_val_dataset_flag": True, "test_val_epoch_list": [ idx_test_val_epoch+1 for idx_test_val_epoch in range(config['num_epochs']) ] } config['test_test_dataset_info_dict'] = { "test_test_dataset_flag": True, "test_test_epoch_list": [ idx_test_test_epoch+1 for idx_test_test_epoch in range(config['num_epochs']) ] } # train_data_loader = DataLoader(train_dataset, batch_size=config['train_batch_size'], # num_workers=config['num_workers'], shuffle=True) # val_data_loader = DataLoader(val_dataset, batch_size=config['val_batch_size'], # num_workers=config['num_workers'], shuffle=True) return model, config # if __name__ == "__main__": # try: # model, config = get_test_config() # test_ddp(model, config) # except Exception as e: # print(f"error info: {traceback.format_exc()}") # raise e # finally: # print("多卡评测 finished") # python pytorch_train/_test_2d3d.py if __name__ == "__main__": _, config = get_test_config() # 所有epoch的评测结果 config['test_train_dataset_info_dict']['test_train_dataset_flag'] = False config['test_val_dataset_info_dict']['test_val_dataset_flag'] = True config['test_test_dataset_info_dict']['test_test_dataset_flag'] = True # test_train_all_epoch_file, test_val_all_epoch_file, test_test_all_epoch_file = gather_test_result(config) # print(f"test_train_all_epoch_file: {test_train_all_epoch_file}\ntest_val_all_epoch_file: {test_val_all_epoch_file}\ntest_test_all_epoch_file: {test_test_all_epoch_file}") # # # # /df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241124/gather_test_[train]_dataset_net_2d3d_node_2021_2031_date_20241124.csv # # # # /df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241124/gather_test_[val]_dataset_net_2d3d_node_2021_2031_date_20241124.csv # # # # /df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241124/gather_test_[test]_dataset_net_2d3d_node_2021_2031_date_20241124.csv # # # # /df_lung/ai-project/cls_train/cls_ckpt/3d_2021_2031_20241130/gather_test_[val]_dataset_net_3d_node_2021_2031_date_20241130.csv # # # # /df_lung/ai-project/cls_train/cls_ckpt/3d_2021_2031_20241130/gather_test_[test]_dataset_net_3d_node_2021_2031_date_20241130.csv # # # # /df_lung/ai-project/cls_train/cls_ckpt/encoder3d_2021_2031_20241131/gather_test_[val]_dataset_net_encoder3d_node_2021_2031_date_20241131.csv # # # # /df_lung/ai-project/cls_train/cls_ckpt/encoder3d_2021_2031_20241131/gather_test_[test]_dataset_net_encoder3d_node_2021_2031_date_20241131.csv # # # # /df_lung/ai-project/cls_train/cls_ckpt/2d_2021_2031_20241201/gather_test_[val]_dataset_net_2d_node_2021_2031_date_20241201.csv # # # # /df_lung/ai-project/cls_train/cls_ckpt/2d_2021_2031_20241201/gather_test_[test]_dataset_net_2d_node_2021_2031_date_20241201.csv # # # # /df_lung/ai-project/cls_train/cls_ckpt/3d_2041_2031_20241201/gather_test_[val]_dataset_net_3d_node_2041_2031_date_20241201.csv # # # # /df_lung/ai-project/cls_train/cls_ckpt/3d_2041_2031_20241201/gather_test_[test]_dataset_net_3d_node_2041_2031_date_20241201.csv # 统计评测结果 test_train_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241217/gather_test_[train]_dataset_net_2d3d_node_2021_2031_date_20241217.csv" test_val_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241217/gather_test_[val]_dataset_net_2d3d_node_2021_2031_date_20241217.csv" test_test_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241217/gather_test_[test]_dataset_net_2d3d_node_2021_2031_date_20241217.csv" # # # test_train_all_epoch_file = None # # # test_val_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/3d_2041_2031_20241201/gather_test_[val]_dataset_net_3d_node_2041_2031_date_20241201.csv" # # # test_test_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/3d_2041_2031_20241201/gather_test_[test]_dataset_net_3d_node_2041_2031_date_20241201.csv" # # # positive_threshold_list = config['test_train_dataset_info_dict']['positive_threshold_list'] # # # negative_threshold_list = config['test_train_dataset_info_dict']['negative_threshold_list'] # # # assert len(positive_threshold_list) == len(negative_threshold_list) # test_test_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241207/gather_test_[test]_dataset_net_2d3d_node_2021_2031_date_20241207.csv" # # # # 查看指定epoch的不在阈值区间的数据 # # # get_gather_result_train_dataset( # # # test_train_all_epoch_file, # # # config, # # # test_type="train_dataset", # # # positive_threshold_list=positive_threshold_list, # # # negative_threshold_list=negative_threshold_list # # # ) # 不同阈值的评测指标 get_gather_result(test_train_all_epoch_file, config, test_type="train_dataset") get_gather_result(test_val_all_epoch_file, config, test_type="val_dataset") get_gather_result(test_test_all_epoch_file, config, test_type="test_dataset")