train.py 8 KB
# import logging
# logger = logging.getLogger()
# fh = logging.FileHandler('TrainLog/01_nodules.log',encoding='utf-8')
# fh.setLevel(logging.DEBUG)
# sh = logging.StreamHandler()
# sh.setLevel(logging.INFO)
# formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s')
# fh.setFormatter(formatter)  
# sh.setFormatter(formatter)

# logger.addHandler(fh)   
# logger.addHandler(sh)

# logger.setLevel(10)  

import os
import sys
import json
import numpy as np
from datetime import datetime
import torch
import torch.distributed
import torch.multiprocessing as mp
import torch.distributed as dist
import logging
import argparse
from Segmentation import Segmentation3D
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".."))
__package__ = "LungNoduleSegmentation"
from k8s_utils import CParamFiller, CK8sPathWrapper
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
import glob

def init_dist(backend='nccl',
              master_ip='127.0.0.1',
              port=29500):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')
    os.environ['MASTER_ADDR'] = master_ip
    os.environ['MASTER_PORT'] = str(port)
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    num_gpus = torch.cuda.device_count()
    local_rank = os.environ['LOCAL_RANK'] 
    deviceid = eval(local_rank) % num_gpus
    torch.cuda.set_device(deviceid)

    print(fr'dist settings: local_rank {local_rank}, rank {rank}, worldsize {world_size}, gpus {num_gpus}, deviceid {deviceid}')

    dist.init_process_group(backend=backend)
    return rank, world_size

def get_logger(name, task_name=None):
    file_handler = logging.StreamHandler(sys.stdout)
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    logger.addHandler(file_handler)
    return logger

