import argparse import os, sys import pathlib current_dir = pathlib.Path(__file__).parent.resolve() while "cls_train" != os.path.basename(current_dir): current_dir = current_dir.parent sys.path.append(current_dir.as_posix()) os.environ["OMP_NUM_THREADS"] = "1" from cls_utils.log_utils import get_logger from sklearn.metrics import classification_report import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.distributed.optim import ZeroRedundancyOptimizer from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel from pytorch_train.classification_model_2d3d import Net2d3d, Net2d, Net3d, init_modules from pytorch_train.encoder_cls import S3dClassifier as NetS3d from pytorch_train.encoder_cls import D2dClassifier as NetD2d from pytorch_train.resnet import ResNetClassifier as NetResNet3d from pytorch_train.resnet import Bottleneck from data.dataset_2d3d import ClassificationDataset2d3d, ClassificationDataset3d, ClassificationDataset2d, custom_collate_fn_2d, custom_collate_fn_3d, custom_collate_fn_2d3d, cls_report_dict_to_string, ClassificationDatasetError2d, ClassificationDatasetError3d, ClassificationDatasetError2d3d, custom_collate_fn_2d_error, custom_collate_fn_3d_error, custom_collate_fn_2d3d_error from transformers import get_cosine_schedule_with_warmup from data.dataset_2d3d import ClassificationDatasetS3d, ClassificationDatasetErrorS3d, custom_collate_fn_s3d, custom_collate_fn_s3d_error from data.dataset_2d3d import ClassificationDatasetResnet3d, ClassificationDatasetErrorResnet3d, custom_collate_fn_resnet3d, custom_collate_fn_resnet3d_error from data.dataset_2d3d import ClassificationDatasetD2d, ClassificationDatasetErrorD2d, custom_collate_fn_d2d, custom_collate_fn_d2d_error import traceback import pandas as pd from pathlib import Path from torch.optim import lr_scheduler, Adam import random import numpy as np def set_seed(seed=1004): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def predict_on_single_gpu(net_id, model, dataloader=None, criterion=None, return_info_list=False, return_classification_report=False, return_cls_report_dict=False, threshold=0.5): """ 模型评测函数 """ model.eval() all_preds = [] all_labels = [] all_label_id = [] all_preds_2d = [] all_preds_3d = [] all_labels_2d = [] all_labels_3d = [] all_label_id_2d = [] all_label_id_3d = [] val_info_list = [] epoch = 0 device = next(model.parameters()).device with torch.no_grad(): for step, batch_data in enumerate(dataloader): if net_id == "2d3d": label_id_2d, data_2d, label_id_3d, data_3d, label_2d, label_3d = batch_data if return_classification_report: all_label_id_2d.extend(label_id_2d) all_label_id_3d.extend(label_id_3d) all_labels_2d.extend(label_2d.numpy()) all_labels_3d.extend(label_3d.numpy()) data_2d = data_2d.to(device) data_3d = data_3d.to(device) label_2d = label_2d.to(device) label_3d = label_3d.to(device) output_2d, output_3d = model(data_2d, data_3d) if return_classification_report: all_preds_2d.extend(output_2d.detach().cpu().numpy()) all_preds_3d.extend(output_3d.detach().cpu().numpy()) else: label_id, data, label = batch_data if return_classification_report: all_label_id.extend(label_id) all_labels.extend(label.numpy()) data = data.to(device) label = label.to(device) output = model(data) if return_classification_report: all_preds.extend(output.detach().cpu().numpy()) if return_info_list: if net_id == "2d3d": loss_2d = criterion(output_2d, label_2d) loss_3d = criterion(output_3d, label_3d) loss = loss_2d.sum() + loss_3d.sum() else: loss = criterion(output, label) val_info_list.append([epoch, step+1, loss.item()]) torch.cuda.empty_cache() cls_report = None cls_report_str = "" cls_report_str_2d = "" cls_report_str_3d = "" if return_classification_report: if net_id == "2d3d": all_preds_2d = np.array(all_preds_2d) all_preds_3d = np.array(all_preds_3d) all_labels_2d = np.array(all_labels_2d) all_labels_3d = np.array(all_labels_3d) all_preds_2d = (all_preds_2d > threshold).astype(np.int32) all_preds_3d = (all_preds_3d > threshold).astype(np.int32) cls_report_2d = classification_report( all_labels_2d, all_preds_2d, labels=[0, 1], target_names=['negative', 'positive'], zero_division=0, output_dict=return_cls_report_dict ) cls_report_3d = classification_report( all_labels_3d, all_preds_3d, labels=[0, 1], target_names=['negative', 'positive'], zero_division=0, output_dict=return_cls_report_dict ) cls_report_str_2d = f"2d_cls_report: \n{cls_report_dict_to_string(cls_report_2d)}\n" cls_report_str_3d = f"3d_cls_report: \n{cls_report_dict_to_string(cls_report_3d)}\n" cls_report_str = cls_report_str_2d + cls_report_str_3d all_preds_2d = [] all_preds_3d = [] all_labels_2d = [] all_labels_3d = [] else: all_labels = np.array(all_labels) all_preds = np.array(all_preds) all_preds = (all_preds > threshold).astype(np.int32) cls_report = classification_report( all_labels, all_preds, labels=[0, 1], target_names=['negative', 'positive'], zero_division=0, output_dict=return_cls_report_dict ) cls_report_str = f"cls_report: \n{cls_report_dict_to_string(cls_report)}\n" all_preds = [] all_labels = [] if net_id == "2d3d": return all_label_id_2d, all_label_id_3d, all_preds_2d, all_preds_3d, all_labels_2d, all_labels_3d, val_info_list, cls_report_2d, cls_report_3d, cls_report_str else: return all_label_id, all_preds, all_labels, val_info_list, cls_report, cls_report_str def start_train_ddp(model, config, args): logger = config['logger'] net_id = config['net_id'] train_2d3d_data_2d_csv_file = config['train_2d3d_data_2d_csv_file'] val_2d3d_data_2d_csv_file = config['val_2d3d_data_2d_csv_file'] test_2d3d_data_2d_csv_file = config['test_2d3d_data_2d_csv_file'] train_2d3d_data_3d_csv_file = config['train_2d3d_data_3d_csv_file'] val_2d3d_data_3d_csv_file = config['val_2d3d_data_3d_csv_file'] test_2d3d_data_3d_csv_file = config['test_2d3d_data_3d_csv_file'] train_csv_file = config['train_csv_file'] val_csv_file = config['val_csv_file'] pos_label = config['pos_label'] neg_label = config['neg_label'] train_batch_size = config['train_batch_size'] val_batch_size = config['val_batch_size'] val_metric = config['val_metric'] num_workers = config['num_workers'] device_index = config['device_index'] num_epochs = config['num_epochs'] weight_decay = config['weight_decay'] lr = config['lr'] criterion = config['criterion'] val_interval = config['val_interval'] save_dir = config['save_dir'] threshold = config['threshold'] logger_train_cls_report_flag = config['logger_train_cls_report_flag'] train_on_error_data_flag = config['train_on_error_data_flag'] train_on_error_data_epoch_dict_list = config['train_on_error_data_epoch_dict_list'] if train_on_error_data_flag else [] train_on_error_data_epoch_dict_list = sorted(train_on_error_data_epoch_dict_list, key=lambda x: x["train_epoch"]) current_train_epoch = 0 if len(device_index) > 1: world_size = torch.distributed.get_world_size() if len(device_index) > 1 and args.local_rank not in [-1, 0]: torch.distributed.barrier() if net_id == "2d": train_dataset = ClassificationDataset2d(train_csv_file, data_info="train_dataset_2d") val_dataset = ClassificationDataset2d(val_csv_file, data_info="val_dataset_2d") custom_collate_fn = custom_collate_fn_2d train_error_dataset_class = ClassificationDatasetError2d custom_collate_fn_error = custom_collate_fn_2d_error elif net_id == "3d": train_dataset = ClassificationDataset3d(train_csv_file, data_info="train_dataset_3d") val_dataset = ClassificationDataset3d(val_csv_file, data_info="val_dataset_3d") custom_collate_fn = custom_collate_fn_3d train_error_dataset_class = ClassificationDatasetError3d custom_collate_fn_error = custom_collate_fn_3d_error elif net_id == "2d3d": train_dataset = ClassificationDataset2d3d(train_2d3d_data_2d_csv_file, train_2d3d_data_3d_csv_file, data_info="train_dataset_2d3d") val_dataset = ClassificationDataset2d3d(val_2d3d_data_2d_csv_file, val_2d3d_data_3d_csv_file, data_info="val_dataset_2d3d") custom_collate_fn = custom_collate_fn_2d3d train_error_dataset_class = ClassificationDatasetError2d3d custom_collate_fn_error = custom_collate_fn_2d3d_error elif net_id == "s3d": train_dataset = ClassificationDatasetS3d(train_csv_file, data_info="train_dataset_S3d") val_dataset = ClassificationDatasetS3d(val_csv_file, data_info="val_dataset_S3d") custom_collate_fn = custom_collate_fn_s3d train_error_dataset_class = ClassificationDatasetErrorS3d custom_collate_fn_error = custom_collate_fn_s3d_error elif net_id == "resnet3d": train_dataset = ClassificationDatasetResnet3d(train_csv_file, data_info="train_dataset_resnet3d") val_dataset = ClassificationDatasetResnet3d(val_csv_file, data_info="val_dataset_resnet3d") custom_collate_fn = custom_collate_fn_resnet3d train_error_dataset_class = ClassificationDatasetErrorResnet3d custom_collate_fn_error = custom_collate_fn_resnet3d_error elif net_id == "d2d": train_dataset = ClassificationDatasetD2d(train_csv_file, data_info="train_dataset_d2d") val_dataset = ClassificationDatasetD2d(val_csv_file, data_info="val_dataset_d2d") custom_collate_fn = custom_collate_fn_d2d train_error_dataset_class = ClassificationDatasetErrorD2d custom_collate_fn_error = custom_collate_fn_d2d_error else: raise ValueError(f"net_id {net_id} not supported") if len(device_index) > 1 and args.local_rank == 0: torch.distributed.barrier() if len(device_index) > 1: train_sampler = DistributedSampler(train_dataset, shuffle=True, rank=args.local_rank, num_replicas=world_size) val_sampler = DistributedSampler(val_dataset, shuffle=False, rank=args.local_rank, num_replicas=world_size) train_data_loader = DataLoader( train_dataset, batch_size=train_batch_size // len(device_index), drop_last=False, shuffle=False, num_workers=num_workers, sampler=train_sampler, collate_fn=custom_collate_fn ) val_data_loader = DataLoader( val_dataset, batch_size=val_batch_size // len(device_index), drop_last=False, shuffle=False, num_workers=num_workers, sampler=val_sampler, collate_fn=custom_collate_fn ) else: train_data_loader = DataLoader( train_dataset, batch_size=train_batch_size, drop_last=False, shuffle=True, num_workers=num_workers, collate_fn=custom_collate_fn ) val_data_loader = DataLoader( val_dataset, batch_size=val_batch_size, drop_last=False, shuffle=False, num_workers=num_workers, collate_fn=custom_collate_fn ) logger.info(f"start_train_ddp, net_id: {net_id}, device_index: {device_index}") if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train data total batch: {len(train_data_loader)}, val data total batch: {len(val_data_loader)}") logger.info(f"local rank: {args.local_rank}, world size: {world_size}") else: logger.info(f"local rank: {device_index[0]}, train data total batch: {len(train_data_loader)}, val data total batch: {len(val_data_loader)}") if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train with DistributedDataParallel") model = DistributedDataParallel( model.cuda(), device_ids=[args.local_rank], output_device=args.local_rank ) else: logger.info(f"local rank: {device_index[0]}, train with single gpu") model = model.to(device_index[0]) if args.use_zero: optimizer = ZeroRedundancyOptimizer( model.parameters(), optimizer_class=torch.optim.Adam, lr=lr, weight_decay=weight_decay ) else: optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train with optimizer: {optimizer}") else: logger.info(f"local rank: {device_index[0]}, train with optimizer: {optimizer}") lr_scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=10, num_training_steps=len(train_data_loader) * num_epochs ) train_info_list = [] val_info_list = [] best_val_metric_score = 0 current_step = 0 if val_interval is None: val_interval = len(train_data_loader) if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, before train batch count: {len(train_data_loader)}, val interval batch count: {val_interval}") else: logger.info(f"local rank: {device_index[0]}, before train batch count: {len(train_data_loader)}, val interval batch count: {val_interval}") train_all_preds = [] train_all_labels = [] train_all_label_id = [] train_all_label_id_2d = [] train_all_label_id_3d = [] train_all_preds_2d = [] train_all_preds_3d = [] train_all_labels_2d = [] train_all_labels_3d = [] error_data_dict = {} error_data_2d_dict = {} error_data_3d_dict = {} print_init_params(model, string=f"start_train_ddp, train_{net_id}", logger=logger) for epoch in range(num_epochs): model.train() if len(device_index) > 1: train_sampler.set_epoch(epoch) current_train_epoch += 1 for step, batch_data in enumerate(train_data_loader): current_step += 1 if net_id == "2d3d": label_id_2d, data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, data_3d, label_3d, z_index_3d, rotate_count_3d = batch_data if logger_train_cls_report_flag or (train_on_error_data_flag and len(train_on_error_data_epoch_dict_list) > 0 and current_train_epoch == train_on_error_data_epoch_dict_list[0]["train_epoch"]): train_all_label_id_2d.extend(label_id_2d) train_all_label_id_3d.extend(label_id_3d) train_all_labels_2d.append(label_2d.cpu()) train_all_labels_3d.append(label_3d.cpu()) if len(device_index) > 1: data_2d = data_2d.cuda() data_3d = data_3d.cuda() label_2d = label_2d.cuda() label_3d = label_3d.cuda() else: data_2d = data_2d.to(device_index[0]) data_3d = data_3d.to(device_index[0]) label_2d = label_2d.to(device_index[0]) label_3d = label_3d.to(device_index[0]) y_pred_2d, y_pred_3d = model(data_2d, data_3d) if logger_train_cls_report_flag or (train_on_error_data_flag and len(train_on_error_data_epoch_dict_list) > 0 and current_train_epoch == train_on_error_data_epoch_dict_list[0]["train_epoch"]): train_all_preds_2d.append(y_pred_2d.detach().cpu()) train_all_preds_3d.append(y_pred_3d.detach().cpu()) else: label_id, data, label, z_index, rotate_count = batch_data if logger_train_cls_report_flag or (train_on_error_data_flag and len(train_on_error_data_epoch_dict_list) > 0 and current_train_epoch == train_on_error_data_epoch_dict_list[0]["train_epoch"]): train_all_label_id.extend(label_id) train_all_labels.append(label.cpu()) if len(device_index) > 1: data = data.cuda() label = label.cuda() else: data = data.to(device_index[0]) label = label.to(device_index[0]) y_pred = model(data) if logger_train_cls_report_flag or (train_on_error_data_flag and len(train_on_error_data_epoch_dict_list) > 0 and current_train_epoch == train_on_error_data_epoch_dict_list[0]["train_epoch"]): train_all_preds.append(y_pred.detach().cpu()) if net_id == "2d3d": loss = criterion(y_pred_2d, label_2d).sum() + criterion(y_pred_3d, label_3d).sum() else: loss = criterion(y_pred, label).sum() if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train epoch: {epoch+1}, step: {step+1}, loss: {loss.item()}") else: logger.info(f"local rank: {device_index[0]}, train epoch: {epoch+1}, step: {step+1}, loss: {loss.item()}") optimizer.zero_grad() loss.backward() optimizer.step() if len(device_index) > 1: torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) lr_scheduler.step() if net_id == "2d3d" and current_step > 0 and current_step % 1000 == 0: train_step_pt_file = f"train_epoch_{epoch+1}_step_{current_step}_net_{net_id}.pt" save_train_step_pt_file = os.path.join(save_dir, train_step_pt_file) if len(device_index) > 1: torch.save(model.module.state_dict(), save_train_step_pt_file) logger.info(f"local rank: {args.local_rank}, save train epoch step pt file to {save_train_step_pt_file}") else: torch.save(model.state_dict(), save_train_step_pt_file) logger.info(f"local rank: {device_index[0]}, save train epoch step pt file to {save_train_step_pt_file}") # 评测 if current_step % val_interval == 0 and ((len(device_index) > 1 and args.local_rank == 0) or len(device_index) == 1): if net_id == "2d3d": _, _, _, _, _, _, current_val_info_list, current_val_classification_report_2d, current_val_classification_report_3d, current_val_classification_report_str = predict_on_single_gpu( net_id, model, dataloader=val_data_loader, criterion=criterion, return_info_list=True, return_classification_report=True, return_cls_report_dict=True, threshold=threshold ) else: _, _, _, current_val_info_list, current_val_classification_report, current_val_classification_report_str = predict_on_single_gpu( net_id, model, dataloader=val_data_loader, criterion=criterion, return_info_list=True, return_classification_report=True, return_cls_report_dict=True, threshold=threshold ) val_info_list += current_val_info_list if len(device_index) > 1: logger.info(f"current_local_rank: {args.local_rank}, current_device_index: {device_index[args.local_rank]},\nepoch: {epoch+1}, step: {step+1}, loss: {current_val_info_list[-1][2]}\n 训练评测--验证集: classification report:\n {current_val_classification_report_str}") else: logger.info(f"current_local_rank: {device_index[0]}, current_device_index: {device_index[0]},\nepoch: {epoch+1}, step: {step+1}, loss: {current_val_info_list[-1][2]}\n 训练评测--验证集: classification report:\n {current_val_classification_report_str}") if net_id == "2d3d": current_score_list = [ current_val_classification_report_2d[pos_label][val_metric], current_val_classification_report_3d[pos_label][val_metric], current_val_classification_report_2d[neg_label][val_metric], current_val_classification_report_3d[neg_label][val_metric] ] current_val_metric_score = sum(current_score_list) / len(current_score_list) else: current_score_list = [ current_val_classification_report[pos_label][val_metric], current_val_classification_report[neg_label][val_metric] ] current_val_metric_score = sum(current_score_list) / len(current_score_list) if current_val_metric_score > best_val_metric_score + 0.000000001: best_val_metric_score = current_val_metric_score save_path = os.path.join(save_dir, f"best_epoch_{epoch+1}_step_{step+1}_score_{best_val_metric_score:.4f}.pth") if len(device_index) > 1: torch.save(model.module.state_dict(), save_path) else: torch.save(model.state_dict(), save_path) if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, best val {val_metric} score: {best_val_metric_score}, save model to {save_path}") else: logger.info(f"local rank: {device_index[0]}, best val {val_metric} score: {best_val_metric_score}, save model to {save_path}") # 评测训练集 - epoch 结束时 if logger_train_cls_report_flag: train_cls_report = None if net_id == "2d3d": cls_train_all_label_id_2d = train_all_label_id_2d[:] cls_train_all_label_id_3d = train_all_label_id_3d[:] cls_train_all_preds_2d = torch.cat(train_all_preds_2d, dim=0) cls_train_all_preds_3d = torch.cat(train_all_preds_3d, dim=0) cls_train_all_labels_2d = torch.cat(train_all_labels_2d, dim=0) cls_train_all_labels_3d = torch.cat(train_all_labels_3d, dim=0) cls_train_all_preds_2d_binary = (cls_train_all_preds_2d > threshold).int() cls_train_all_preds_3d_binary = (cls_train_all_preds_3d > threshold).int() # positive_indices_2d = torch.where(cls_train_all_labels_2d == 1) # negative_indices_2d = torch.where(cls_train_all_labels_2d == 0) # positive_indices_3d = torch.where(cls_train_all_labels_3d == 1) # negative_indices_3d = torch.where(cls_train_all_labels_3d == 0) positive_error_2d_indices = torch.where((cls_train_all_labels_2d == 1) & (cls_train_all_preds_2d_binary == 0)) negative_error_2d_indices = torch.where((cls_train_all_labels_2d == 0) & (cls_train_all_preds_2d_binary == 1)) positive_error_3d_indices = torch.where((cls_train_all_labels_3d == 1) & (cls_train_all_preds_3d_binary == 0)) negative_error_3d_indices = torch.where((cls_train_all_labels_3d == 0) & (cls_train_all_preds_3d_binary == 1)) def get_error_samples(indices, label_id, labels, preds): error_label_ids = [label_id[i] for i in indices[0]] error_labels = labels[indices] error_preds = preds[indices] return error_label_ids, error_labels, error_preds error_label_id_2d_pos, error_labels_2d_pos, error_preds_2d_pos = get_error_samples(positive_error_2d_indices, cls_train_all_label_id_2d, cls_train_all_labels_2d, cls_train_all_preds_2d) error_label_id_2d_neg, error_labels_2d_neg, error_preds_2d_neg = get_error_samples(negative_error_2d_indices, cls_train_all_label_id_2d, cls_train_all_labels_2d, cls_train_all_preds_2d) error_label_id_3d_pos, error_labels_3d_pos, error_preds_3d_pos = get_error_samples(positive_error_3d_indices, cls_train_all_label_id_3d, cls_train_all_labels_3d, cls_train_all_preds_3d) error_label_id_3d_neg, error_labels_3d_neg, error_preds_3d_neg = get_error_samples(negative_error_3d_indices, cls_train_all_label_id_3d, cls_train_all_labels_3d, cls_train_all_preds_3d) train_cls_report_2d = classification_report(cls_train_all_labels_2d.flatten().cpu().numpy(), cls_train_all_preds_2d_binary.flatten().cpu().numpy(), labels=[0, 1], target_names=['negative', 'positive'], zero_division=0) train_cls_report_3d = classification_report(cls_train_all_labels_3d.flatten().cpu().numpy(), cls_train_all_preds_3d_binary.flatten().cpu().numpy(), labels=[0, 1], target_names=['negative', 'positive'], zero_division=0) train_cls_report = f"2d_cls_report: \n{train_cls_report_2d}\n3d_cls_report: \n{train_cls_report_3d}" cls_error_info_str = "" cls_error_info_str += f"net_id_{net_id}_2d分类: 预测错误的正样本信息, 个数: {len(error_labels_2d_pos)}:\n" for label_id, label, pred in zip(error_label_id_2d_pos, error_labels_2d_pos, error_preds_2d_pos): cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n" cls_error_info_str += f"net_id_{net_id}_2d分类: 预测错误的负样本信息, 个数: {len(error_labels_2d_neg)}:\n" for label_id, label, pred in zip(error_label_id_2d_neg, error_labels_2d_neg, error_preds_2d_neg): cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n" cls_error_info_str += f"net_id_{net_id}_3d分类: 预测错误的正样本信息, 个数: {len(error_labels_3d_pos)}:\n" for label_id, label, pred in zip(error_label_id_3d_pos, error_labels_3d_pos, error_preds_3d_pos): cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n" cls_error_info_str += f"net_id_{net_id}_3d分类: 预测错误的负样本信息, 个数: {len(error_labels_3d_neg)}:\n" for label_id, label, pred in zip(error_label_id_3d_neg, error_labels_3d_neg, error_preds_3d_neg): cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n" del cls_train_all_label_id_2d, cls_train_all_label_id_3d del cls_train_all_preds_2d, cls_train_all_preds_3d, cls_train_all_labels_2d, cls_train_all_labels_3d del cls_train_all_preds_2d_binary, cls_train_all_preds_3d_binary # del positive_indices_2d, negative_indices_2d, positive_indices_3d, negative_indices_3d del positive_error_2d_indices, negative_error_2d_indices, positive_error_3d_indices, negative_error_3d_indices del error_label_id_2d_pos, error_labels_2d_pos, error_preds_2d_pos del error_label_id_2d_neg, error_labels_2d_neg, error_preds_2d_neg del error_label_id_3d_pos, error_labels_3d_pos, error_preds_3d_pos del error_label_id_3d_neg, error_labels_3d_neg, error_preds_3d_neg del train_cls_report_2d, train_cls_report_3d torch.cuda.empty_cache() else: cls_train_all_label_id = train_all_label_id[:] cls_train_all_labels = torch.cat(train_all_labels, dim=0) cls_train_all_preds = torch.cat(train_all_preds, dim=0) cls_train_all_preds_binary = (cls_train_all_preds > threshold).int() # positive_indices = torch.where(cls_train_all_labels == 1) # negative_indices = torch.where(cls_train_all_labels == 0) positive_error_indices = torch.where((cls_train_all_labels == 1) & (cls_train_all_preds_binary == 0)) negative_error_indices = torch.where((cls_train_all_labels == 0) & (cls_train_all_preds_binary == 1)) def get_error_samples(indices, label_id, labels, preds): error_label_ids = [label_id[i] for i in indices[0]] error_labels = labels[indices] error_preds = preds[indices] return error_label_ids, error_labels, error_preds error_label_id_pos, error_labels_pos, error_preds_pos = get_error_samples(positive_error_indices, cls_train_all_label_id, cls_train_all_labels, cls_train_all_preds) error_label_id_neg, error_labels_neg, error_preds_neg = get_error_samples(negative_error_indices, cls_train_all_label_id, cls_train_all_labels, cls_train_all_preds) train_cls_report = classification_report(cls_train_all_labels.flatten().cpu().numpy(), cls_train_all_preds_binary.flatten().cpu().numpy(), labels=[0, 1], target_names=['negative', 'positive'], zero_division=0) train_cls_report = f"cls_report: \n{train_cls_report}" cls_error_info_str = "" cls_error_info_str += f"net_id_{net_id}_分类: 预测错误的正样本信息, 个数: {len(error_labels_pos)}:\n" for label_id, label, pred in zip(error_label_id_pos, error_labels_pos, error_preds_pos): cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n" cls_error_info_str += f"net_id_{net_id}_分类: 预测错误的负样本信息, 个数: {len(error_labels_neg)}:\n" for label_id, label, pred in zip(error_label_id_neg, error_labels_neg, error_preds_neg): cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n" del cls_train_all_label_id, cls_train_all_labels, cls_train_all_preds del cls_train_all_preds_binary, positive_error_indices, negative_error_indices del error_label_id_pos, error_labels_pos, error_preds_pos del error_label_id_neg, error_labels_neg, error_preds_neg torch.cuda.empty_cache() train_info_list.append([epoch+1, step+1, loss.item(), str(train_cls_report)]) if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, 评测训练集, train epoch: {epoch+1}, step: {step+1}, loss: {loss.item()}\n train classification report:\n {train_cls_report}") logger.info(f"local rank: {args.local_rank}, 评测训练集, train epoch: {epoch+1}, 训练集分类错误的信息\n{cls_error_info_str}") else: logger.info(f"local rank: {device_index[0]}, 评测训练集, train epoch: {epoch+1}, step: {step+1}, loss: {loss.item()}\n train classification report:\n {train_cls_report}") logger.info(f"local rank: {device_index[0]}, 评测训练集, train epoch: {epoch+1}, 训练集分类错误的信息\n{cls_error_info_str}") # 保存 - epoch 结束时 if len(device_index) > 1: torch.distributed.barrier() if (len(device_index) > 1 and args.local_rank == 0) or len(device_index) == 1: train_epoch_pt_file = f"train_epoch_{epoch+1}_net_{net_id}.pt" save_train_epoch_pt_file = os.path.join(save_dir, train_epoch_pt_file) if len(device_index) > 1: torch.save(model.module.state_dict(), save_train_epoch_pt_file) else: torch.save(model.state_dict(), save_train_epoch_pt_file) if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train epoch: {epoch+1}, save model to {save_train_epoch_pt_file}") else: logger.info(f"local rank: {device_index[0]}, train epoch: {epoch+1}, save model to {save_train_epoch_pt_file}") # 训练集错误的数据,单独训练 if train_on_error_data_flag and len(train_on_error_data_epoch_dict_list) > 0 and current_train_epoch == train_on_error_data_epoch_dict_list[0]["train_epoch"]: current_train_on_error_data_epoch_dict = train_on_error_data_epoch_dict_list.pop(0) if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train_on_error_data_epoch_dict_list: {len(train_on_error_data_epoch_dict_list)}") logger.info(f"local rank: {args.local_rank}, train on error data epochs: {current_train_on_error_data_epoch_dict['train_epoch']}, train epoch: {epoch+1}") else: logger.info(f"local rank: {device_index[0]}, train_on_error_data_epoch_dict_list: {len(train_on_error_data_epoch_dict_list)}") logger.info(f"local rank: {device_index[0]}, train on error data epochs: {current_train_on_error_data_epoch_dict['train_epoch']}, train epoch: {epoch+1}") if net_id == "2d3d": cls_train_all_label_id_2d = train_all_label_id_2d[:] cls_train_all_label_id_3d = train_all_label_id_3d[:] cls_train_all_preds_2d = torch.cat(train_all_preds_2d, dim=0) cls_train_all_preds_3d = torch.cat(train_all_preds_3d, dim=0) cls_train_all_labels_2d = torch.cat(train_all_labels_2d, dim=0) cls_train_all_labels_3d = torch.cat(train_all_labels_3d, dim=0) # 修改成正样本预测概率小于0.2,负样本预测概率大于0.8 current_positive_threshold = current_train_on_error_data_epoch_dict["positive_threshold"] current_negative_threshold = current_train_on_error_data_epoch_dict["negative_threshold"] positive_error_2d_indices = torch.where((cls_train_all_labels_2d == 1) & (cls_train_all_preds_2d < current_positive_threshold)) negative_error_2d_indices = torch.where((cls_train_all_labels_2d == 0) & (cls_train_all_preds_2d > current_negative_threshold)) positive_error_3d_indices = torch.where((cls_train_all_labels_3d == 1) & (cls_train_all_preds_3d < current_positive_threshold)) negative_error_3d_indices = torch.where((cls_train_all_labels_3d == 0) & (cls_train_all_preds_3d > current_negative_threshold)) log_train_error_data_str = "" log_train_error_data_str += f"train on error data, net_id_{net_id}_2d分类: 预测错误的正样本信息, 个数: {len(positive_error_2d_indices[0])}:\n" for idx_index in positive_error_2d_indices[0]: log_train_error_data_str += f"label_id: {cls_train_all_label_id_2d[idx_index]}, label: {cls_train_all_labels_2d[idx_index]}, pred: {cls_train_all_preds_2d[idx_index]}\n" log_train_error_data_str += f"train on error data, net_id_{net_id}_2d分类: 预测错误的负样本信息, 个数: {len(negative_error_2d_indices[0])}:\n" for idx_index in negative_error_2d_indices[0]: log_train_error_data_str += f"label_id: {cls_train_all_label_id_2d[idx_index]}, label: {cls_train_all_labels_2d[idx_index]}, pred: {cls_train_all_preds_2d[idx_index]}\n" log_train_error_data_str += f"train on error data, net_id_{net_id}_3d分类: 预测错误的正样本信息, 个数: {len(positive_error_3d_indices[0])}:\n" for idx_index in positive_error_3d_indices[0]: log_train_error_data_str += f"label_id: {cls_train_all_label_id_3d[idx_index]}, label: {cls_train_all_labels_3d[idx_index]}, pred: {cls_train_all_preds_3d[idx_index]}\n" log_train_error_data_str += f"train on error data, net_id_{net_id}_3d分类: 预测错误的负样本信息, 个数: {len(negative_error_3d_indices[0])}:\n" for idx_index in negative_error_3d_indices[0]: log_train_error_data_str += f"label_id: {cls_train_all_label_id_3d[idx_index]}, label: {cls_train_all_labels_3d[idx_index]}, pred: {cls_train_all_preds_3d[idx_index]}\n" if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train on error data, train_error_data_epochs: {current_train_on_error_data_epoch_dict['train_epoch']}\ntrain epoch: {epoch+1}, 训练集错误样本单独训练,训练集预测错误的信息:\n{log_train_error_data_str}") else: logger.info(f"local rank: {device_index[0]}, train on error data, train_error_data_epochs: {current_train_on_error_data_epoch_dict['train_epoch']}\ntrain epoch: {epoch+1}, 训练集错误样本单独训练,训练集预测错误的信息:\n{log_train_error_data_str}") error_2d_indices = torch.cat((positive_error_2d_indices[0], negative_error_2d_indices[0])) error_3d_indices = torch.cat((positive_error_3d_indices[0], negative_error_3d_indices[0])) error_data_2d_label_id_list = [cls_train_all_label_id_2d[i] for i in error_2d_indices.cpu().numpy()] error_data_3d_label_id_list = [cls_train_all_label_id_3d[i] for i in error_3d_indices.cpu().numpy()] if len(error_data_2d_label_id_list) == 0 and len(error_data_3d_label_id_list) == 0: if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train on error data, 预测错误的样本为0, 跳过错误样本单独训练, error_data_2d_label_id_list: {error_data_2d_label_id_list}\nerror_data_3d_label_id_list: {error_data_3d_label_id_list}") else: logger.info(f"local rank: {device_index[0]}, train on error data, 预测错误的样本为0, 跳过错误样本单独训练, error_data_2d_label_id_list: {error_data_2d_label_id_list}\nerror_data_3d_label_id_list: {error_data_3d_label_id_list}") del cls_train_all_label_id_2d del cls_train_all_label_id_3d del cls_train_all_preds_2d del cls_train_all_preds_3d del cls_train_all_labels_2d del cls_train_all_labels_3d del current_positive_threshold del current_negative_threshold del positive_error_2d_indices del negative_error_2d_indices del positive_error_3d_indices del negative_error_3d_indices del log_train_error_data_str del error_2d_indices del error_3d_indices del error_data_2d_label_id_list del error_data_3d_label_id_list torch.cuda.empty_cache() continue current_train_batch_size = current_train_on_error_data_epoch_dict["error_train_batch_size"] if len(device_index) > 1: error_file = f"train_error_data_local_rank_{args.local_rank}_current_train_epoch_{current_train_on_error_data_epoch_dict['train_epoch']}_[{num_epochs}]_train_error_data_epoch_{current_train_on_error_data_epoch_dict['error_train_epoch']}_net_{net_id}.csv" else: error_file = f"train_error_data_local_rank_{device_index[0]}_current_train_epoch_{current_train_on_error_data_epoch_dict['train_epoch']}_[{num_epochs}]_train_error_data_epoch_{current_train_on_error_data_epoch_dict['error_train_epoch']}_net_{net_id}.csv" train_error_csv_file = os.path.join(save_dir, error_file) error_data_2d_preds_list = [cls_train_all_preds_2d[i].cpu().numpy() for i in error_2d_indices.cpu().numpy()] error_data_3d_preds_list = [cls_train_all_preds_3d[i].cpu().numpy() for i in error_3d_indices.cpu().numpy()] error_data_2d_labels_list = [cls_train_all_labels_2d[i].cpu().numpy() for i in error_2d_indices.cpu().numpy()] error_data_3d_labels_list = [cls_train_all_labels_3d[i].cpu().numpy() for i in error_3d_indices.cpu().numpy()] if len(device_index) > 1: print(f"local rank: {args.local_rank}, train on error data\nerror_data_2d_label_id_list: {len(error_data_2d_label_id_list)}\nerror_data_3d_label_id_list: {len(error_data_3d_label_id_list)}") else: print(f"local rank: {device_index[0]}, train on error data\nerror_data_2d_label_id_list: {len(error_data_2d_label_id_list)}\nerror_data_3d_label_id_list: {len(error_data_3d_label_id_list)}") train_error_dataset = train_error_dataset_class( csv_file=train_csv_file, label_id_2d_list=error_data_2d_label_id_list, label_id_3d_list=error_data_3d_label_id_list, error_file=train_error_csv_file, preds_2d=error_data_2d_preds_list, preds_3d=error_data_3d_preds_list, labels_2d=error_data_2d_labels_list, labels_3d=error_data_3d_labels_list ) train_error_dataloader = DataLoader( train_error_dataset, batch_size=current_train_batch_size, shuffle=True, num_workers=num_workers, collate_fn = custom_collate_fn_error ) if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train on error data, 原始训练数据, error_data_2d_label_id_list: {len(error_data_2d_label_id_list)}\nerror_data_3d_label_id_list: {len(error_data_3d_label_id_list)}") logger.info(f"local rank: {args.local_rank}, train on error data, 原始训练数据, error_data_2d_label_id_list: {error_data_2d_label_id_list}\nerror_data_3d_label_id_list: {error_data_3d_label_id_list}") logger.info(f"local rank: {args.local_rank}, train on error data, 处理训练数据, train_error_dataloader: {len(train_error_dataloader)}") else: logger.info(f"local rank: {device_index[0]}, train on error data, 原始训练数据, error_data_2d_label_id_list: {len(error_data_2d_label_id_list)}\nerror_data_3d_label_id_list: {len(error_data_3d_label_id_list)}") logger.info(f"local rank: {device_index[0]}, train on error data, 原始训练数据, error_data_2d_label_id_list: {error_data_2d_label_id_list}\nerror_data_3d_label_id_list: {error_data_3d_label_id_list}") logger.info(f"local rank: {device_index[0]}, train on error data, 处理训练数据, train_error_dataloader: {len(train_error_dataloader)}") current_train_on_error_epochs = current_train_on_error_data_epoch_dict["error_train_epoch"] for idx_train_error_data_epoch in range(current_train_on_error_epochs): for idx_batch in train_error_dataloader: idx_train_error_data_2d, idx_train_error_data_3d, idx_train_error_label_2d, idx_train_error_label_3d = idx_batch if len(device_index) > 1: idx_train_error_data_2d = idx_train_error_data_2d.cuda() idx_train_error_data_3d = idx_train_error_data_3d.cuda() idx_train_error_label_2d = idx_train_error_label_2d.cuda() idx_train_error_label_3d = idx_train_error_label_3d.cuda() else: idx_train_error_data_2d = idx_train_error_data_2d.to(device_index[0]) idx_train_error_data_3d = idx_train_error_data_3d.to(device_index[0]) idx_train_error_label_2d = idx_train_error_label_2d.to(device_index[0]) idx_train_error_label_3d = idx_train_error_label_3d.to(device_index[0]) idx_train_error_y_pred_2d, idx_train_error_y_pred_3d = model(idx_train_error_data_2d, idx_train_error_data_3d) idx_train_error_loss = criterion(idx_train_error_y_pred_2d, idx_train_error_label_2d) + criterion(idx_train_error_y_pred_3d, idx_train_error_label_3d) optimizer.zero_grad() idx_train_error_loss.backward() optimizer.step() if len(device_index) > 1: torch.distributed.all_reduce(idx_train_error_loss, op=torch.distributed.ReduceOp.AVG) if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train on error data, train_error_data_epoch: {idx_train_error_data_epoch+1}, train_error_epochs: {current_train_on_error_epochs}, current_train_epoch: {epoch+1}, train_epoch: {num_epochs}, loss: {idx_train_error_loss.item()}") else: logger.info(f"local rank: {device_index[0]}, train on error data, train_error_data_epoch: {idx_train_error_data_epoch+1}, train_error_epochs: {current_train_on_error_epochs}, current_train_epoch: {epoch+1}, train_epoch: {num_epochs}, loss: {idx_train_error_loss.item()}") if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train on error data, 错误样本单独训练结束") else: logger.info(f"local rank: {device_index[0]}, train on error data, 错误样本单独训练结束") del cls_train_all_label_id_2d del cls_train_all_label_id_3d del cls_train_all_preds_2d del cls_train_all_preds_3d del cls_train_all_labels_2d del cls_train_all_labels_3d del current_positive_threshold del current_negative_threshold del positive_error_2d_indices del negative_error_2d_indices del positive_error_3d_indices del negative_error_3d_indices del log_train_error_data_str del error_2d_indices del error_3d_indices del error_data_2d_label_id_list del error_data_3d_label_id_list del current_train_batch_size del error_file del train_error_csv_file del error_data_2d_preds_list del error_data_3d_preds_list del error_data_2d_labels_list del error_data_3d_labels_list del train_error_dataset del train_error_dataloader del current_train_on_error_epochs del idx_train_error_data_2d del idx_train_error_data_3d del idx_train_error_label_2d del idx_train_error_label_3d del idx_train_error_y_pred_2d del idx_train_error_y_pred_3d del idx_train_error_loss torch.cuda.empty_cache() else: cls_train_all_label_id = train_all_label_id[:] cls_train_all_preds = torch.cat(train_all_preds, dim=0) cls_train_all_labels = torch.cat(train_all_labels, dim=0) # 修改成正样本预测概率小于0.2,负样本预测概率大于0.8 current_positive_threshold = current_train_on_error_data_epoch_dict["positive_threshold"] current_negative_threshold = current_train_on_error_data_epoch_dict["negative_threshold"] positive_error_indices = torch.where((cls_train_all_labels == 1) & (cls_train_all_preds < current_positive_threshold)) negative_error_indices = torch.where((cls_train_all_labels == 0) & (cls_train_all_preds > current_negative_threshold)) log_train_error_data_str = "" log_train_error_data_str += f"train on error data, net_id_{net_id}_分类: 预测错误的正样本信息, 个数: {len(positive_error_indices[0])}:\n" for idx_index in positive_error_indices[0]: log_train_error_data_str += f"label_id: {cls_train_all_label_id[idx_index]}, label: {cls_train_all_labels[idx_index]}, pred: {cls_train_all_preds[idx_index]}\n" log_train_error_data_str += f"train on error data, net_id_{net_id}_分类: 预测错误的负样本信息, 个数: {len(negative_error_indices[0])}:\n" for idx_index in negative_error_indices[0]: log_train_error_data_str += f"label_id: {cls_train_all_label_id[idx_index]}, label: {cls_train_all_labels[idx_index]}, pred: {cls_train_all_preds[idx_index]}\n" if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train on error data, train_error_data_epochs: {current_train_on_error_data_epoch_dict['train_epoch']}\ntrain epoch: {epoch+1}, 训练集错误样本单独训练,训练集预测错误的信息:\n{log_train_error_data_str}") else: logger.info(f"local rank: {device_index[0]}, train on error data, train_error_data_epochs: {current_train_on_error_data_epoch_dict['train_epoch']}\ntrain epoch: {epoch+1}, 训练集错误样本单独训练,训练集预测错误的信息:\n{log_train_error_data_str}") error_indices = torch.cat((positive_error_indices[0], negative_error_indices[0])) error_data_label_id_list = [cls_train_all_label_id[i] for i in error_indices.cpu().numpy()] if len(error_data_label_id_list) == 0: if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train on error data, 预测错误的样本为0, 跳过错误样本单独训练, error_data_label_id_list: {error_data_label_id_list}") else: logger.info(f"local rank: {device_index[0]}, train on error data, 预测错误的样本为0, 跳过错误样本单独训练, error_data_label_id_list: {error_data_label_id_list}") del cls_train_all_label_id del cls_train_all_preds del cls_train_all_labels del current_positive_threshold del current_negative_threshold del positive_error_indices del negative_error_indices del log_train_error_data_str del error_indices del error_data_label_id_list torch.cuda.empty_cache() continue current_train_batch_size = current_train_on_error_data_epoch_dict["error_train_batch_size"] error_file = f"train_error_data_current_train_epoch_{current_train_on_error_data_epoch_dict['train_epoch']}_[{num_epochs}]_train_error_data_epoch_{current_train_on_error_data_epoch_dict['error_train_epoch']}_net_{net_id}.csv" train_error_csv_file = os.path.join(save_dir, error_file) error_data_preds_list = [cls_train_all_preds[i].cpu().numpy() for i in error_indices.cpu().numpy()] error_data_labels_list = [cls_train_all_labels[i].cpu().numpy() for i in error_indices.cpu().numpy()] train_error_dataset = None if net_id == "2d": train_error_dataset = train_error_dataset_class( csv_file=train_csv_file, label_id_2d_list=error_data_label_id_list, error_file=train_error_csv_file, preds_2d=error_data_preds_list, labels_2d=error_data_labels_list ) elif net_id == "3d": train_error_dataset = train_error_dataset_class( csv_file=train_csv_file, label_id_3d_list=error_data_label_id_list, error_file=train_error_csv_file, preds_3d=error_data_preds_list, labels_3d=error_data_labels_list ) elif net_id == "s3d": train_error_dataset = train_error_dataset_class( csv_file=train_csv_file, label_id_3d_list=error_data_label_id_list, error_file=train_error_csv_file, preds_3d=error_data_preds_list, labels_3d=error_data_labels_list ) elif net_id == "resnet3d": train_error_dataset = train_error_dataset_class( csv_file=train_csv_file, label_id_3d_list=error_data_label_id_list, error_file=train_error_csv_file, preds_3d=error_data_preds_list, labels_3d=error_data_labels_list ) elif net_id == "d2d": train_error_dataset = train_error_dataset_class( csv_file=train_csv_file, label_id_2d_list=error_data_label_id_list, error_file=train_error_csv_file, preds_2d=error_data_preds_list, labels_2d=error_data_labels_list ) if train_error_dataset is None: if len(device_index) > 1: raise ValueError(f"local rank: {args.local_rank}, train on error data, train_error_dataset is None, net_id: {net_id}") else: raise ValueError(f"local rank: {device_index[0]}, train on error data, train_error_dataset is None, net_id: {net_id}") train_error_dataloader = DataLoader( train_error_dataset, batch_size=current_train_batch_size, shuffle=True, num_workers=num_workers, collate_fn = custom_collate_fn_error ) if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train on error data, 原始训练数据, error_data_label_id_list: {len(error_data_label_id_list)}") logger.info(f"local rank: {args.local_rank}, train on error data, 原始训练数据, error_data_label_id_list: {error_data_label_id_list}") logger.info(f"local rank: {args.local_rank}, train on error data, 处理训练数据, train_error_dataloader: {len(train_error_dataloader)}") else: logger.info(f"local rank: {device_index[0]}, train on error data, 原始训练数据, error_data_label_id_list: {len(error_data_label_id_list)}") logger.info(f"local rank: {device_index[0]}, train on error data, 原始训练数据, error_data_label_id_list: {error_data_label_id_list}") logger.info(f"local rank: {device_index[0]}, train on error data, 处理训练数据, train_error_dataloader: {len(train_error_dataloader)}") # 错误样本单独训练 current_train_on_error_epochs = current_train_on_error_data_epoch_dict["error_train_epoch"] for idx_train_error_data_epoch in range(current_train_on_error_epochs): for idx_batch in train_error_dataloader: idx_train_error_data, idx_train_error_label = idx_batch if len(device_index) > 1: idx_train_error_data = idx_train_error_data.cuda() idx_train_error_label = idx_train_error_label.cuda() else: idx_train_error_data = idx_train_error_data.to(device_index[0]) idx_train_error_label = idx_train_error_label.to(device_index[0]) idx_train_error_y_pred = model(idx_train_error_data) idx_train_error_loss = criterion(idx_train_error_y_pred, idx_train_error_label) optimizer.zero_grad() idx_train_error_loss.backward() optimizer.step() if len(device_index) > 1: torch.distributed.all_reduce(idx_train_error_loss, op=torch.distributed.ReduceOp.AVG) if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train on error data, train_error_data_epoch: {idx_train_error_data_epoch+1}, train_error_epochs: {current_train_on_error_epochs}, current_train_epoch: {epoch+1}, train_epoch: {num_epochs}, loss: {idx_train_error_loss.item()}") else: logger.info(f"local rank: {device_index[0]}, train on error data, train_error_data_epoch: {idx_train_error_data_epoch+1}, train_error_epochs: {current_train_on_error_epochs}, current_train_epoch: {epoch+1}, train_epoch: {num_epochs}, loss: {idx_train_error_loss.item()}") if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train on error data, 错误样本单独训练结束") else: logger.info(f"local rank: {device_index[0]}, train on error data, 错误样本单独训练结束") del cls_train_all_label_id del cls_train_all_preds del cls_train_all_labels del current_positive_threshold del current_negative_threshold del positive_error_indices del negative_error_indices del log_train_error_data_str del error_indices del error_data_label_id_list del current_train_batch_size del error_file del train_error_csv_file del error_data_preds_list del error_data_labels_list del train_error_dataset del train_error_dataloader del current_train_on_error_epochs del idx_train_error_data del idx_train_error_label del idx_train_error_y_pred del idx_train_error_loss torch.cuda.empty_cache() train_all_preds = [] train_all_labels = [] train_all_label_id = [] train_all_label_id_2d = [] train_all_label_id_3d = [] train_all_preds_2d = [] train_all_preds_3d = [] train_all_labels_2d = [] train_all_labels_3d = [] error_data_dict = {} error_data_2d_dict = {} error_data_3d_dict = {} if len(device_index) > 1: torch.distributed.barrier() if (len(device_index) > 1 and args.local_rank == 0) or len(device_index) == 1: if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, train finished, best score: {best_val_metric_score}") else: logger.info(f"local rank: {device_index[0]}, train finished, best score: {best_val_metric_score}") save_path = os.path.join(save_dir, f"final_epoch_{epoch+1}_net_id_{net_id}.pth") if len(device_index) > 1: torch.save(model.module.state_dict(), save_path) else: torch.save(model.state_dict(), save_path) if len(device_index) > 1: logger.info(f"local rank: {args.local_rank}, final model saved to {save_path}") else: logger.info(f"local rank: {device_index[0]}, final model saved to {save_path}") def train_ddp(model, config, args): """ 多卡分布训练模型 """ try: device_index = config['device_index'] if len(device_index) > 1: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) else: device = torch.device(f"cuda:{device_index[0]}") print(f"train_ddp, device: {device_index}") if len(device_index) > 1: if torch.distributed.is_available() and not torch.distributed.is_initialized(): torch.distributed.init_process_group( backend="nccl", init_method="env://" ) world_size = torch.distributed.get_world_size() print(f"train_ddp, 多卡, local_rank: {args.local_rank}, world_size: {world_size}") print(f"train_ddp, 多卡, device: {device}") else: world_size = 1 print(f"train_ddp, 单卡, local_rank: {device_index[0]}, world_size: {world_size}") print(f"train_ddp, 单卡, device: {device}") set_seed() start_train_ddp(model, config, args) except Exception as e: print(f"error in train_ddp: {traceback.format_exc()}") raise e finally: if len(device_index) > 1 and torch.distributed.is_initialized(): print(f"finished ddp training, destroy process group") torch.distributed.destroy_process_group() else: print(f"finished single gpu training") def print_init_params(model, string ="s3d", logger=None): for name, param in model.named_parameters(): if len(param.shape) > 3: if logger is not None: logger.info(f"{string}, name: {name}, param: {param[0][0][0][:1]}") else: print(f"{string}, name: {name}, param: {param[0][0][0][:1]}") elif len(param.shape) > 2: if logger is not None: logger.info(f"{string}, name: {name}, param: {param[0][0][:1]}") else: print(f"{string}, name: {name}, param: {param[0][0][:1]}") elif len(param.shape) > 1: if logger is not None: logger.info(f"{string}, name: {name}, param: {param[0][:1]}") else: print(f"{string}, name: {name}, param: {param[0][:1]}") else: if logger is not None: logger.info(f"{string}, name: {name}, param: {param[:1]}") else: print(f"{string}, name: {name}, param: {param[:1]}") # print(f"{string}, name: {name}, param.shape: {param.shape}") def get_train_config(): from pytorch_train.train_2d3d_config import node_net_train_file_dict config = {} net_id = "2d3d" epochs = 20 local_rank_index = 0 ''' 1010_1020_2011_2021_2041_2031 1020_2011_2021_2031 2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2031_2021_2041 2011_2021_2031_2041_1010_1020_1016 2011_2021_2031_2041 ''' node_id = "2021_2031" date_id = "20241217" task_info = f"train_{net_id}_{node_id}_{date_id}_rotate_10_ddp_count_00000005" train_log_dir = "/df_lung/ai-project/cls_train/log/train" save_dir = "/df_lung/ai-project/cls_train/cls_ckpt" save_folder = f"{net_id}_{node_id}_{date_id}" log_file = os.path.join(train_log_dir, f"{task_info}.log") save_dir = os.path.join(save_dir, save_folder) train_data_dir = "/df_lung/cls_train_data/train_csv_data/" train_csv_file = None val_csv_file = None test_csv_file = None train_2d3d_data_2d_csv_file = None val_2d3d_data_2d_csv_file = None test_2d3d_data_2d_csv_file = None train_2d3d_data_3d_csv_file = None val_2d3d_data_3d_csv_file = None test_2d3d_data_3d_csv_file = None if net_id == "2d3d": train_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_train_file"] val_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_val_file"] test_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_test_file"] train_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_train_file"] val_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_val_file"] test_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_test_file"] else: train_csv_file = node_net_train_file_dict[(node_id, net_id)]["train_file"] val_csv_file = node_net_train_file_dict[(node_id, net_id)]["val_file"] test_csv_file = node_net_train_file_dict[(node_id, net_id)]["test_file"] s3d_file = "/df_lung/cls_train_data/encoder_3d_classifier_state_dict.pth" resnet_3d_file = "/df_lung/cls_train_data/resnet_classifier_state_dict.pth" d2d_dir = "/df_lung/cls_train_data/dino2" config['s3d_file'] = s3d_file config['resnet_3d_file'] = resnet_3d_file config['d2d_dir'] = d2d_dir logger = get_logger(log_file) logger.info(f"train_config, node_id: {node_id}") if net_id == "2d3d": model = Net2d3d() batch_size = 6 if local_rank_index != None else 24 elif net_id == "2d": model = Net2d() batch_size = 200 elif net_id == "3d": model = Net3d() batch_size = 6 elif net_id == "s3d": model = NetS3d() model.load_state_dict(torch.load(config['s3d_file'], weights_only=False, map_location="cpu")) batch_size = 1 elif net_id == "resnet3d": model = NetResNet3d(Bottleneck, [3, 4, 23, 3], 128, 128, 128) model.load_state_dict(torch.load(config['resnet_3d_file'], weights_only=True)) batch_size = 10 elif net_id == "d2d": model = NetD2d(pretrain_dir=config["d2d_dir"]) batch_size = 1 print_init_params(model, string="get_train_config, instance model", logger=logger) if net_id not in ["s3d", "resnet3d", "d2d"]: init_modules(model) if net_id == "2d3d" and node_id == "2021_2031": step_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241207/train_epoch_1_net_2d3d.pt" model.load_state_dict(torch.load(step_file, weights_only=True)) logger.info(f"net_id: {net_id}, node_id: {node_id}, step_file: {step_file}") elif net_id == "2d3d" and node_id == "2041_2031": step_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2041_2031_20241207/train_epoch_1_step_17000_net_2d3d.pt" model.load_state_dict(torch.load(step_file, weights_only=True)) logger.info(f"net_id: {net_id}, node_id: {node_id}, step_file: {step_file}") elif net_id == "2d3d" and node_id == "1010_1020_2011_2021_2041_2031": step_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_1010_1020_2011_2021_2041_2031_20241207/train_epoch_1_step_17000_net_2d3d.pt" model.load_state_dict(torch.load(step_file, weights_only=True)) logger.info(f"net_id: {net_id}, node_id: {node_id}, step_file: {step_file}") elif net_id == "2d3d" and node_id == "1010_1020_2011_2021_2031_2041": step_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_1010_1020_2011_2021_2031_2041_20241207/train_epoch_1_step_17000_net_2d3d.pt" model.load_state_dict(torch.load(step_file, weights_only=True)) logger.info(f"net_id: {net_id}, node_id: {node_id}, step_file: {step_file}") Path(train_log_dir).mkdir(parents=True, exist_ok=True) Path(save_dir).mkdir(parents=True, exist_ok=True) config['net_id'] = net_id config['local_rank_index'] = local_rank_index config['lr'] = 1e-5 config['weight_decay'] = 1e-4 config['logger'] = logger config['criterion'] = nn.BCELoss() config['step_size'] = 200 config['learning_rate_drop'] = 0.1 config['optimizer'] = Adam(model.parameters(), lr=config['lr']) config['scheduler'] = lr_scheduler.StepLR(config['optimizer'], step_size=config['step_size'], gamma=config['learning_rate_drop']) config['num_epochs'] = epochs config['device_index'] = [config['local_rank_index']] if config['local_rank_index'] != None else [0, 1, 2, 3] config['num_workers'] = 1 config['train_batch_size'] = batch_size config['val_batch_size'] = batch_size config['save_dir'] = save_dir config['val_interval'] = 1000000000000000 config['pos_label'] = 'positive' config['neg_label'] = 'negative' config['train_2d3d_data_2d_csv_file'] = train_2d3d_data_2d_csv_file config['val_2d3d_data_2d_csv_file'] = val_2d3d_data_2d_csv_file config['test_2d3d_data_2d_csv_file'] = test_2d3d_data_2d_csv_file config['train_2d3d_data_3d_csv_file'] = train_2d3d_data_3d_csv_file config['val_2d3d_data_3d_csv_file'] = val_2d3d_data_3d_csv_file config['test_2d3d_data_3d_csv_file'] = test_2d3d_data_3d_csv_file config['train_csv_file'] = train_csv_file config['val_csv_file'] = val_csv_file config['test_csv_file'] = test_csv_file config['val_metric'] = 'f1-score' config['threshold'] = 0.5 config['logger_train_cls_report_flag'] = False config['train_on_error_data_flag'] = False config['train_on_error_data_epoch_dict_list'] = [ {"train_epoch": config['num_epochs']//2, "error_train_epoch": 1, "positive_threshold": 0.3, "negative_threshold": 0.7, "error_train_batch_size": 1 }, {"train_epoch": config['num_epochs']-10, "error_train_epoch": 1, "positive_threshold": 0.5, "negative_threshold": 0.5, "error_train_batch_size": 1} ] return model, config if __name__ == "__main__": try: model, config = get_train_config() parser = argparse.ArgumentParser() parser.add_argument("--local-rank", type=int, default=-1, dest="local_rank") parser.add_argument("--use_zero", action="store_true") args = parser.parse_args() print(f"args: {args}") train_ddp(model, config, args) except Exception as e: print(f"error in main: {traceback.format_exc()}") raise e finally: if torch.distributed.is_initialized(): print("finished, destroy process group") torch.distributed.destroy_process_group() # python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 pytorch_train/train_2d3d.py --use_zero # python pytorch_train/train_2d3d.py