import os import sys import json import numpy as np from datetime import datetime import torch import torch.distributed import torch.multiprocessing as mp import torch.distributed as dist import logging import argparse import glob assert torch.cuda.is_available() sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "..")) __package__ = "lungClassification3D" from BaseClassifier.Se_Cls import Classification3D from BaseClassifier.utils.gpu_utils import set_gpu from k8s_utils import CParamFiller, CK8sPathWrapper def init_dist(backend='nccl', master_ip='127.0.0.1', port=29500): if mp.get_start_method(allow_none=True) is None: mp.set_start_method('spawn') os.environ['MASTER_ADDR'] = master_ip os.environ['MASTER_PORT'] = str(port) rank = int(os.environ['RANK']) world_size = int(os.environ['WORLD_SIZE']) num_gpus = torch.cuda.device_count() local_rank = os.environ['LOCAL_RANK'] deviceid = eval(local_rank) % num_gpus torch.cuda.set_device(deviceid) print(fr'dist settings: local_rank {local_rank}, rank {rank}, worldsize {world_size}, gpus {num_gpus}, deviceid {deviceid}') dist.init_process_group(backend=backend) return rank, world_size def get_logger(name, task_name=None): file_handler = logging.StreamHandler(sys.stdout) logger = logging.getLogger(name) logger.setLevel(logging.INFO) logger.addHandler(file_handler) return logger def _log_msg(strmsg="\n"): if torch.distributed.get_rank() == 0: if g_logger is not None: g_logger.info(strmsg) else: print(strmsg) class MyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, datetime): return obj.__str__() else: return super(MyEncoder, self).default(obj) # set_gpu(num_gpu=1, used_percent=0.1) if __name__ == '__main__': parser = argparse.ArgumentParser(description="full functional execute script of Detection module.") group = parser.add_mutually_exclusive_group() # ddp parser.add_argument("--local_rank", default=-1, type=int) parser.add_argument('--port', default=29500, type=int, help='port of server') parser.add_argument('--world_size', default=2, type=int) parser.add_argument('--rank', default=0, type=int) parser.add_argument('--master_ip', default='127.0.0.1', type=str) parser.add_argument('--job_data_root', default="", type=str) args = parser.parse_args() g_pathWrapper = CK8sPathWrapper(args.job_data_root) config_path = g_pathWrapper.get_input_inputconfig_filepath() if not os.path.isfile(config_path): print(fr'given config file {config_path} not exist !') else: # print(fr'using config file {config_path} to fill args !') cfg = CParamFiller.fillin_args(config_path, args) test_params = 'configs/val/input_config.json' cfg = CParamFiller.fillin_args(test_params, cfg) g_logger = get_logger(__name__) # cfg.testing['test_data_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'val_data.npy') # cfg.testing['test_info_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'val_info.npy') cfg.testing['test_data_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'val_split.json') cfg.pretrain_msg = os.path.join(g_pathWrapper.get_input_dirpath(), cfg.pretrain_msg) # if not os.path.exists(cfg.testing['test_data_path']) or not os.path.exists(cfg.testing['test_info_path']): # current_dirextory = os.path.abspath(__file__) # run_file = os.path.join(os.path.dirname(os.path.dirname(current_dirextory)), 'local_run_preprocess_lung.py') # print('Need to run {}'.format(run_file)) # syscommand = fr"python {run_file} --job_data_root {args.job_data_root}" # retcode = os.WEXITSTATUS(os.system(syscommand)) npy_path = g_pathWrapper.get_output_preprocess_npy_path() # data2info = {} # uids = [] # for file in os.listdir(npy_path): # uid = file[:-9] # if uid not in uids: # uids.append(uid) # for uid in uids: # data2info[uid + '_data.npy'] = uid + '_info.npy' # whole_data = [] # whole_info = [] # whole_label = [] # for key, value in data2info.items(): # key = os.path.join(npy_path, key) # value = os.path.join(npy_path, value) # whole_data.append(np.load(key)) # working_info = np.load(value, allow_pickle=True) # whole_info.append(working_info) # whole_label.append(working_info[:,-2]) # whole_data = np.concatenate(whole_data, axis=0) # whole_info = np.concatenate(whole_info, axis=0) # whole_label = np.concatenate(whole_label, axis=0) # np.save(cfg.testing['test_data_path'], whole_data) # np.save(cfg.testing['test_info_path'], whole_info) # print('data: ', cfg.testing['test_data_path']) # print('info: ', cfg.testing['test_info_path']) data_list = glob.glob(os.path.join(npy_path, '*', '*_data.npy')) val_data = {'data': data_list} with open(cfg.testing['test_data_path'], 'w+') as f: json.dump(val_data, f, indent=4) # 设置输出的文件路径 cfg.testing['eval_pmd'] = g_pathWrapper.get_output_eval_performance_md_filepath() cfg.testing['eval_pjson'] = g_pathWrapper.get_output_eval_performance_filepath() cfg.testing['eval_result'] = g_pathWrapper.get_output_eval_evalresult_filepath() cfg.testing['save_dir'] = g_pathWrapper.get_output_eval_dirpath() # 开始训练, mode choice in [training, testing] train_obj = Classification3D(cfg, cfg.mode, _log_msg) train_obj()