#!/usr/bin/python3 #coding=utf-8 import argparse import os import time import numpy as np import shutil import sys from tqdm import tqdm import json from python_tree import color_print_dir import random import torch import pandas import lmdb # from trans_dicom_2_nii import MyEncoder from quality_control import control_single from anno_json_wrapper import annojson2cls, annolabels2labelidlabelmap, CLS_UNDEFINED_LABEL from k8s_utils import CK8sPathWrapper, CParamFiller, CDatasetWrapper, get_fullfilepaths_according2pattern parser = argparse.ArgumentParser(description='PyTorch DataBowl3 Detector') parser.add_argument('--model_type', '-m', metavar='MODEL', default='mobilenet', help='model') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 32)') parser.add_argument('--epochs', default=10, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('-b', '--batch-size', default=16, type=int, metavar='N', help='mini-batch size (default: 16)') parser.add_argument('--lr', default=2e-3, type=float, metavar='LR', help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--save-freq', default=10, type=int, metavar='S', help='save frequency') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--save-dir', default=None, type=str, metavar='SAVE', help='directory to save checkpoint (default: none)') parser.add_argument('--test', default=0, type=int, metavar='TEST', help='1 do test evaluation, 0 not') parser.add_argument('--split', default=8, type=int, metavar='SPLIT', help='In the test phase, split the image to 8 parts') parser.add_argument('--gpu', default='0', type=str, metavar='N', help='use gpu, "all" or "0,1,2,3" or "0,2" etc') parser.add_argument('--n_test', default=4, type=int, metavar='N', help='number of gpu for test') parser.add_argument('--cross', default=None, type=str, metavar='N', help='which data cross be used') parser.add_argument('--cluster', action='store_true', default=False, help='enables CUDA training (default: False)') parser.add_argument('--config', default='', type=str) parser.add_argument('--classes', default='', type=str) parser.add_argument('--class_id', default=1, type=int) # ================================== # from devkit.core.dist_utils import init_dist # 通用的命令行参数 parser.add_argument("--local_rank", default=0, type=int) parser.add_argument('--port', default=23456, type=int, help='port of server') parser.add_argument('--world-size', default=1, type=int) parser.add_argument('--rank', default=0, type=int) parser.add_argument('--master_ip', default='10.100.39.11', type=str) parser.add_argument('--use_ddp', action='store_true', default=True, help='script started from ddp(default: False)') # 平台传递的数据集相关 parser.add_argument('--use_webui', action='store_true', default=True, help='script started from webui(default: False)') # 模型和配置相关 parser.add_argument('--train-ratio', default=0.8, type=float) parser.add_argument('--local_data', action='store_true', default=False, help='enables localdata training') # ================================== parser.add_argument('--job_data_root', default="/data/job_715/job_data_preprocess", type=str) parser.add_argument('--annotation_label', default=None) parser.add_argument('--preprocess_workers', default=0, type=int) # ================================== # ================================== args = parser.parse_args() g_pathWrapper = CK8sPathWrapper(args.job_data_root) config_path = g_pathWrapper.get_input_inputconfig_filepath() if not os.path.isfile(config_path): print(fr'given config file {config_path} not exist !') else: # print(fr'using config file {config_path} to fill args !') args = CParamFiller.fillin_args(config_path, args) pre_params = 'configs/preprocess/input_config.json' current_dirextory = os.path.abspath(__file__) pre_params = os.path.join(os.path.dirname(current_dirextory), 'configs/infer/input_config.json') args = CParamFiller.fillin_args(pre_params, args) # ================================== ###################################################### USE_WEBUI = args.use_webui USE_DDP = args.use_ddp g_writer = None import logging import sys def get_logger(name, task_name=None): file_handler = logging.StreamHandler(sys.stdout) logger = logging.getLogger(name) logger.setLevel(logging.INFO) logger.addHandler(file_handler) return logger g_logger = get_logger(__name__) ###################################################### ###################################################### g_local_rank = None g_rank = None g_world_size = None def _log_msg(strmsg="\n"): global g_rank if g_rank == 0: if g_logger is not None: g_logger.info(strmsg) else: print(strmsg) ###################################################### ###################################################### from utils import constants as constants ###################################################### constants.USE_WEBUI = USE_WEBUI constants.USE_DDP = USE_DDP ###################################################### def datestr(): now = time.gmtime() return '{}{:02}{:02}_{:02}{:02}{:02}_{}'.format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec, int(round(time.time() * 1000))) def main(): global args, best_loss, g_rank, g_world_size, g_logger, g_writer ###################################################### ###################################################### # rank, world_size = init_dist(backend='nccl', master_ip=args.master_ip, port=args.port) # g_rank = rank # g_world_size = world_size # g_local_rank = os.environ['LOCAL_RANK'] g_rank = 0 g_world_size = 1 g_local_rank = 0 ###################################################### ###################################################### constants.G_WORLD_SIZE = _ensure_params(g_world_size, int) constants.G_RANK = _ensure_params(g_rank, int) constants.G_LOCAL_RANK = _ensure_params(g_local_rank, int) constants.NUM_OF_EPOCHS = _ensure_params(args.epochs, int) constants.MAX_LEARNING_RATE = _ensure_params(args.lr, float) constants.NUM_WORKERS = _ensure_params(args.workers, int) constants.BATCH_SIZE = _ensure_params(args.batch_size, int) ###################################################### # ########################################## # train_ratio = args.train_ratio # train_ratio = max(train_ratio, 0.0) # train_ratio = min(train_ratio, 1.0) # val_ratio = 1.0 - train_ratio # constants.TEST_MODE= True if train_ratio <= 1e-6 else False # ########################################## # ================================== if not constants.USE_WEBUI or args.local_data: save_path = constants.DATA_PATH_ROOT else: save_path = g_pathWrapper.get_output_preprocess_preprocessfiledetails_dirpath() # preprocess结果目录 # ================================== ############################################ _log_msg(fr'doing prepare, save path {save_path}') _ensure_path(save_path) fulldicom_path = fr'{g_pathWrapper.get_output_tmp_dirpath()}' _ensure_path(fulldicom_path) fullnpy_path = fr'{save_path}/{constants.SUBDIR_NPY}' _ensure_path(fullnpy_path) ############################################ # ================================== _log_msg('using mlops data') data_config_json = g_pathWrapper.get_input_dataset_filepath() _log_msg(fr'target data config {data_config_json}') _log_msg(fr'annotation labels {args.annotation_label}') # ================================== # anno_labels = args.annotation_label # label_id, label_map = annolabels2labelidlabelmap(anno_labels) # _log_msg(fr'label_id {label_id}, label_map {label_map}') # ================================== # 将zip解压缩,dcm保存在tmp中,[uid,dicom,json] datainfo_list = CDatasetWrapper.get_dataset_list(data_config_json, g_pathWrapper.get_output_tmp_dirpath()) _log_msg(fr'loading datameta length {len(datainfo_list)}') _log_msg(fr'done...') # 这个函数需要修改,检测任务不需要分那么多类别 # working_classmap_id, working_classmap = _load_data_meta_from_MLOPs( # datainfo_list=datainfo_list, given_label_id=args.working_label) # ================================== # if label_id != -1 and label_map is not None: # working_classmap_id = label_id # working_classmap = label_map # _log_msg(fr'using classmap {working_classmap_id}, details {working_classmap}') # ================================== # ================================== # setup classes # if working_classmap is not None and working_classmap_id != -1: # args.classes = str(working_classmap) # args.class_id = working_classmap_id # save class settings # tar_clmap = os.path.join(save_path, 'class_settings.json') # _log_msg(fr'writing classmap to {tar_clmap}...') # with open(tar_clmap, 'w+') as fp: # json.dump( # { # "working_classmap": working_classmap, # "working_classmap_id": working_classmap_id # }, fp, indent=4) # _log_msg(fr'done...') # ================================== ## 进行质控,不满足指控条件的将进行删除 if args.if_qcontrol: exclude_uids = [] dcm_paths = [os.path.join(fulldicom_path, dcm) for dcm in os.listdir(fulldicom_path)] for path in dcm_paths: control_res = control_single(path) if control_res != 'SUCCESS': exclude_uids.append(path.split('/')[-1]) print('Remove: {} for {}'.format(path, control_res)) shutil.rmtree(path) ## need to renew datainfo_list rm_idxes = [] for idx, info in enumerate(datainfo_list): uid = '.'.join(info[-1][0].split('/')[-1].split('.')[:-1]) if uid in exclude_uids: rm_idxes.append(idx) rm_idxes = sorted(rm_idxes, reverse=True) for i in rm_idxes: del datainfo_list[i] _log_msg(fr'preprocessing...') preprocess_LUNGMLOPs(input_path=fulldicom_path, save_path=fullnpy_path, crop_size=args.crop_size, diameter_th=args.diameter_th, given_label_id=args.working_label, datainfos=datainfo_list, log_func=_log_msg, preprocess_workers=args.preprocess_workers) _log_msg(fr'done...') ############################################ map_fullpath2info = {} for info in datainfo_list: file_id = info[0] dirname = '.'.join(os.path.basename(info[2][0]).split('.')[:-1]) npy_path = os.path.join(fullnpy_path, dirname + '_data.npy') map_fullpath2info[npy_path] = file_id if map_fullpath2info is not None: # ================================== preprocessroot = g_pathWrapper.get_output_preprocess_dirpath() result_items = [] for k,v in map_fullpath2info.items(): with open(fr'{save_path}/{os.path.basename(k)[:-9]}_dataid.txt', 'w') as fp: fp.write(str(v)) for k,v in map_fullpath2info.items(): result_items.append( { "uid": v, "urls": get_fullfilepaths_according2pattern(preprocessroot, os.path.basename(k)[:-9]), } ) preprocessfilepath = g_pathWrapper.get_output_preprocess_preprocessresult_filepath() with open(preprocessfilepath, 'w+') as fp: json.dump( { "successFlag": "SUCCESS", "resultItems": result_items }, fp, indent=4) # ================================== ############################### constants.DATA_PATH_DCM = fulldicom_path constants.DATA_PATH_NPY = fullnpy_path constants.DATA_PATH = fr'{g_pathWrapper.get_output_tmp_dirpath()}/{constants.SUBDIR_TRAINVAL}' ############################### color_print_dir(constants.DATA_PATH_NPY) def preprocess_LUNGMLOPs(input_path, save_path, crop_size, diameter_th, given_label_id, datainfos, preprocess_workers=0, log_func=None): from preprocess_func import convert_dcm_2_npy if log_func is None: log_func = print # deal with dcms convert_dcm_2_npy(input_path, save_path, crop_size, diameter_th, given_label_id, datainfos, preprocess_workers=preprocess_workers, log_func=log_func) def _load_data_meta_from_MLOPs(datainfo_list, given_label_id=-1): KEY_NIIGZ = 'nii.gz' KEY_MHD = 'mhd' KEY_JSON = 'json' KEY_CSV = 'csv' output_map = {} # ================================== # 解析datainfo_list、并构建成我们切实可用的数据 working_classmap = None working_classmap_id = -1 for data_info in datainfo_list: print(data_info) data_id = data_info[0] suid = data_info[1] dcm_paths = data_info[2] anno_paths = data_info[3] _log_msg(fr'data_id {data_id}, dcmpath {dcm_paths}, annopath {anno_paths}') dcmpath = None if len(dcm_paths) > 0: # dcmpath = os.path.abspath(os.path.dirname(dcm_paths[0])) dcmpath = dcm_paths[0] else: dcmpath = dcm_paths if dcmpath is None: _log_msg(fr'{data_info} dcmpath not valid') continue ############################## working_classmap_id = -1 working_classmap = None if anno_paths is None: _log_msg(fr'{data_info} annopath not valid, regard data as pure-predict data') if CLS_UNDEFINED_LABEL not in output_map: output_map[CLS_UNDEFINED_LABEL] = [] output_map[CLS_UNDEFINED_LABEL].append([dcmpath, data_id]) else: cls_info_list = None annopath = None if len(anno_paths) > 0: annopath = anno_paths[0] else: annopath = anno_paths if KEY_JSON in annopath or KEY_CSV in annopath: try: _, tmp_working_classmap_id, tmp_working_classmap = \ annojson2cls(annopath, given_labelid=given_label_id) except Exception as ex: _log_msg(fr'{data_info} not valid, process cls failed, msg {ex}') if tmp_working_classmap_id > 0 and tmp_working_classmap is not None: working_classmap = tmp_working_classmap working_classmap_id = tmp_working_classmap_id return working_classmap_id, working_classmap def _ensure_path(given_path): if not os.path.exists(given_path): os.makedirs(given_path) def _ensure_params(param, giventype=int): return param if type(param) == giventype else eval(param) if __name__ == '__main__': main() exit() try: status = main() except Exception as ex: print(fr'exception occured, msg {ex}') status = -1 sys.exit(status)