# -*- coding:utf-8 -*- import os import sys import six import glob import json import random import shutil import datetime import numpy as np import pandas as pd import scipy.ndimage as nd from tqdm import tqdm import datetime import logging import torch import torch.nn as nn from torch import optim from matplotlib import pyplot as plt from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ import copy from tensorboardX import SummaryWriter sys.path.append('./utils') sys.path.append('./model') from .utils.RWconfig import LoadJson from .utils.ReadData import load_data, load_single_data from .model.modelBuild import build_model from .model.BasicModules import weight_init from .utils.DataGen import SurDataSet, InferDataSet, TestDataSet from .utils.OnlineEval import eval_net, eval_netSE, eval_netFSE from .utils.OfflineEval import CalculateClsScore, CalculateAuc, CalculateClsScoreByTh from .utils.loss_func import * from .utils.gradcam import GradCam, GuidedBackpropReLUModel from .utils.ImagePro import ShowEdges, HFFilter from .utils.TripleLoss import TripletFocalLoss from .utils.TripleBin import TripletBinFocal from .utils.GradientBoost import GradientBoostFocal from k8s_utils import CMetricsWriter import torch.distributed as dist from sklearn.metrics import precision_score, recall_score import copy from pydicom import dicomio logger = logging.getLogger() fh = logging.FileHandler('Train.log', encoding='utf-8') sh = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s') fh.setFormatter(formatter) sh.setFormatter(formatter) logger.addHandler(fh) logger.addHandler(sh) logger.setLevel(10) import os import numpy as np import SimpleITK as sitk import cv2 import imageio from tqdm import tqdm import sys from PIL import Image, ImageDraw, ImageFont def norm(image, hu_min=-1000.0, hu_max=600.0): image = (np.clip(image.astype(np.float32), hu_min, hu_max) - hu_min) / float(hu_max - hu_min) return image * 255. def img2gif(tmp_path, raw_path): target_path = os.path.join(os.path.dirname(tmp_path), 'save_images') if not os.path.exists(target_path): os.makedirs(target_path) tmp_files = os.listdir(tmp_path) for file in tmp_files: print('deal with {}'.format(file)) npy_info = np.load(os.path.join(tmp_path, file, 'all_rpns.npy')) raw_img = sitk.ReadImage(os.path.join(raw_path, file + '.nii.gz')) raw_img = sitk.GetArrayFromImage(raw_img) all_z = [int(i) for i in list(npy_info[:,2])] min_z, max_z = min(all_z)-3, max(all_z)+3 used_slices = [] for z in all_z: temp = [] for i in range(z-3, z+3): temp.append(i) used_slices.append(temp) all_images = [] for it_z in range(min_z, max_z): u_idx = [] for key, value in enumerate(used_slices): if it_z in value: u_idx.append(key) u_idx = sorted(u_idx) print('have {} box'.format(len(u_idx))) if len(u_idx) == 0: all_images.append(cv2.cvtColor(norm(raw_img[it_z,:,:]), cv2.COLOR_GRAY2BGR)) else: worked_slice = norm(raw_img[it_z,:,:]) worked_slice = cv2.cvtColor(worked_slice, cv2.COLOR_GRAY2BGR) for idx in u_idx: z,y,x,dim_z,dim_y,dim_x = npy_info[idx][1:] start_y = max(0, y - dim_y / 2) end_y = min(worked_slice.shape[0], y + dim_y / 2) start_z = max(0, z - dim_z / 2) end_z = min(worked_slice.shape[1], z + dim_z / 2) cv2.rectangle(worked_slice, (int(start_y), int(start_z)), (int(end_y), int(end_z)), (0,0,255), 2) all_images.append(worked_slice) for idx, img in enumerate(all_images): save_path = os.path.join(target_path, file) if not os.path.exists(save_path): os.makedirs(save_path) cv2.imwrite(os.path.join(save_path, '{:03d}.png'.format(idx)), img) output_gif_path = os.path.join(os.path.dirname(target_path), 'save_gifs') if not os.path.exists(output_gif_path): os.makedirs(output_gif_path) for file in os.listdir(target_path): files = [] for f in os.listdir(os.path.join(target_path, file)): files.append(f) files.sort(key=lambda x: x[:-4]) frames = [] for i in tqdm(range(len(files))): im = imageio.imread(os.path.join(target_path, file) + '/' + files[i]) frames.append(im) fps = 24.0 imageio.mimsave(os.path.join(output_gif_path, file+'.gif'), frames, 'GIF', duration=1/fps) def load_ct_from_dicom(dcm_path, sort_by_distance=True): class DcmInfo(object): def __init__(self, dcm_path, series_instance_uid, acquisition_number, sop_instance_uid, instance_number, image_orientation_patient, image_position_patient): super(DcmInfo, self).__init__() self.dcm_path = dcm_path self.series_instance_uid = series_instance_uid self.acquisition_number = acquisition_number self.sop_instance_uid = sop_instance_uid self.instance_number = instance_number self.image_orientation_patient = image_orientation_patient self.image_position_patient = image_position_patient self.slice_distance = self._cal_distance() def _cal_distance(self): normal = [self.image_orientation_patient[1] * self.image_orientation_patient[5] - self.image_orientation_patient[2] * self.image_orientation_patient[4], self.image_orientation_patient[2] * self.image_orientation_patient[3] - self.image_orientation_patient[0] * self.image_orientation_patient[5], self.image_orientation_patient[0] * self.image_orientation_patient[4] - self.image_orientation_patient[1] * self.image_orientation_patient[3]] distance = 0 for i in range(3): distance += normal[i] * self.image_position_patient[i] return distance def is_sop_instance_uid_exist(dcm_info, dcm_infos): for item in dcm_infos: if dcm_info.sop_instance_uid == item.sop_instance_uid: return True return False def get_dcm_path(dcm_info): return dcm_info.dcm_path reader = sitk.ImageSeriesReader() if sort_by_distance: dcm_infos = [] files = os.listdir(dcm_path) for file in files: file_path = os.path.join(dcm_path, file) dcm = dicomio.read_file(file_path, force=True) _series_instance_uid = dcm.SeriesInstanceUID _sop_instance_uid = dcm.SOPInstanceUID _instance_number = dcm.InstanceNumber _acquisition_number = dcm.AcquisitionNumber _image_orientation_patient = dcm.ImageOrientationPatient _image_position_patient = dcm.ImagePositionPatient dcm_info = DcmInfo(file_path, _series_instance_uid, _acquisition_number, _sop_instance_uid, _instance_number, _image_orientation_patient, _image_position_patient) if is_sop_instance_uid_exist(dcm_info, dcm_infos): continue dcm_infos.append(dcm_info) dcm_infos.sort(key=lambda x: x.slice_distance) dcm_series = list(map(get_dcm_path, dcm_infos)) else: dcm_series = reader.GetGDCMSeriesFileNames(dcm_path) reader.SetFileNames(dcm_series) sitk_image = reader.Execute() return sitk_image class Metric(object): def __init__(self, name): self.name = name self.sum = torch.tensor(0.) self.n = torch.tensor(0.) def update(self, val): val = val.clone().detach() output_tensors = [val.clone().to(val) for _ in range(torch.distributed.get_world_size())] dist.all_gather(output_tensors, val) self.sum += torch.sum(torch.stack(output_tensors)).cpu().detach() self.n += 1 @property def avg(self): return self.sum / self.n def accuracy(outputs, labels): _, preds = torch.max(outputs, dim=1) # get the accuracy of number preds correctly # if math.isnan(torch.sum(preds == labels).item() / len(preds)): # print('preds:', preds) # print('check: ', preds == labels) return torch.tensor(torch.sum(preds == labels).item() / len(preds)) def precision(outputs, labels): _, preds = torch.max(outputs, dim=1) # get the accuracy of number preds correctly return torch.tensor(precision_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')) def recall(outputs, labels): _, preds = torch.max(outputs, dim=1) # get the accuracy of number preds correctly return torch.tensor(recall_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')) class Classification3D(object): def __init__(self, cfg, mode, log_msg): self.cfg = cfg self.mode = mode self.log_msg = log_msg if mode == 'training': self.mode_cfg = cfg.training else: logger.info('load testing parameters...') self.mode_cfg = cfg.testing if mode == 'training': self.metrics_writer = CMetricsWriter( filename=cfg.training['train_metrics'], headernames=['epoch', 'accuracy', 'f1_score', 'precision', 'recall', 'lr', 'lossTotal', 'performanceAccuracy'] ) self.basic_file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) self.n_input_channel = 1 self._clsInfoDecode() if mode == 'training': # self.traindata_npy = self.mode_cfg['train_data_path'] # self.traininfo_npy = self.mode_cfg['train_info_path'] # self.valdata_npy = self.mode_cfg['val_data_path'] # self.valinfo_npy = self.mode_cfg['val_info_path'] # self.testdata_npy = [] # self.testinfo_npy = [] self.traindata = self.mode_cfg['train_data_path'] self.valdata = self.mode_cfg['val_data_path'] self.testdata = [] print(self.traindata, self.valdata) else: # self.traindata_npy = [] # self.traininfo_npy = [] # self.valdata_npy = [] # self.valinfo_npy = [] # self.testdata_npy = self.mode_cfg['test_data_path'] # self.testinfo_npy = self.mode_cfg['test_info_path'] self.traindata = [] self.valdata = [] self.testdata = self.mode_cfg['test_data_path'] self.ParametersDecode() self.mode_cfg['model_mode'] = 'basic' # self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') def __call__(self): ''' 训练/infer流程 1. buildModel&compileModel 2. 载入数据 3. 根据选择模式,进行训练or inference ''' # self._LoadData(self.mode) if self.mode == 'training': self._buildModel() self._compileModel() self._trainModel() elif self.mode == 'infer': self._buildModel() self._inferModel() else: self._buildModel() self._testModel() def _clsInfoDecode(self): self.cls_map_dict = eval(self.mode_cfg['cls_map_dict']) self.ori_cls_types = set([self.cls_map_dict[key] for key in self.cls_map_dict.keys()]) self.input_shape = [self.mode_cfg['patch_height'], self.mode_cfg['patch_height'], self.mode_cfg['patch_depth']] self.cls_name_map_dict = eval(self.mode_cfg['cls_name_map_dict']) # self.save_path = self.mode_cfg['save_path'] # if (not os.path.exists(self.save_path)) and self.mode == 'training': # os.makedirs(self.save_path) if type(self.mode_cfg['cls_label_index']) == list: self.mode_cfg['num_task'] = len(self.mode_cfg['cls_label_index']) else: self.mode_cfg['num_task'] = 1 def _buildModel(self): ''' buildModel 流程 1. 读取模型参数并修改部分 1. num_task:有几个要分类的任务 2. num_class:有几类(目前多个要分类的任务类别数目必须相同) 3. arcface_norm_choice:arcface是否进行归一化 4. easy_margin: arcface相关参数 如果设置为True,只要cos_theta大于0,就设置为乘性压缩的角度 否则,要cos_theta大于th,设置为乘性压缩角度,cos_theta小于th的情况,设置为减性压缩角度 5. freeze_blocks:不训练的block ''' cfg = self.model_params cfg['num_task'] = self.mode_cfg['num_task'] cfg['num_class'] = self.mode_cfg['num_class'] cfg['arcface_norm_choice'] = self.mode_cfg.get('norm_choice', False) cfg['easy_margin'] = self.mode_cfg.get('easy_margin', True) cfg['freeze_blocks'] = eval(self.mode_cfg.get('freeze_blocks', "[]")) if cfg['num_class'] == 2: cfg['num_class'] = 1 self.n_input_channel = cfg['n_input_channel'] if 'n_input_channel' in cfg.keys() else self.n_input_channel self.net = build_model(self.mode_cfg['model_name'], cfg) # print('model:',self.net) self.net_init = copy.deepcopy(self.net) if self.cfg.pretrain_msg: print('pretrain weight:', self.cfg.pretrain_msg) # if self.cfg.pretrain_msg.split('.')[-1] == 'pt': try: logger.info("load model from %s" % self.cfg.pretrain_msg) need_keys = [key for key,_ in self.net.named_parameters()] ## print('get need_keys') logger.info('get need_keys') checkpoints = torch.load(self.cfg.pretrain_msg, map_location=torch.device('cpu')) print('get checkpoints') # print(checkpoints.keys()) logger.info('get checkpoints') if self.cfg.mode == 'testing': new_checkpoints = {key[7:]:value for key, value in checkpoints.items() if key[7:] in need_keys} else: new_checkpoints = {key:value for key, value in checkpoints.items() if key in need_keys} print('Start load ...') # print(list(new_checkpoints.keys())[0], new_checkpoints[list(new_checkpoints.keys())[0]]) logger.info('Start load ...') self.net_init.load_state_dict(new_checkpoints) print('Have already load pretrain_msg !') logger.info('Have already load pretrain_msg !') except: self.net_init = torch.jit.load(self.cfg.pretrain_msg) print('load pt!!!!!') # for a,b in self.net.named_parameters(): # break # logger.info(b) # print(b) ## DDP logger.info(self.mode) if self.mode == 'training': print('start') logger.info('start') self.net = self.net.cuda() logger.info('cuda') logger.info(self.cfg.local_rank) logger.info(torch.cuda.device_count()) self.net = torch.nn.parallel.DistributedDataParallel(self.net, device_ids=[int(self.cfg.local_rank)]) # broadcast_buffers=False, # find_unused_parameters=True) logger.info('have ddp') else: self.net = self.net_init def _compileModel(self): self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.cfg.lr) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=self.mode_cfg['patience'], factor=0.5, verbose=True) if self.mode_cfg['num_class'] > 2: self.criterion = nn.CrossEntropyLoss() else: self.criterion = nn.BCEWithLogitsLoss() # conf_paths = [self.filename, self.model_param_name] # if self.mode == 'training': # conf_paths += [self.trainfile_conf_path, self.valfile_conf_path] # for conf_path in conf_paths: # filename = conf_path.split('/')[-1] # target_save_path = os.path.join(self.save_path, filename) # shutil.copy(conf_path, target_save_path) # else: # conf_paths += [self.test_conf_path] def _DecodeGeneratorParam(self): high_freq_filter = self.mode_cfg['high_freq_filter'] if 'high_freq_filter' in self.mode_cfg.keys() else False rescale_intensity = self.mode_cfg['rescale_intensity'] if 'rescale_intensity' in self.mode_cfg.keys() else False nonfix_crop = self.mode_cfg['nonfix_crop'] if 'nonfix_crop' in self.mode_cfg.keys() else False aug = True if self.mode == 'training' else False model_mode = self.mode_cfg['model_mode'] if 'model_mode' in self.mode_cfg.keys() else 'basic' reg_map_dict = eval(self.mode_cfg['reg_map_dict']) if 'reg_map_dict' in self.mode_cfg.keys() else None self.regression_choice = self.mode_cfg['regression_choice'] if 'regression_choice' in self.mode_cfg.keys() else False self.generator_parameters = { 'cls_index': self.mode_cfg['cls_label_index'], 'cls_map_dict': self.cls_map_dict, 'aug': aug, 'nonfix_crop': nonfix_crop, 'diameter_index': self.mode_cfg['diameter_index'], 'diameter_max_th': self.mode_cfg['diameter_max_th'], 'diameter_min_th': self.mode_cfg['diameter_min_th'], 'num_class': self.mode_cfg['num_class'], 'HU_min': self.mode_cfg['HU_min'], 'HU_max': self.mode_cfg['HU_max'], 'cls_types': self.ori_cls_types, 'n_input_channel': self.n_input_channel, 'high_freq_filter': high_freq_filter, 'rescale_intensity': rescale_intensity, "input_shape": self.input_shape, "model_mode": model_mode, "regression_choice": self.regression_choice, "reg_map_dict": reg_map_dict, 'cls_balance': self.mode_cfg.get('cls_balance', False) } def _DecodeDataParam(self, filename): paths = LoadJson(filename) data_paths = [record['data_path'] for record in paths] info_paths = [record['info_path'] for record in paths] return {'data_paths': data_paths, 'info_paths': info_paths} # def _LoadData(self, mode): # ''' # 载入npy格式的数据 # ''' # if mode == 'training': # self.train_data, self.train_info = load_data([self.traindata_npy, self.traininfo_npy]) # self.val_data, self.val_info = load_data([self.valdata_npy, self.valinfo_npy]) # else: # print('testdata_npy and info: ',self.testdata_npy, self.testinfo_npy) # self.test_data, self.test_info = load_data([self.testdata_npy, self.testinfo_npy]) def _GetLoss(self, prob_diff, gamma): model_mode = self.mode_cfg['model_mode'] if 'model_mode' in self.mode_cfg.keys() else 'basic' loss_choice = self.mode_cfg['loss_func'] if 'loss_func' in self.mode_cfg.keys() else 'focal' lambda_val = self.mode_cfg['lambda_val'] if 'lambda_val' in self.mode_cfg.keys() else 1 margin_val = self.mode_cfg['margin_val'] if 'margin_val' in self.mode_cfg.keys() else 0.0 print('model_mode is ', model_mode) ''' gradientBoost的参数,是否使用gradient_boost/选择保留几个类别 ''' gradientBoost_choice = self.mode_cfg['gradientBoost'] if 'gradientBoost' in self.mode_cfg.keys() else False gradientBoost_k_val = self.mode_cfg['gradientBoost_k'] if 'gradientBoost_k' in self.mode_cfg.keys() else 1 ''' 基于focal loss基本的weights(1/该类样本量->归一化),是否要对某一类的weights进行修改,如果否就设置为None,否则设置为对应的loss ''' alpha_multi_ratio = eval(self.mode_cfg['alpha_multi_ratio']) if 'alpha_multi_ratio' in self.mode_cfg.keys() else None ''' 得到regression_task的loss ''' self.regression_loss = torch.nn.MSELoss() if self.regression_choice else None self.regression_loss = torch.nn.SmoothL1Loss() if self.mode_cfg.get('regression_loss', 'MSE') == 'SmoothL1Loss' else self.regression_loss self.convert_softmax = True print("self.mode_cfg['num_task']: ", self.mode_cfg['num_task']) if self.mode_cfg['focal_loss'] and self.mode_cfg['num_task'] == 1: # ratios = self.train_dataset._GenerateRatios() ratios = [0.25, 0.25, 0.25, 0.25] if self.mode_cfg['num_class'] <= 2: alpha = ratios[0] if model_mode == 'deux': self.criterion = SeLossWithBCEFocalLoss(alpha=alpha, gamma=gamma, lambda_val=lambda_val, prob_diff=prob_diff) elif model_mode == 'triple': self.criterion = TripletBinFocal(self.cfg['num_class'], alpha=alpha, gamma=gamma, lambda_val=lambda_val, margin=margin_val) else: self.criterion = BCEFocalLoss(alpha=alpha, gamma=gamma) else: ''' 多类情况下的loss ''' alpha = ratios if model_mode == 'deux': # ratios = [0.25, 0.25, 0.25, 0.25] print('parameters-alpha: ', ratios) self.criterion = SE_FocalLoss(self.mode_cfg['num_class'], alpha=ratios, gamma=gamma, lambda_val=lambda_val, gradient_boost=gradientBoost_choice, alpha_multi_ratio=alpha_multi_ratio, convert_softmax=self.convert_softmax, k=gradientBoost_k_val, margin=margin_val, new_se_loss=self.mode_cfg.get('new_se_loss', False)) print('self.alpha: ', self.criterion.alpha) elif model_mode == 'triple': self.criterion = TripletFocalLoss(self.mode_cfg['num_class'], alpha=ratios, gamma=gamma, lambda_val=lambda_val, margin=margin_val) else: if gradientBoost_choice: self.criterion = GradientBoostFocal(self.mode_cfg['num_class'], alpha=ratios, gamma=gamma) else: # ratios = [0.25, 0.25, 0.25, 0.25] print('parameters-alpha: ', ratios) self.criterion = FocalLoss(self.mode_cfg['num_class'], alpha=ratios, gamma=gamma, alpha_multi_ratio=alpha_multi_ratio, convert_softmax=self.convert_softmax) print('self.alpha: ', self.criterion.alpha) else: alpha = None # return alpha def _trainOnBatch(self, epoch, batch, label_dtype): imgs, labels, label_diff = batch[0] label_diff = [single_label_diff.type(label_dtype).cuda() for single_label_diff in label_diff] imgs = [single_set_img.type(torch.float32).cuda() for single_set_img in imgs] labels = [single_label.type(label_dtype).cuda() for single_label in labels] prediction = self.net(imgs, labels=labels) # 计算训练输出 whole_pred = torch.cat(prediction, dim=0) whole_labels = torch.cat(labels, dim=0) acc = accuracy(whole_pred, whole_labels) out_precision = precision(whole_pred, whole_labels) out_recall = recall(whole_pred, whole_labels) out_f1 = 2*out_precision*out_recall / (out_precision + out_recall) ''' 网络预测的前len(labels)是logits。len(labels)->2*len(labels)部分如果存在就是features ''' logits = [pred.squeeze(-1) for pred in prediction[:len(labels)]] features = [pred.squeeze(-1) for pred in prediction[len(labels):len(labels) * 2]] cons = prediction[-1].squeeze(-1) loss_input = [logits] + [labels] + [label_diff] loss = self.criterion(loss_input) # self.metrics_writer.append_one_line( # [epoch, acc.item(), out_f1.item(), out_precision.item(), out_recall.item(), self.optimizer.param_groups[0]['lr'], loss.item(), acc.item()] # ) return loss, acc, out_f1, out_precision, out_recall def _trainOnBatchWithReg(self, epoch, batch, label_dtype, device): regression_ratio = self.mode_cfg.get('regression_ratio', 0.05) imgs, labels, label_diff = batch[0] label_diff = [single_label_diff.cuda().to(device=device, dtype=label_dtype) for single_label_diff in label_diff] imgs = [single_set_img.cuda().to(device=device, dtype=torch.float32) for single_set_img in imgs] cls_labels = [single_label[0].cuda().to(device=device, dtype=label_dtype) for single_label in labels] reg_labels = [single_label[1].cuda().to(device=device, dtype=torch.float32) for single_label in labels] prediction = self.net(imgs, labels=cls_labels) logits = [pred[0].squeeze(-1) for pred in prediction[:len(labels)]] reg_outputs = [pred[1].squeeze(-1) for pred in prediction[:len(labels)]] loss_input = [logits] + [cls_labels] + [label_diff] loss = self.criterion(loss_input) reg_loss = 0 for sample_idx in range(len(reg_outputs)): reg_loss_val = self.regression_loss(reg_outputs[sample_idx], reg_labels[sample_idx]) if sample_idx == 0: reg_loss = reg_loss_val else: reg_loss += reg_loss_val loss = regression_ratio * reg_loss + loss # 计算训练输出 whole_pred = torch.cat(prediction, dim=0) whole_labels = torch.cat(labels, dim=0) acc = accuracy(whole_pred, whole_labels) out_precision = precision(whole_pred, whole_labels) out_recall = recall(whole_pred, whole_labels) out_f1 = 2*out_precision*out_recall / (out_precision + out_recall) # self.metrics_writer.append_one_line( # [epoch, acc, out_f1, out_precision, out_recall, self.optimizer.param_groups[0]['lr'], loss.item()] # ) return def _trainModel(self): ########## Add parameters ''' 设置训练的device ''' # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') ''' 如果pretrain_weight_path不为空,则load进来 ''' # if self.cfg.pretrain_msg and self.cfg.rank == 0: # logger.info("load model from %s" % self.cfg.pretrain_msg) # need_keys = [key for key,_ in self.net.named_parameters()] # checkpoints = torch.load(self.cfg.pretrain_msg) # new_checkpoints = {key:value for key, value in checkpoints.items() if key in need_keys} # self.net.load_state_dict(new_checkpoints) # for name, value in self.net.named_parameters(): # print(name, value.shape) ''' 读取关于loss和数据处理的参数 nonfix_crop:输入的大小是否一致 train_aug_choice:train的数据是否进行augmentation gamma:focal loss参数 feature_sim: prob_diff: ''' nonfix_crop = self.mode_cfg['nonfix_crop'] if 'nonfix_crop' in self.mode_cfg.keys() else False train_aug_choice = self.mode_cfg['train_aug_choice'] if 'train_aug_choice' in self.mode_cfg.keys() else False # print('type: ', type(self.cfg)) # print('cfg: ', self.cfg) gamma = self.cfg.lr_gamma if hasattr(self.cfg, 'lr_gamma') else 1 prob_diff = self.mode_cfg['prob_diff'] if 'prob_diff' in self.mode_cfg.keys() else False ############ generate params for generator self._DecodeGeneratorParam() # train_dataset = SurDataSet(self.train_data, self.train_info, **self.generator_parameters) train_dataset = SurDataSet(self.traindata, **self.generator_parameters) self.generator_parameters['aug'] = False # val_dataset = SurDataSet(self.val_data, self.val_info, **self.generator_parameters) val_dataset = SurDataSet(self.valdata, **self.generator_parameters) balance_data_kws = self.mode_cfg['balance_data'] if 'balance_data' in self.mode_cfg.keys() else [0, 0] modes = ['val' if val == 0 else 'train' for val in balance_data_kws] # train_dataset._GenerateTrainData(mode=modes[0]) # val_dataset._GenerateTrainData(mode=modes[1]) self.train_dataset = train_dataset ''' 根据参数情况得到训练的loss ''' self._GetLoss(prob_diff, gamma) train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) train_loader = DataLoader(train_dataset, batch_size=self.cfg.batch_size, sampler=train_sampler) val_loader = DataLoader(val_dataset, batch_size=self.cfg.batch_size, sampler=val_sampler) lr = self.cfg.lr batch_size = self.cfg.batch_size global_step = 0 epochs = self.cfg.epochs criterion = self.criterion val_criterion = self.criterion label_dtype = torch.long if self.mode_cfg['num_class'] > 2 else torch.float32 n_train = int(round(len(train_dataset) / self.cfg.batch_size)) best_val_score = 1000000 for epoch in range(epochs): train_sampler.set_epoch(epoch) self.net.train() # self.net.to(device) self.net.cuda() epoch_loss = 0 print('size of train dataset is ', len(train_dataset)) if torch.distributed.get_rank() == 0: epoch_acc = 0 epoch_f1 = 0 epoch_precision = 0 epoch_recall = 0 epoch_num = 0 with tqdm(total=n_train, desc=f'Epoch{epoch + 1}/{epochs}', unit='img') as pbar: batch_count = 0 for batch in zip(train_loader): batch_count += 1 loss_values = [] # if self.regression_choice: # loss = self._trainOnBatchWithReg(epoch, batch, label_dtype, device) # else: loss, acc, out_f1, out_precision, out_recall = self._trainOnBatch(epoch, batch, label_dtype) # logger.info('torch.distributed.get_rank(): ', torch.distributed.get_rank()) if torch.distributed.get_rank() == 0: # logger.info('deal with epoch_num ...') epoch_acc += acc.item() epoch_f1 += out_f1.item() epoch_precision += out_precision.item() epoch_recall += out_recall.item() epoch_num += 1 self.optimizer.zero_grad() loss.backward() epoch_loss += loss.item() nn.utils.clip_grad_value_(self.net.parameters(), 1e-1) self.optimizer.step() global_step += 1 pbar.update() if torch.distributed.get_rank() == 0: print('self.metrics_writer.append_one_line -- loss: ', loss) self.metrics_writer.append_one_line( [epoch, epoch_acc / epoch_num, epoch_f1 / epoch_num, epoch_precision / epoch_num, epoch_recall / epoch_num, self.optimizer.param_groups[0]['lr'], loss.item(), epoch_acc / epoch_num] ) # random.shuffle(train_dataset.indices) print('val_criterion: ', val_criterion) val_score = eval_netSE(self.net, val_dataset, val_loader, base_cfg=self.cfg, cfg=self.mode_cfg, criterion=val_criterion, reg_loss=self.regression_loss) ############ change lr based on val result self.scheduler.step(val_score) logger.info("Train loss: {}".format(epoch_loss / n_train)) logging.info("Validation loss: {}".format(val_score)) if val_score <= best_val_score and self.cfg.rank == 0: torch.save(self.net.state_dict(), self.mode_cfg['save_path'] + '/model.pth') best_val_score = val_score with open(self.mode_cfg['train_result'], 'w+') as file: json.dump( { "successFlag":"TRAINING", "bestModelEpoch": epoch }, file, indent=4) # 更新eval中的performance.json with open(self.mode_cfg['eval_pjson'], 'w+') as file: json.dump( { "loss": val_score, }, file, indent=4) # 更新eval中的performance.md with open(self.mode_cfg['eval_pmd'], 'w+') as file: file.write('# overall performance \n') file.write('| loss | \n') file.write('| -------- | \n') file.write(fr'| {val_score} | \n') if self.cfg.rank == 0: json_info = json.load(open(self.mode_cfg['train_result'], 'r')) json_info['successFlag'] = 'SUCCESS' with open(self.mode_cfg['train_result'], 'w') as file: json.dump(json_info, file, indent=4) shutil.copy(self.cfg.pretrain_msg, self.mode_cfg['save_path'] + '/model.pth') print('save best ok') def _testModel_ori(self): # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # model_path = self.mode_cfg['weight_path'] # self.net.load_state_dict(torch.load(model_path, map_location=device)) # logger.info("load model from %s" % model_path) # self.net.to(device=device) assert torch.cuda.is_available() self.net = self.net.cuda() self.net.eval() gamma = self.cfg.lr_gamma if hasattr(self.cfg, 'lr_gamma') else 1 prob_diff = self.mode_cfg['prob_diff'] if 'prob_diff' in self.mode_cfg.keys() else False ############ generate params for generator self._DecodeGeneratorParam() self.generator_parameters['aug'] = False val_dataset = SurDataSet(self.test_data, self.test_info, **self.generator_parameters) # print('val_dataset: ',len(val_dataset)) val_dataset._GenerateTrainData() self._GetLoss(prob_diff, gamma) val_loader = DataLoader(val_dataset, batch_size=self.mode_cfg['batch_size'], shuffle=False) val_criterion = self.criterion val_score = eval_netSE(self.net, val_dataset, val_loader, base_cfg=self.cfg, cfg=self.mode_cfg, criterion=val_criterion, reg_loss=self.regression_loss) # 更新eval中的performance.json with open(self.mode_cfg['eval_pjson'], 'w+') as file: json.dump( { "precision": val_score, }, file, indent=4) # 更新eval中的performance.md with open(self.mode_cfg['eval_pmd'], 'w+') as file: file.write('# overall performance \n') file.write('| precision | \n') file.write('| -------- | \n') file.write(fr'| {val_score} | \n') with open(self.mode_cfg['eval_result'], 'w+') as file: json.dump( { "successFlag": "SUCCESS", },file, indent=4) def _testModel(self): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # model_path = self.mode_cfg['weight_path'] # self.net.load_state_dict(torch.load(model_path, map_location=device)) # logger.info("load model from %s" % model_path) # self.net.to(device=device) print('all cfg:') print(self.cfg) print('model cfg') print(self.mode_cfg) assert torch.cuda.is_available() self.net = torch.jit.load(self.cfg.pretrain_msg) self.net = self.net.cuda() self.net.eval() gamma = self.cfg.lr_gamma if hasattr(self.cfg, 'lr_gamma') else 1 prob_diff = self.mode_cfg['prob_diff'] if 'prob_diff' in self.mode_cfg.keys() else False ############ generate params for generator self._DecodeGeneratorParam() self.generator_parameters['aug'] = False val_dataset = TestDataSet(self.testdata, **self.generator_parameters) # val_loader = DataLoader(val_dataset, batch_size=self.mode_cfg['batch_size'], shuffle=False) # batch_size = self.mode_cfg['batch_size'] # print(f'batch size: {batch_size}') labels = [] preds = [] limits = [] # for data, label, is_limit in val_loader: for data, label, is_limit in val_dataset: # data = data.float().cuda() data = torch.from_numpy(data[None, ...]).to(device=device, dtype=torch.float32) pred = self.net(data) # print(pred) labels.append(label) pred = torch.softmax(pred[0], -1) pred = torch.max(pred, -1)[1] # pred = torch.argmax(pred[0]) # print('pred', pred, pred.shape) preds.append(pred) # limits += is_limit.reshape(-1).tolist() limits.append(is_limit) # print(f'data shape: {data.shape}, label shape: {label.shape}, pred shape: {pred.shape}, is limit shape: {is_limit.shape}') # labels = torch.cat(labels, 0).cpu().numpy() labels = np.array(labels) preds = torch.cat(preds, 0).cpu().numpy() # print(labels.shape, preds.shape) # preds = torch.cat(preds, 0).cpu().numpy() # print(f'labels shape: {labels.shape}, preds shape: {preds.shape}, limits shape: {len(limits)}') # print(f'labels: {labels}') # print(f'preds: {preds}') preds_new = [] for pred_cls_ori, limit in zip(preds, limits): if pred_cls_ori == 1: pred_cls = 2 elif pred_cls_ori == 2: pred_cls = 1 else: pred_cls = pred_cls_ori # if pred_cls_ori == 3: # if limit: # pred_cls = 5 # else: # if limit: # pred_cls = 4 preds_new.append(pred_cls) preds_new = np.array(preds_new) acc = np.sum(preds_new == labels) / preds_new.shape[0] num_gts = preds_new.shape[0] num_tps = np.sum(preds_new == labels) print('gt info:----------------') elems, elems_count = np.unique(labels, return_counts=True) print('gts:', elems) print('gt counts:', elems_count) print('pred info:----------------') elems, elems_count = np.unique(preds_new, return_counts=True) print('preds:', elems) print('pred counts:', elems_count) print('=================') print('num gts:', num_gts) print('num tps:', num_tps) # 更新eval中的performance.json with open(self.mode_cfg['eval_pjson'], 'w+') as file: json.dump( { "accuracy": acc, }, file, indent=4) # 更新eval中的performance.md with open(self.mode_cfg['eval_pmd'], 'w+') as file: file.write('# overall performance \n') file.write('| accuracy |\n') file.write('| -------- |\n') file.write(fr'| {acc} |\n') with open(self.mode_cfg['eval_result'], 'w+') as file: json.dump( { "successFlag": "SUCCESS", },file, indent=4) def make_infer_data_info(self, info_path, data_path): info_dict = {} with open(info_path, 'r+') as f: infer_data_info = json.load(f) for i in infer_data_info['dataList']: uid = i['rawDataUrls'][0].split('/')[-1].split('.zip')[0] if uid not in info_dict: info_dict[uid] = {'dcm_path': os.path.join(data_path, uid)} for j in i['annotations']: ann_path = j['annotationUrls'][0] with open(ann_path, 'r+') as f: ann = json.load(f) for index, k in enumerate(ann['annotationSessions'][0]['annotationSet']): info_dict[uid][index] = k['coordinates'] return info_dict def _inferModel(self): class_map = {0: '肺内磨玻璃结节', 1:'肺内混合结节', 2: '肺内实性结节', 3: '肺内钙化结节', 4: '胸膜结节或斑块', 5: '胸膜钙化结节'} class_map2 = {'肺内磨玻璃结节': 'mbl', '肺内混合结节': 'hh', '肺内实性结节': 'sx', '肺内钙化结节': 'gh', '胸膜结节或斑块': 'xm', '胸膜钙化结节': 'xmgh'} print(class_map) assert torch.cuda.is_available() self.net = self.net.cuda() self.net.eval() gamma = self.cfg.lr_gamma if hasattr(self.cfg, 'lr_gamma') else 1 prob_diff = self.mode_cfg['prob_diff'] if 'prob_diff' in self.mode_cfg.keys() else False ############ generate params for generator self._DecodeGeneratorParam() self.generator_parameters['aug'] = False val_dataset = InferDataSet(self.testdata, **self.generator_parameters) batch_size = self.mode_cfg['batch_size'] print(f'batch size: {batch_size}') # preds = [] info_dict = self.make_infer_data_info(self.cfg.dataset_info_path, self.cfg.dcm_data_path) print(f'info dict: {info_dict}') pred_dict = {} font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.2 font_color = (0,0,255) font_thickness = 1 font_size = 10 # text_position = (0,0) text = '肺结节' (text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, font_thickness) adjusted_position = (20-text_w, 20+text_h) font_path = '/shared/msyhl.ttc' for data, info in val_dataset: uid = info[4] is_limit = info[5] if uid not in pred_dict: pred_dict[uid] = {} index = info[1] # print(data.shape) data = torch.from_numpy(data[None, ...]).float().cuda() pred = self.net(data) pred = torch.softmax(pred[0], -1) pred = torch.max(pred, -1)[1] pred = pred.cpu().numpy()[0] if pred == 3: if is_limit: pred = 5 else: if is_limit: pred = 4 pred_cls = class_map[pred] pred_dict[uid][index] = pred_cls # print('pred', pred, pred.shape) # preds.append(pred) # print(f'data shape: {data.shape}, pred shape: {pred.shape}') # preds = torch.cat(preds, 0).cpu().numpy() # for i in preds: # pred_cls = class_map[i] # print(f'pred class: {pred_cls}') print(f'pred dict: {pred_dict}') target_path = self.cfg.save_img_dir for uid, pred in pred_dict.items(): dcm_path = info_dict[uid]['dcm_path'] itk_image = load_ct_from_dicom(dcm_path) origin = itk_image.GetOrigin() spacing = itk_image.GetSpacing() image = sitk.GetArrayFromImage(itk_image) x_origin, y_origin, z_origin = origin x_spacing, y_spacing, z_spacing = spacing for index, pred_cls in pred.items(): gt_info = info_dict[uid][index] x1, y1, z1, x2, y2, z2 = gt_info[0]['x'], gt_info[0]['y'], gt_info[0]['z'], gt_info[1]['x'], gt_info[1]['y'], gt_info[1]['z'] x1 = int((x1 - x_origin)/x_spacing) y1 = int((y1 - y_origin)/y_spacing) z1 = int((z1 - z_origin)/z_spacing) x2 = int((x2 - x_origin)/x_spacing) y2 = int((y2 - y_origin)/y_spacing) z2 = int((z2 - z_origin)/z_spacing) all_images = [] z_min = max(0, z1-4) z_max = min(image.shape[0]-1, z2+4) for z in range(z_min, z_max, 1): img_z = image.copy()[z, :, :] img_z = norm(img_z) img_z = cv2.cvtColor(img_z, cv2.COLOR_GRAY2BGR) # cv2.rectangle(img_z, (int(y1), int(x1)), (int(y2), int(x2)), (0,0,255), 2) if z>=z1 and z<=z2: cv2.rectangle(img_z, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2) cut_x_min = max(0, x1-64) cut_y_min = max(0, y1-64) cut_x_max = min(image.shape[2]-1, x2+64) cut_y_max = min(image.shape[1]-1, y2+64) img_z = img_z[cut_y_min: cut_y_max, cut_x_min: cut_x_max, :] # img_z = img_z[cut_x_min: cut_x_max, cut_y_min: cut_y_max, :] img_pil = Image.fromarray(cv2.cvtColor(img_z, cv2.COLOR_BGR2RGB).astype(np.uint8)) draw = ImageDraw.Draw(img_pil) font = ImageFont.truetype(font_path, font_size) draw.text((0, 0), pred_cls, font=font, fill=(255,0,0)) img_z = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) # cv2.putText(img_z, text, adjusted_position, font, font_scale, font_color, font_thickness) all_images.append(img_z) for idx, img in enumerate(all_images): save_path = os.path.join(target_path, uid+'_'+str(index)+'_'+class_map2[pred_cls]) if not os.path.exists(save_path): os.makedirs(save_path) # print(img) cv2.imwrite(os.path.join(save_path, '{:03d}.png'.format(idx)), img) output_gif_path = self.cfg.save_gif_dir if not os.path.exists(output_gif_path): os.makedirs(output_gif_path) for file in os.listdir(target_path): files = [] for f in os.listdir(os.path.join(target_path, file)): files.append(f) files.sort(key=lambda x: x[:-4]) frames = [] for i in tqdm(range(len(files))): im = imageio.imread(os.path.join(target_path, file) + '/' + files[i]) frames.append(im) fps = 24.0 imageio.mimsave(os.path.join(output_gif_path, file+'.gif'), frames, 'GIF', duration=1/fps) def _GetTargetClassProb(self, labels, probs): target_class_probs = [prob[0][label] for prob, label in zip(probs, labels)] return target_class_probs def ProbVis(self, labels, probs): values = np.unique(labels) from matplotlib import pyplot as plt for val in values: print('=' * 60) print('Current label is %d' % val) indices = [idx for idx in range(len(labels)) if labels[idx] == val] current_probs = np.take(probs, indices, axis=0) print('Number of samples are %d' % len(current_probs)) plt.hist(current_probs, range=(0, 1)) plt.show() plt.show() def BadCaseVis(self, num_incies, show_edges=False, gaussian_choice=False, sigma=1): for idx in self.badcase_lists[:num_incies]: current_data = self.test_dataset[idx] img = np.squeeze(current_data[0]) label = current_data[1] center_slice_idx = int(img.shape[0] / 2) slice_range = 1 print('=' * 60) print('Current label is ', label, ' result is ', self.result[idx]) print('prob is', self.result_list[idx]) if not show_edges: _, axs = plt.subplots(1, 2 * slice_range + 1) else: _, axs = plt.subplots(2, 2 * slice_range + 1) for slice_idx in range(center_slice_idx - slice_range, center_slice_idx + slice_range + 1): if not show_edges: axs[slice_idx - (center_slice_idx - slice_range)].imshow(img[slice_idx], cmap='bone') else: edges = ShowEdges(img[slice_idx], gaussian_choice=gaussian_choice, sigma=sigma) # edges = HFFilter(img[slice_idx]) axs[0][slice_idx - (center_slice_idx - slice_range)].imshow(img[slice_idx], cmap='bone') axs[1][slice_idx - (center_slice_idx - slice_range)].imshow(edges, cmap='bone') plt.show() def ParametersDecode(self): ''' Generate model parameters ''' # self.model_param_name = os.path.join(self.basic_file_path, self.cfg['model_params']) # self.model_params = LoadJson(self.model_param_name)[self.cfg['model_name']] if self.mode_cfg['model_name'] == 'VGG3D': self.model_params = self.cfg.VGG3D elif self.mode_cfg['model_name'] == 'ResNet': self.model_params = self.cfg.ResNet elif self.mode_cfg['model_name'] == 'Se_ResNet': self.model_params = self.cfg.Se_ResNet elif self.mode_cfg['model_name'] == 'Darknet': self.model_params = self.cfg.Darknet elif self.mode_cfg['model_name'] == 'FSe_ResNet': self.model_params = self.cfg.FSe_ResNet