import os import sys import argparse import torch # import horovod.torch as hvd import numpy as np import lmdb import json import SimpleITK as sitk from pprint import pprint import logging from datetime import datetime import torch.distributed import torch.multiprocessing as mp import torch.distributed as dist import pdb # if __name__ == "__main__" and __package__ is None: # sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "..")) # __package__ = "lungDetection3D" import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # from grouplung.BaseDetector.config import get_cfg_defaults from BaseDetector.utils.gpu_utils import set_gpu from NoduleDetector.engine import NoduleDetection3D from k8s_utils import CParamFiller, CK8sPathWrapper def init_dist(backend='nccl', master_ip='127.0.0.1', port=29500): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') os.environ['MASTER_ADDR'] = master_ip os.environ['MASTER_PORT'] = str(port) rank = int(os.environ['RANK']) world_size = int(os.environ['WORLD_SIZE']) num_gpus = torch.cuda.device_count() local_rank = os.environ['LOCAL_RANK'] deviceid = eval(local_rank) % num_gpus torch.cuda.set_device(deviceid) print(fr'dist settings: local_rank {local_rank}, rank {rank}, worldsize {world_size}, gpus {num_gpus}, deviceid {deviceid}') dist.init_process_group(backend=backend) return rank, world_size def get_logger(name, task_name=None): file_handler = logging.StreamHandler(sys.stdout) logger = logging.getLogger(name) logger.setLevel(logging.INFO) logger.addHandler(file_handler) return logger def _log_msg(strmsg="\n"): if os.environ.get('RANK', -1) != -1: if torch.distributed.get_rank() == 0: if g_logger is not None: g_logger.info(strmsg) else: print(strmsg) else: print(strmsg) class MyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, datetime): return obj.__str__() else: return super(MyEncoder, self).default(obj) def split_tainval_and_create_new_db(mdb_path, train_db, test_db, split_rate): train_rate, test_rate = split_rate['train'], split_rate['test'] env = lmdb.open(mdb_path, readonly=True) txn = env.begin() all_niis = [] all_labels = [] for key, value in txn.cursor(): key = str(key, encoding='utf-8') value = str(value, encoding='utf-8') label_info = json.loads(value) all_niis.append(key) all_labels.append(label_info) txn.commit() env.close() train_idxs = list(np.random.choice(len(all_niis), int(len(all_niis) * train_rate), replace=False)) train_dicts = {} test_dicts = {} for i in range(len(all_niis)): if i in train_idxs: train_dicts[all_niis[i]] = all_labels[i] else: test_dicts[all_niis[i]] = all_labels[i] train_env = lmdb.open(train_db, map_size=int(1e9)) train_txn = train_env.begin(write=True) for key,values in train_dicts.items(): train_txn.put(key=str(key).encode(), value=json.dumps(values, cls=MyEncoder).encode()) train_txn.commit() train_env.close() test_env = lmdb.open(test_db, map_size=int(1e9)) test_txn = test_env.begin(write=True) for key,values in test_dicts.items(): test_txn.put(key=str(key).encode(), value=json.dumps(values, cls=MyEncoder).encode()) test_txn.commit() test_env.close() if __name__ == "__main__": parser = argparse.ArgumentParser(description="full functional execute script of Detection module.") group = parser.add_mutually_exclusive_group() # ddp parser.add_argument("--local_rank", default=-1, type=int) parser.add_argument('--port', type=int, help='port of server') parser.add_argument('--world_size', default=2, type=int) parser.add_argument('--rank', default=0, type=int) parser.add_argument('--master_ip', default='127.0.0.1', type=str) parser.add_argument('--job_data_root', default="/data/job_678/job_data_train", type=str) args = parser.parse_args() g_pathWrapper = CK8sPathWrapper(args.job_data_root) config_path = g_pathWrapper.get_input_inputconfig_filepath() if not os.path.isfile(config_path): print(fr'given config file {config_path} not exist !') else: # print(fr'using config file {config_path} to fill args !') cfg = CParamFiller.fillin_args(config_path, args) # config_path train_params = 'configs/train/input_config.json' cfg = CParamFiller.fillin_args(train_params, cfg) # cfg = get_cfg_defaults() # cfg.merge_from_file('/root/Documents/st_sample_codedata/grouplung/NoduleDetector/config.yaml') # cfg.freeze() g_logger = get_logger(__name__) rank, world_size = init_dist(backend='nccl', master_ip=args.master_ip, port=args.port) cfg.rank = rank cfg.world_size = world_size cfg.local_rank = os.environ['LOCAL_RANK'] _log_msg('pretrain_msg: ' + cfg.pretrain_msg) # 划分训练集和测试集,并得到训练需要的数据和db格式标签 cfg.data['data_loader']['train_db'] = g_pathWrapper.get_train_db_path() cfg.data['data_loader']['validate_db'] = g_pathWrapper.get_val_db_path() if cfg.rank == 0: _log_msg('rank: {}'.format(cfg.rank)) _log_msg('split trainval dataset in {}'.format(cfg.rank)) mdb_path = g_pathWrapper.get_output_preprocess_mdb_path() train_db_path = g_pathWrapper.get_train_db_path() test_db_path = g_pathWrapper.get_val_db_path() if len(os.listdir(train_db_path)) == 0 or len(os.listdir(test_db_path)) == 0: if len(os.listdir(train_db_path)) != 0: files = [os.path.join(train_db_path, file) for file in os.listdir(train_db_path)] for file in files: os.remove(file) if len(os.listdir(test_db_path)) != 0: files = [os.path.join(test_db_path, file) for file in os.listdir(test_db_path)] for file in files: os.remove(file) _log_msg('Need to create train and val test !!!') split_tainval_and_create_new_db(mdb_path, train_db_path, test_db_path, cfg.split) _log_msg(os.path.exists(g_pathWrapper.get_train_db_path())) _log_msg(os.path.exists(g_pathWrapper.get_val_db_path())) cfg.data['data_loader']['data_dir'] = g_pathWrapper.get_train_nii_path() # 模型输出的保存路径 cfg.training['saver']['saver_root'] = g_pathWrapper.get_output_train_dirpath() cfg.training['saver']['saver_train_bestmodel'] = g_pathWrapper.get_output_train_bestmodel_pth_filepath() cfg.training['saver']['saver_eval_pmd'] = g_pathWrapper.get_output_eval_performance_md_filepath() cfg.training['saver']['saver_eval_pjson'] = g_pathWrapper.get_output_eval_performance_filepath() cfg.training['saver']['saver_train_metrics'] = g_pathWrapper.get_output_train_trainingmetrics_filepath() cfg.training['saver']['saver_train_result'] = g_pathWrapper.get_output_train_trainresult_filepath() # 初始化 output/train/train_result.json train_result_dicts = {'successFlag': '', 'bestModelEpoch': 0} fp = cfg.training['saver']['saver_train_result'] with open(fp, 'w+') as file: json.dump(train_result_dicts, file, indent=4) # 如果提供预训练模型,则提供绝对路径 if cfg.pretrain_msg: cfg.pretrain_msg = os.path.join(g_pathWrapper.get_input_dirpath(), cfg.pretrain_msg) # set gpu, train gpus >= 1, val/pred gpu == 1 # set_gpu(cfg.training['num_gpus'], used_percent=0.1, log_fun=None) # os.environ['CUDA_VISIBLE_DEVICES'] = "1,2" # if cfg.train: # os.environ['CUDA_VISIBLE_DEVICES'] = "1,2" # set_gpu(cfg.training['num_gpus'], used_percent=0.1, log_fun=None) # if cfg.inference: # os.environ['CUDA_VISIBLE_DEVICES'] = "7" # set_gpu(num_gpu=1, used_percent=0.1) if cfg.train: if torch.distributed.get_rank() == 0: _log_msg(cfg) else: print(cfg) detector = NoduleDetection3D(cfg, log_fun=_log_msg) detector.do_train()