run.py 11.7 KB
import os
import sys
import json
import numpy as np
from datetime import datetime
import torch
import torch.distributed
import torch.multiprocessing as mp
import torch.distributed as dist
import logging
import argparse

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".."))
__package__ = "lungClassification3D"
from BaseClassifier.Se_Cls import Classification3D
from BaseClassifier.utils.gpu_utils import set_gpu
from k8s_utils import CParamFiller, CK8sPathWrapper
# os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3,4,5,6,7"
def init_dist(backend='nccl',
              master_ip='127.0.0.1',
              port=29500):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')
    os.environ['MASTER_ADDR'] = master_ip
    os.environ['MASTER_PORT'] = str(port)
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    num_gpus = torch.cuda.device_count()
    local_rank = os.environ['LOCAL_RANK'] 
    deviceid = eval(local_rank) % num_gpus
    torch.cuda.set_device(deviceid)

    print(fr'dist settings: local_rank {local_rank}, rank {rank}, worldsize {world_size}, gpus {num_gpus}, deviceid {deviceid}')

    dist.init_process_group(backend=backend)
    return rank, world_size

def get_logger(name, task_name=None):
    file_handler = logging.StreamHandler(sys.stdout)
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    logger.addHandler(file_handler)
    return logger

def _log_msg(strmsg="\n"):
    if torch.distributed.get_rank() == 0:
        if g_logger is not None:
            g_logger.info(strmsg)
        else:
            print(strmsg)

class MyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, datetime):
            return obj.__str__()
        else:
            return super(MyEncoder, self).default(obj)
        
