import os import sys import argparse import torch # import horovod.torch as hvd import numpy as np import lmdb import json import SimpleITK as sitk from pprint import pprint import logging from datetime import datetime import torch.distributed import torch.multiprocessing as mp import torch.distributed as dist import pdb # if __name__ == "__main__" and __package__ is None: # sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "..")) # __package__ = "lungDetection3D" import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # from grouplung.BaseDetector.config import get_cfg_defaults from BaseDetector.utils.gpu_utils import set_gpu from NoduleDetector.engine import NoduleDetection3D from k8s_utils import CParamFiller, CK8sPathWrapper def init_dist(backend='nccl', master_ip='127.0.0.1', port=29500): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') os.environ['MASTER_ADDR'] = master_ip os.environ['MASTER_PORT'] = str(port) rank = int(os.environ['RANK']) world_size = int(os.environ['WORLD_SIZE']) num_gpus = torch.cuda.device_count() local_rank = os.environ['LOCAL_RANK'] deviceid = eval(local_rank) % num_gpus torch.cuda.set_device(deviceid) print(fr'dist settings: local_rank {local_rank}, rank {rank}, worldsize {world_size}, gpus {num_gpus}, deviceid {deviceid}') dist.init_process_group(backend=backend) return rank, world_size def get_logger(name, task_name=None): file_handler = logging.StreamHandler(sys.stdout) logger = logging.getLogger(name) logger.setLevel(logging.INFO) logger.addHandler(file_handler) return logger def _log_msg(strmsg="\n"): if os.environ.get('RANK', -1) != -1: if torch.distributed.get_rank() == 0: if g_logger is not None: g_logger.info(strmsg) else: print(strmsg) else: print(strmsg) class MyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, datetime): return obj.__str__() else: return super(MyEncoder, self).default(obj) def split_tainval_and_create_new_db(mdb_path, train_db, test_db, split_rate): train_rate, test_rate = split_rate['train'], split_rate['test'] env = lmdb.open(mdb_path) txn = env.begin() all_niis = [] all_labels = [] for key, value in txn.cursor(): key = str(key, encoding='utf-8') value = str(value, encoding='utf-8') label_info = json.loads(value) all_niis.append(key) all_labels.append(label_info) txn.commit() env.close() train_idxs = list(np.random.choice(len(all_niis), int(len(all_niis) * train_rate), replace=False)) train_dicts = {} test_dicts = {} for i in range(len(all_niis)): if i in train_idxs: train_dicts[all_niis[i]] = all_labels[i] else: test_dicts[all_niis[i]] = all_labels[i] train_env = lmdb.open(train_db, map_size=int(1e9)) train_txn = train_env.begin(write=True) for key,values in train_dicts.items(): train_txn.put(key=str(key).encode(), value=json.dumps(values, cls=MyEncoder).encode()) train_txn.commit() train_env.close() test_env = lmdb.open(test_db, map_size=int(1e9)) test_txn = test_env.begin(write=True) for key,values in test_dicts.items(): test_txn.put(key=str(key).encode(), value=json.dumps(values, cls=MyEncoder).encode()) test_txn.commit() test_env.close() if __name__ == "__main__": print() parser = argparse.ArgumentParser(description="full functional execute script of Detection module.") group = parser.add_mutually_exclusive_group() # ddp parser.add_argument("--local_rank", default=-1, type=int) parser.add_argument('--port', default=29500, type=int, help='port of server') parser.add_argument('--world-size', default=2, type=int) parser.add_argument('--rank', default=0, type=int) parser.add_argument('--master_ip', default='127.0.0.1', type=str) parser.add_argument('--job_data_root', default="/data/job_678/job_data_train", type=str) args = parser.parse_args() g_pathWrapper = CK8sPathWrapper(args.job_data_root) config_path = g_pathWrapper.get_input_inputconfig_filepath() if not os.path.isfile(config_path): print(fr'given config file {config_path} not exist !') else: # print(fr'using config file {config_path} to fill args !') cfg = CParamFiller.fillin_args(config_path, args) # cfg = get_cfg_defaults() # cfg.merge_from_file('/root/Documents/st_sample_codedata/grouplung/NoduleDetector/config.yaml') # cfg.freeze() g_logger = get_logger(__name__) if cfg.train: rank, world_size = init_dist(backend='nccl', master_ip=args.master_ip, port=args.port) cfg.rank = rank cfg.world_size = world_size cfg.local_rank = os.environ['LOCAL_RANK'] # cfg.rank = 0 # cfg.world_size = 1 # cfg.local_rank = -1 else: cfg.rank = 0 cfg.world_size = 1 cfg.local_rank = -1 if cfg.train: # 划分训练集和测试集,并得到训练需要的数据和db格式标签 cfg.data['data_loader']['train_db'] = g_pathWrapper.get_train_db_path() cfg.data['data_loader']['validate_db'] = g_pathWrapper.get_val_db_path() if cfg.rank == 0: _log_msg('rank: {}'.format(cfg.rank)) _log_msg('split trainval dataset in {}'.format(cfg.rank)) mdb_path = g_pathWrapper.get_output_preprocess_mdb_path() train_db_path = g_pathWrapper.get_train_db_path() test_db_path = g_pathWrapper.get_val_db_path() if len(os.listdir(train_db_path)) == 0 or len(os.listdir(test_db_path)) == 0: if len(os.listdir(train_db_path)) != 0: files = [os.path.join(train_db_path, file) for file in os.listdir(train_db_path)] for file in files: os.remove(file) if len(os.listdir(test_db_path)) != 0: files = [os.path.join(test_db_path, file) for file in os.listdir(test_db_path)] for file in files: os.remove(file) _log_msg('Need to create train and val test !!!') split_tainval_and_create_new_db(mdb_path, train_db_path, test_db_path, cfg.split) _log_msg(os.path.exists(g_pathWrapper.get_train_db_path())) _log_msg(os.path.exists(g_pathWrapper.get_val_db_path())) cfg.data['data_loader']['data_dir'] = g_pathWrapper.get_train_nii_path() # 模型输出的保存路径 cfg.training['saver']['saver_root'] = g_pathWrapper.get_output_train_dirpath() cfg.training['saver']['saver_train_bestmodel'] = g_pathWrapper.get_output_train_bestmodel_pth_filepath() cfg.training['saver']['saver_eval_pmd'] = g_pathWrapper.get_output_eval_performance_md_filepath() cfg.training['saver']['saver_eval_pjson'] = g_pathWrapper.get_output_eval_performance_filepath() cfg.training['saver']['saver_train_metrics'] = g_pathWrapper.get_output_train_trainingmetrics_filepath() cfg.training['saver']['saver_train_result'] = g_pathWrapper.get_output_train_trainresult_filepath() # 初始化 output/train/train_result.json train_result_dicts = {'successFlag': '', 'bestModelEpoch': 0} fp = cfg.training['saver']['saver_train_result'] with open(fp, 'w+') as file: json.dump(train_result_dicts, file, indent=4) # 如果提供预训练模型,则提供绝对路径 if cfg.pretrain_msg: cfg.pretrain_msg = os.path.join(g_pathWrapper.get_input_dirpath(), cfg.pretrain_msg) elif cfg.test: # 如果存在cancle_flag,其的绝对路径 cfg.check_cancleflag = g_pathWrapper.get_input_cancle_flag() # 所有数据都作为验证 cfg.data['data_loader']['data_dir'] = g_pathWrapper.get_train_nii_path() cfg.data['data_loader']['validate_db'] = g_pathWrapper.get_output_preprocess_mdb_path() if len(os.listdir(cfg.data['data_loader']['data_dir'])) == 0 or len(os.listdir(cfg.data['data_loader']['validate_db'])) == 0: current_dirextory = os.path.abspath(__file__) run_file = os.path.join(os.path.dirname(os.path.dirname(current_dirextory)), 'local_run_preprocess_lung.py') print('Need to run {}'.format(run_file)) syscommand = fr"python {run_file} --job_data_root {args.job_data_root}" retcode = os.WEXITSTATUS(os.system(syscommand)) print('All ready have data !!!') # 模型输出的保存路径 cfg.training['saver']['saver_eval_pmd'] = g_pathWrapper.get_output_eval_performance_md_filepath() cfg.training['saver']['saver_eval_pjson'] = g_pathWrapper.get_output_eval_performance_filepath() cfg.training['saver']['saver_eval_result'] = g_pathWrapper.get_output_eval_evalresult_filepath() # 验证时必定提供模型参数,从相对路径转换为绝对路径 if cfg.pretrain_msg: cfg.pretrain_msg = os.path.join(g_pathWrapper.get_input_dirpath(), cfg.pretrain_msg) elif cfg.pred: # 需要的路径位置 ## 保存预处理的结果 cfg.testing['preprocess_save_path'] = g_pathWrapper.get_output_preprocess_preprocessfiledetails_dirpath() # 如果存在cancle_flag,其的绝对路径 cfg.check_cancleflag = g_pathWrapper.get_input_cancle_flag() # 所有数据都作为验证 cfg.data['data_loader']['data_dir'] = g_pathWrapper.get_train_nii_path() cfg.data['data_loader']['test_db'] = g_pathWrapper.get_output_preprocess_mdb_path() # 模型输出的保存路径 cfg.testing['test_tmp_rpns'] = g_pathWrapper.get_tmp_test_rpns_dirpath() cfg.testing['test_tmp_dir'] = g_pathWrapper.get_output_tmp_dirpath() cfg.testing['test_nii_dir'] = g_pathWrapper.get_train_nii_path() cfg.testing['saver_pred_result'] = g_pathWrapper.get_output_eval_evalresult_filepath() cfg.testing['saver_pred_details'] = g_pathWrapper.get_output_eval_evalfiledetails_dirpath() # 预测时必定提供模型参数,从相对路径转换为绝对路径 if cfg.pretrain_msg: cfg.pretrain_cp = cfg.pretrain_msg cfg.pretrain_msg = os.path.join(g_pathWrapper.get_input_dirpath(), cfg.pretrain_msg[0]['model_name']) # set gpu, train gpus >= 1, val/pred gpu == 1 # set_gpu(cfg.training['num_gpus'], used_percent=0.1, log_fun=None) # os.environ['CUDA_VISIBLE_DEVICES'] = "1,2" # if cfg.train: # os.environ['CUDA_VISIBLE_DEVICES'] = "1,2" # set_gpu(cfg.training['num_gpus'], used_percent=0.1, log_fun=None) # if cfg.inference: # os.environ['CUDA_VISIBLE_DEVICES'] = "7" # set_gpu(num_gpu=1, used_percent=0.1) if cfg.train: if torch.distributed.get_rank() == 0: _log_msg(cfg) else: print(cfg) if cfg.train: detector = NoduleDetection3D(cfg, log_fun=_log_msg) detector.do_train() elif cfg.test: detector = NoduleDetection3D(cfg, 'valid', log_fun=_log_msg) detector.do_val() elif cfg.pred: detector = NoduleDetection3D(cfg, "test") detector.do_test()