def _log_msg(strmsg="\n"):
    if torch.distributed.get_rank() == 0:
        if g_logger is not None:
            g_logger.info(strmsg)
        else:
            print(strmsg)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="full functional execute script of Segmentation module.")
    group = parser.add_mutually_exclusive_group()
    # ddp
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument('--port', default=29500, type=int, help='port of server')
    parser.add_argument('--world_size', default=1, type=int)
    parser.add_argument('--rank', default=0, type=int)
    parser.add_argument('--master_ip', default='127.0.0.1', type=str)

    parser.add_argument('--job_data_root', default="/root/Documents/GroupLung/Datasets/seg_data/job_data_seg_train", type=str)
    args = parser.parse_args()

    g_pathWrapper = CK8sPathWrapper(args.job_data_root)
    config_path = g_pathWrapper.get_input_inputconfig_filepath()
    cfg = CParamFiller.fillin_args(config_path, args) # config_path
    train_params = 'configs/train/input_config.json'
    cfg = CParamFiller.fillin_args(train_params, cfg)
    
    g_logger = get_logger(__name__)

    rank, world_size = init_dist(backend='nccl', master_ip=args.master_ip, port=args.port)
    cfg.rank = rank
    cfg.world_size = world_size
    cfg.local_rank = os.environ['LOCAL_RANK']
    if cfg.pretrain_msg:
        cfg.pretrain_msg = os.path.join(g_pathWrapper.get_input_dirpath(), cfg.pretrain_msg)
    print(cfg)
    cfg.common['train_data_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'train_split.json')
    # cfg.common['train_mask_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'train_mask.npy')
    # cfg.common['train_info_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'train_info.npy')

    cfg.common['val_data_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'val_split.json')
    # cfg.common['val_mask_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'val_mask.npy')
    # cfg.common['val_info_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'val_info.npy')

    train_rate = cfg.split['train']
    if (not os.path.exists(cfg.common['train_data_path'])) and cfg.rank == 0:
        # 读取所有的npy信息,并划分训练集和测试集
        # whole_data = []
        # whole_mask = []
        # whole_info = []
        npy_path = g_pathWrapper.get_output_preprocess_npy_path()
        data_list = glob.glob(os.path.join(npy_path, '*', '*_data.npy'))
        all_len = len(data_list)
        train_len = int(all_len * train_rate)
        train_list = data_list[:train_len]
        val_list = data_list[train_len:]
        train_data = {'data': train_list}
        val_data = {'data': val_list}

        with open(cfg.common['train_data_path'], 'w+') as f:
            json.dump(train_data, f, indent=4)
        with open(cfg.common['val_data_path'], 'w+') as f:
            json.dump(val_data, f, indent=4)

        # for img_path in glob.glob(os.path.join(npy_path, '*', '*_data.npy')):
        #     mask_path = img_path.replace('_data.npy', '_mask.npy')
        #     info_path = img_path.replace('_data.npy', '_info.npy')
        #     whole_data.append(np.load(img_path)[None, ...])
        #     whole_mask.append(np.load(mask_path)[None, ...])
        #     whole_info.append(np.load(info_path)[None, ...])
        # whole_data = np.concatenate(whole_data, axis=0)
        # whole_mask = np.concatenate(whole_mask, axis=0)
        # whole_info = np.concatenate(whole_info, axis=0)
        # print(whole_data.shape, whole_mask.shape, whole_info.shape)
        # # get train and test npy and info, mask
        # train_data = []
        # val_data = []
        # train_mask = []
        # val_mask = []
        # train_info = []
        # val_info = []

        # indices = range(whole_data.shape[0])
        # train_idxs = list(np.random.choice(indices, int(whole_data.shape[0] * train_rate), replace=False))
        # val_idxs = [idx for idx in indices if idx not in train_idxs]
        # for idx in train_idxs:
        #     train_data.append(whole_data[idx][np.newaxis,:])
        #     train_mask.append(whole_mask[idx][np.newaxis,:])
        #     train_info.append(whole_info[idx][np.newaxis,:])

        # for idx in val_idxs:
        #     val_data.append(whole_data[idx][np.newaxis,:])
        #     val_mask.append(whole_mask[idx][np.newaxis,:])
        #     val_info.append(whole_info[idx][np.newaxis,:])

        # train_data = np.concatenate(train_data, axis=0)
        # train_info = np.concatenate(train_info, axis=0)
        # train_mask = np.concatenate(train_mask, axis=0)
        # val_data = np.concatenate(val_data, axis=0)
        # val_info = np.concatenate(val_info, axis=0)
        # val_mask = np.concatenate(val_mask, axis=0)

        # np.save(cfg.common['train_data_path'], train_data)
        # np.save(cfg.common['train_mask_path'], train_mask)
        # np.save(cfg.common['train_info_path'], train_info)

        # np.save(cfg.common['val_data_path'], val_data)
        # np.save(cfg.common['val_mask_path'], val_mask)
        # np.save(cfg.common['val_info_path'], val_info)


    # torch.distributed.barrier()
    # 设置输出的文件路径
    cfg.common['base_path'] = g_pathWrapper.get_output_train_dirpath()
    cfg.common['save_path'] = g_pathWrapper.get_output_train_bestmodel_dirpath()
    cfg.common['writer_path'] = g_pathWrapper.get_output_train_writer_path()

    cfg.common['eval_pmd'] = g_pathWrapper.get_output_eval_performance_md_filepath()
    cfg.common['eval_pjson'] = g_pathWrapper.get_output_eval_performance_filepath()
    cfg.common['train_metrics'] = g_pathWrapper.get_output_train_trainingmetrics_filepath()
    cfg.common['train_result'] = g_pathWrapper.get_output_train_trainresult_filepath() 
    

    if cfg.rank == 0:
        train_result_dicts = {'successFlag': '', 'bestModelEpoch': 0}
        with open(cfg.common['train_result'], 'w+') as file:
            json.dump(train_result_dicts, file, indent=4)
    # 开始训练, mode choice in [training, testing]
    train_obj = Segmentation3D(cfg, g_logger)
    train_obj()