# set_gpu(num_gpu=1, used_percent=0.1)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="full functional execute script of Detection module.")
    group = parser.add_mutually_exclusive_group()
    # ddp
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument('--port', default=29500, type=int, help='port of server')
    parser.add_argument('--world-size', default=2, type=int)
    parser.add_argument('--rank', default=0, type=int)
    parser.add_argument('--master_ip', default='127.0.0.1', type=str)

    parser.add_argument('--job_data_root', default="/fileser51/baoyc/Work/lung_tasks/GroupLung/Datasets/cls_data/job_data_cls_train_2", type=str)
    args = parser.parse_args()

    g_pathWrapper = CK8sPathWrapper(args.job_data_root)
    config_path = g_pathWrapper.get_input_inputconfig_filepath()
    if not os.path.isfile(config_path):
        print(fr'given config file {config_path} not exist !')
    else:
        # print(fr'using config file {config_path} to fill args !')
        cfg = CParamFiller.fillin_args(config_path, args)
    
    g_logger = get_logger(__name__)

    if cfg.mode == 'training':
        rank, world_size = init_dist(backend='nccl', master_ip=args.master_ip, port=args.port)
        cfg.rank = rank
        cfg.world_size = world_size
        cfg.local_rank = os.environ['LOCAL_RANK']

    if cfg.mode == 'training':
        cfg.training['train_data_path'] = os.path.join(g_pathWrapper.get_output_preprocess_preprocessfiledetails_dirpath(), 'train_data.npy')
        cfg.training['train_info_path'] = os.path.join(g_pathWrapper.get_output_preprocess_preprocessfiledetails_dirpath(), 'train_info.npy')
        cfg.training['val_data_path'] = os.path.join(g_pathWrapper.get_output_preprocess_preprocessfiledetails_dirpath(), 'val_data.npy')
        cfg.training['val_info_path'] = os.path.join(g_pathWrapper.get_output_preprocess_preprocessfiledetails_dirpath(), 'val_info.npy')

        # 之后ddp是,需要在判断条件中加入: cfg.rank == 0
        if (not os.path.exists(cfg.training['train_data_path']) or not os.path.exists(cfg.training['train_info_path']) or not os.path.exists(cfg.training['val_data_path']) or not os.path.exists(cfg.training['val_info_path'])) and cfg.rank == 0:
            # 读取所有的npy信息,并划分训练集和测试集
            npy_path = g_pathWrapper.get_output_preprocess_npy_path()
            data2info = {}
            uids = []
            for file in os.listdir(npy_path):
                uid = file[:-9]
                if uid not in uids:
                    uids.append(uid)
            for uid in uids:
                data2info[uid + '_data.npy'] = uid + '_info.npy'
            
            whole_data = []
            whole_info = []
            whole_label = []
            for key, value in data2info.items():
                key = os.path.join(npy_path, key)
                value = os.path.join(npy_path, value)
                whole_data.append(np.load(key))
                working_info = np.load(value, allow_pickle=True)
                whole_info.append(working_info)
                whole_label.append(working_info[:,-2])
                
            whole_data = np.concatenate(whole_data, axis=0)
            whole_info = np.concatenate(whole_info, axis=0)
            whole_label = np.concatenate(whole_label, axis=0)

            # 根据whole_label划分训练集和测试集
            label2idx = {}
            for label in range(1, cfg.training['num_class'] + 1):
                label2idx[label] = []
            for idx in range(len(whole_label)):
                label2idx[whole_label[idx]].append(idx)

            # get train and test npy and info
            train_data = []
            val_data = []
            train_info = []
            val_info = []
            for key in list(label2idx.keys()):
                choice_idxs = list(np.random.choice(label2idx[key], int(len(label2idx[key]) * cfg.split['train']), replace=False))
                relax_idxs = [idx for idx in label2idx[key] if idx not in choice_idxs]
                for idx in choice_idxs:
                    train_data.append(whole_data[idx][np.newaxis,:])
                    train_info.append(whole_info[idx].reshape(-1, len(whole_info[idx])))
                for idx in relax_idxs:
                    val_data.append(whole_data[idx][np.newaxis,:])
                    val_info.append(whole_info[idx].reshape(-1, len(whole_info[idx])))
            
            train_data = np.concatenate(train_data, axis=0)
            train_info = np.concatenate(train_info, axis=0)
            val_data = np.concatenate(val_data, axis=0)
            val_info = np.concatenate(val_info, axis=0)

            np.save(cfg.training['train_data_path'], train_data)
            np.save(cfg.training['train_info_path'], train_info)
            np.save(cfg.training['val_data_path'], val_data)
            np.save(cfg.training['val_info_path'], val_info)
        torch.distributed.barrier()
        # 设置输出的文件路径
        cfg.training['base_path'] = g_pathWrapper.get_output_train_dirpath()
        cfg.training['save_path'] = g_pathWrapper.get_output_train_bestmodel_dirpath()
        cfg.training['writer_path'] = g_pathWrapper.get_output_train_writer_path()

        cfg.training['eval_pmd'] = g_pathWrapper.get_output_eval_performance_md_filepath()
        cfg.training['eval_pjson'] = g_pathWrapper.get_output_eval_performance_filepath()
        cfg.training['train_metrics'] = g_pathWrapper.get_output_train_trainingmetrics_filepath()
        cfg.training['train_result'] = g_pathWrapper.get_output_train_trainresult_filepath() 
        
        # 开始训练, mode choice in [training, testing]
        train_obj = Classification3D(cfg, cfg.mode, _log_msg)
        train_obj()
    elif cfg.mode == 'testing':
        cfg.testing['test_data_path'] = os.path.join(g_pathWrapper.get_output_preprocess_preprocessfiledetails_dirpath(), 'val_data.npy')
        cfg.testing['test_info_path'] = os.path.join(g_pathWrapper.get_output_preprocess_preprocessfiledetails_dirpath(), 'val_info.npy')

        npy_path = g_pathWrapper.get_output_preprocess_npy_path()
        data2info = {}
        uids = []
        for file in os.listdir(npy_path):
            uid = file[:-9]
            if uid not in uids:
                uids.append(uid)
        for uid in uids:
            data2info[uid + '_data.npy'] = uid + '_info.npy'
        
        whole_data = []
        whole_info = []
        whole_label = []
        for key, value in data2info.items():
            key = os.path.join(npy_path, key)
            value = os.path.join(npy_path, value)
            whole_data.append(np.load(key))
            working_info = np.load(value, allow_pickle=True)
            whole_info.append(working_info)
            whole_label.append(working_info[:,-2])
            
        whole_data = np.concatenate(whole_data, axis=0)
        whole_info = np.concatenate(whole_info, axis=0)
        whole_label = np.concatenate(whole_label, axis=0)

        # # 根据whole_label划分训练集和测试集
        # label2idx = {}
        # for label in range(1, cfg.training['num_class'] + 1):
        #     label2idx[label] = []
        # for idx in range(len(whole_label)):
        #     label2idx[whole_label[idx]].append(idx)

        # # get train and test npy and info
        # train_data = []
        # val_data = []
        # train_info = []
        # val_info = []
        # for key in list(label2idx.keys()):
        #     choice_idxs = list(np.random.choice(label2idx[key], int(len(label2idx[key]) * cfg.split['train']), replace=False))
        #     relax_idxs = [idx for idx in label2idx[key] if idx not in choice_idxs]
        #     for idx in choice_idxs:
        #         train_data.append(whole_data[idx][np.newaxis,:])
        #         train_info.append(whole_info[idx].reshape(-1, len(whole_info[idx])))
        #     for idx in relax_idxs:
        #         val_data.append(whole_data[idx][np.newaxis,:])
        #         val_info.append(whole_info[idx].reshape(-1, len(whole_info[idx])))
        
        # train_data = np.concatenate(train_data, axis=0)
        # train_info = np.concatenate(train_info, axis=0)
        # val_data = np.concatenate(val_data, axis=0)
        # val_info = np.concatenate(val_info, axis=0)
        all_num = whole_data.shape[0]
        inter = all_num // 15
        for i in range(15):
            if i == 14:
                data = whole_data[i*inter: , :]
                info = whole_info[i*inter: , :]
            else:
                data = whole_data[i*inter: (i+1)*inter, :]
                info = whole_info[i*inter: (i+1)*inter, :]
            print(data.shape, info.shape)
            save_data_path = os.path.join(g_pathWrapper.get_output_preprocess_preprocessfiledetails_dirpath(), 'val_data_' + str(i) + '.npy')
            save_info_path = os.path.join(g_pathWrapper.get_output_preprocess_preprocessfiledetails_dirpath(), 'val_info_' + str(i) + '.npy')
            np.save(save_data_path, data)
            np.save(save_info_path, info)
            cfg.testing['test_data_path'] = save_data_path
            cfg.testing['test_info_path'] = save_info_path

            # 设置输出的文件路径
            cfg.testing['eval_pmd'] = g_pathWrapper.get_output_eval_performance_md_filepath()
            cfg.testing['eval_pjson'] = g_pathWrapper.get_output_eval_performance_filepath()
            cfg.testing['eval_result'] = g_pathWrapper.get_output_eval_evalresult_filepath()
            cfg.testing['save_dir'] = g_pathWrapper.get_output_eval_dirpath()
            # 开始训练, mode choice in [training, testing]
            train_obj = Classification3D(cfg, cfg.mode, _log_msg)
            train_obj()