import os import sys import six import json import glob import shutil import logging import datetime import numpy as np import scipy.ndimage as nd from tqdm import tqdm from matplotlib import pyplot as pltF import imageio import torch import torch.nn as nn from torch import optim from torch.utils.data import DataLoader # from tensorboard_logger import Logger from k8s_utils import CMetricsWriter # sys.path.append('../../lungNoduleDensityClassification/nodulecls/utils') # from RWconfig import LoadJson,WriteJson # from OfflineEval import * # sys.path.append('utils') # sys.path.append('model') from model.modelBuild import build_model from model.baseLayers.BasicModules import weight_init from utils.SegmentationLoss import Tversky_loss,DiceLoss,SurfaceLoss from utils.SegmentationDataGen import SegDataGenerator from utils.ImagePreprocess import * from utils.ResultEval import dice_coefficient_post from utils.respacing_func import * from utils.cls_loss import * def LoadJson(path): with open(path, 'r+') as f: return json.load(f) def Dice(label, pred): x = label.sum() y = pred.sum() dice = 2*(label*pred).sum()/(x+y+1e-5) return dice def norm(image, hu_min=-1000.0, hu_max=600.0): image = (np.clip(image.astype(np.float32), hu_min, hu_max) - hu_min) / float(hu_max - hu_min) return image * 255. class Segmentation3D(object): def __init__(self,cfg,logger): # self.filename = filename self.cfg_g = cfg self.logger = logger self.mode = cfg.mode self.parseParams(cfg) if self.mode == 'training': self.device = torch.device("cuda:%s"%(cfg.rank) if torch.cuda.is_available() else "cpu") self.metrics_writer = CMetricsWriter( filename=self.common_cfg['train_metrics'], headernames=['epoch', 'lr', 'loss', 'performanceAccuracy'] ) self.cfg = cfg.training else: self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.cfg = cfg.testing def __call__(self): if self.mode == 'infer': self.Infer() # self._makeTensorBoard() else: self.buildModel() self.compileModel() self.loadData() if self.mode == 'training': self.model = self.model.cuda() # torch.nn.parallel.DistributedDataParallel(self.model, # broadcast_buffers=False, # find_unused_parameters=True) self.Train() elif self.mode == 'testing': self.Test() # def _makeTensorBoard(self): # tensorboard_path = "%s/tensorboard_logs"%self.save_path # self.tensorboard_logger = Logger(logdir=tensorboard_path,flush_secs=10) # self.logger.info('tensorboard path is %s'%tensorboard_path) def parseParams(self, cfg): # cfg = LoadJson(self.filename) mode = self.mode self.common_cfg = cfg.common # if self.mode=='training': # current_DT = datetime.datetime.now() # self.current_day = '%02d%02d'%(current_DT.month,current_DT.day) # model_order = '%02d'%self.cfg['model_order'] # self.save_path = self.common_cfg['base_path']+'/'+self.current_day+'_'+self.cfg['model_name']+'/'+str(model_order)+'_'+self.cfg['prefix'] # else: # self.pre_model_file = self.cfg['pre_model_file'] # # self.save_path = '/'.join(self.cfg['pre_model_file'].split('/')[:-1]) # self.save_path = os.path.join(os.path.pardir(self.pre_model_file), 'run') self.save_path = self.cfg_g.common['save_path'] self.logger.debug('save path is %s'%self.save_path) if not os.path.exists(self.save_path): os.makedirs(self.save_path) # target_filename = os.path.join(self.save_path,self.filename.split('/')[-1]) # shutil.copy(self.filename,target_filename) def loadData(self): input_size = [64, 64, 64] generator_parameters = {'aug':self.cfg.get('aug',False), 'HU_min':self.common_cfg['normalizer'][0],'HU_max':self.common_cfg['normalizer'][1], 'nonfix_crop':self.cfg['nonfix_crop'],'stride_ratio':self.common_cfg['stride_ratio'], 'densityTypeIndex':self.common_cfg['densityTypeIndex'], 'diameter_enlarge_ratio':self.common_cfg['diameter_enlarge_ratio'], 'diameter_idx':self.common_cfg['diameter_idx'],'spacing_idx':self.common_cfg['spacing_idx'], 'input_size':input_size,'mode':self.mode, 'min_patch_size':self.common_cfg.get('min_patch_size',16), 'do_cls':self.common_cfg.get('do_cls',False), 'repeat_times':self.cfg.get('repeat_times',4), 'zoom_shape':self.common_cfg.get('zoom_shape',None), 'calculate_shape_choice':self.common_cfg.get('calculate_shape_choice',1), 'use_label_diameter':self.cfg.get('use_label_diameter',False)} # train_data_files = self.common_cfg["train_file"] # val_data_files = self.common_cfg["val_file"] # test_data_files = self.common_cfg["test_file"] if self.mode == 'training': # train_paths = [self.common_cfg['train_data_path'], self.common_cfg['train_mask_path'], self.common_cfg['train_info_path']] # val_paths = [self.common_cfg['val_data_path'], self.common_cfg['val_mask_path'], self.common_cfg['val_info_path']] # test_paths = [self.common_cfg['val_data_path'], self.common_cfg['val_mask_path'], self.common_cfg['val_info_path']] train_paths = self.common_cfg['train_data_path'] val_paths = self.common_cfg['val_data_path'] elif self.mode == "testing": # val_paths = [self.common_cfg['val_data_path'], self.common_cfg['val_mask_path'], self.common_cfg['val_info_path']] val_paths = self.common_cfg['val_data_path'] # self.test_paths = test_paths if self.mode=='training': self.train_generator = SegDataGenerator(data_paths=train_paths,**generator_parameters) generator_parameters['aug'] = False generator_parameters['mode'] = 'val' self.val_generator = SegDataGenerator(data_paths=val_paths,**generator_parameters) else: self.val_generator = SegDataGenerator(data_paths=val_paths,**generator_parameters) # def _parseDataJosn(self,filename): # data_files = [] # mask_files = [] # info_files = [] # target_filename = os.path.join(self.save_path,filename.split('/')[-1]) # filename = 'files/%s'%filename # files = LoadJson(filename) # shutil.copy(filename,target_filename) # for record in files: # data_file = record["data_path"] # mask_file = record["mask_path"] # info_file = record.get("info_path",None) # data_files.append(data_file) # mask_files.append(mask_file) # if info_file is not None: # info_files.append(info_file) # if len(info_files)==0: # info_files = None # return data_files,mask_files,info_files # def _findLast(self,paths): # length = len('CP_epoch') # vals = [int(x.split('/')[-1][length:]) for x in paths] # max_val = max(vals) # idx = [idx for idx in range(len(vals)) if vals[idx]==max_val][0] # return paths[idx] def loadModel(self,path=None): if path is not None: model_path = path else: pretrain = self.cfg_g.pretrain_msg self.pretrain = pretrain model_path = self.pretrain print ('path is %s'%path) print ('model path is %s'%model_path) if os.path.exists(model_path): model_path = model_path elif len(glob.glob(pretrain+'/*pth'))>0: paths = glob.glob(pretrain+'/*pth') model_path = self._findLast(paths) else: model_path = None if model_path is not None: try: print('load model weights') self.model.to(device=self.device) self.model.load_state_dict(torch.load(model_path,map_location=self.device)) except: self.logger.error('Cannot load pretrain weight %s'%model_path) def Train(self): train_batch_size = self.cfg_g.batch_size val_batch_size = self.cfg_g.batch_size print ('train_batch_size is ',train_batch_size,val_batch_size) train_Loader = DataLoader(self.train_generator,batch_size=train_batch_size,shuffle=True) val_Loader = DataLoader(self.val_generator,batch_size=val_batch_size,shuffle=False,drop_last=False) num_epoch = self.cfg_g.epochs compare_choice = self.cfg.get('compare_choice',False) loss_names = { 0:'branch_1_dice_loss', 1:'branch_2_dice_loss', 2:'branch_3_dice_loss', 3:'branch_1_surface_loss', 4:'branch_2_surface_loss', 5:'branch_3_surface_loss', 6:'branch_1_trans_dice_loss', 7:'branch_2_trans_dice_loss', 8:'branch_3_trans_dice_loss', 9:'branch_1_trans_surface_loss', 10:'branch_2_trans_surface_loss', 11:'branch_3_trans_surface_loss', 12:'branch_1_sim_loss', 13:'branch_2_sim_loss', 14:'branch_3_sim_loss', 15:'cls_loss' } n_train = int(len(self.train_generator)/train_batch_size) label_dtype = torch.long self.loadModel() best_val_loss = 100000 for epoch in range(num_epoch): self.model.train() self.model.to(self.device) self.model.cuda() epoch_loss = 0 self.train_losses = [[] for _ in range(len(loss_names.keys()))] self.val_losses = [[] for _ in range(len(loss_names.keys()))] with tqdm(total=n_train,desc=f'Epoch{epoch+1}/{num_epoch}',unit='img') as pbar: batch_count = 0 for batch in zip(train_Loader): batch_count +=1 # if not compare_choice: loss, _ = self.TrainSingle(batch) # else: # loss = self.TrainSingleWithConstrain(batch) self.optimizer.zero_grad() loss.backward() epoch_loss += loss.item() self.optimizer.step() pbar.update() train_loss = (epoch_loss)/float(n_train) self.model.eval() val_loss, dice = self.Val(val_Loader,epoch,num_epoch) self.model.train() self.scheduler.step(val_loss) self.logger.info("Train loss:{}".format(train_loss)) self.logger.info("Validation loss:{}".format(val_loss)) # self.tensorboard_logger.log_value('train_loss', train_loss, epoch) # self.tensorboard_logger.log_value('val_loss', val_loss, epoch) # self._AddTensorData(self.train_losses,loss_names,'train',epoch) # self._AddTensorData(self.val_losses,loss_names,'val',epoch) if torch.distributed.get_rank() == 0: self.metrics_writer.append_one_line([epoch, self.optimizer.param_groups[0]['lr'], train_loss, dice]) with open(self.common_cfg['eval_pjson'], 'w+') as file: json.dump( { "loss": val_loss, "dice": dice }, file, indent=4) # 更新eval中的performance.md with open(self.common_cfg['eval_pmd'], 'w+') as file: file.write('# overall performance \n') file.write('| loss | dice |\n') file.write('| ---- | ---- | \n') file.write(fr'| {val_loss} | {dice} |\n') # if self.cfg['save_cp']: # model_save_path_current_epoch = self.save_path+f'/CP_epoch{epoch+1}_loss_%.2f.pth'%val_loss # torch.save(self.model.state_dict(),model_save_path_current_epoch) # self.logger.info(f'Checkpoint{epoch+1} saved to %s!'%model_save_path_current_epoch) if val_loss <= best_val_loss: with open(self.common_cfg['train_result'], 'w+') as file: json.dump( { "successFlag":"TRAINING", "bestModelEpoch": epoch }, file, indent=4) save_path = self.common_cfg['save_path']+'/model.pth' torch.save(self.model.state_dict(), save_path) print(f'model save to {save_path}') best_val_loss = val_loss if torch.distributed.get_rank() == 0: json_info = json.load(open(self.common_cfg['train_result'], 'r')) json_info['successFlag'] = 'SUCCESS' with open(self.common_cfg['train_result'], 'w') as file: json.dump(json_info, file, indent=4) print('train done') def Test(self): loss_names = { 0:'branch_1_dice_loss', 1:'branch_2_dice_loss', 2:'branch_3_dice_loss', 3:'branch_1_surface_loss', 4:'branch_2_surface_loss', 5:'branch_3_surface_loss', 6:'branch_1_trans_dice_loss', 7:'branch_2_trans_dice_loss', 8:'branch_3_trans_dice_loss', 9:'branch_1_trans_surface_loss', 10:'branch_2_trans_surface_loss', 11:'branch_3_trans_surface_loss', 12:'branch_1_sim_loss', 13:'branch_2_sim_loss', 14:'branch_3_sim_loss', 15:'cls_loss' } self.val_losses = [[] for _ in range(len(loss_names.keys()))] val_batch_size = 1 val_Loader = DataLoader(self.val_generator,batch_size=val_batch_size,shuffle=False) self.loadModel() self.model.to(self.device) # self.model.cuda() val_loss, dice = self.Val(val_Loader,1,1) with open(self.common_cfg['eval_pjson'], 'w+') as file: json.dump( { "loss": val_loss, "dice": dice, }, file, indent=4) with open(self.common_cfg['eval_pmd'], 'w+') as file: file.write('# overall performance \n') file.write('| loss | dice | \n') file.write('| ---- | ---- | \n') file.write(fr'| {val_loss} | {dice} |\n') with open(self.common_cfg['eval_result'], 'w+') as file: json.dump( { "successFlag": "SUCCESS", },file, indent=4) def Val(self,val_Loader,epoch,num_epoch): compare_choice = self.cfg.get('compare_choice',False) val_batch_size = self.cfg_g.batch_size self.model.eval() total_loss = 0 total_dice = 0 n_val = int(len(self.val_generator)/val_batch_size) print('n_val:', n_val) with tqdm(total=n_val,desc=f'Epoch{epoch+1}/{num_epoch}',unit='img') as pbar: batch_count = 0 for batch in zip(val_Loader): batch_count += 1 # if not compare_choice: loss, dice = self.TrainSingle(batch,mode='val') # else: # loss = self.TrainSingleWithConstrain(batch,mode='val') total_loss += loss.item() total_dice += dice.item() pbar.update() self.model.train() return total_loss/float(n_val), total_dice/float(n_val) # def _AddTensorData(self,losses,name_dict,mode,epoch): # for idx in range(len(losses)): # if len(losses[idx])>0: # mean_val = np.mean(losses[idx]) # self.tensorboard_logger.log_value('%s_%s'%(mode,name_dict[idx]), mean_val, epoch) def Infer(self): import cv2 from skimage import measure, morphology from scipy.spatial import distance import SimpleITK as sitk data = self.common_cfg['val_data_path'] info = self.common_cfg['val_info_path'] data = np.load(data) info = np.load(info) print(f'data shape: {data.shape}') self.buildModel() self.loadModel() num_imgs = data.shape[0] self.model = self.model.to(self.device) self.model.eval() all_images = [] target_path = self.cfg_g.save_img_dir save_json = os.path.dirname(target_path)+'/diameter.json' diameter_dict = {} for i in range(num_imgs): print(f'img: {i}') spacing_x, spacing_y, spacing_z = info[i] input_data = torch.from_numpy(data[i][None, None, ...]).float().to(self.device) pred = self.model(input_data) print(f'input shape: {input_data.shape}, output shapeL {pred[0].shape}') pred = torch.sigmoid(pred[0][0, 0, ...]).detach().cpu().numpy() mask = np.where(pred>0.9, 1, 0) # print(1) labeled_mask, num_features = measure.label(mask, connectivity=1, return_num=True) # print(2) regions = measure.regionprops(labeled_mask) # print(3) max_region = max(regions, key=lambda r:r.area) # print(4) coords =max_region.coords # print(coords.shape) if coords.shape[0]>10000: coords = coords[:10000, :] diameter = max(distance.pdist(coords)*spacing_x) # print(5) # diameter = 11 print(f'max diameter: {diameter}') diameter_dict[i] = diameter imgs = input_data[0,0,...].cpu().numpy() imgs = norm(imgs) z = imgs.shape[0] for j in range(z): img = imgs[j] img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) mask_i = mask[j][...,None] zero_mask = np.zeros_like(mask_i) mask_i = np.concatenate((zero_mask, zero_mask, mask_i*255), -1).astype(np.uint8) img = 0.5*img + 0.5*mask_i # all_images.append(img) # for idx, img in enumerate(all_images): save_path = os.path.join(target_path, str(i)+'_'+str(diameter)) if not os.path.exists(save_path): os.makedirs(save_path) # print(img) cv2.imwrite(os.path.join(save_path, '{:03d}.png'.format(j)), img) print(f'save diameter json') with open(save_json, 'w+') as f: json.dump(diameter_dict, f) output_gif_path = self.cfg_g.save_gif_dir if not os.path.exists(output_gif_path): os.makedirs(output_gif_path) for file in os.listdir(target_path): files = [] for f in os.listdir(os.path.join(target_path, file)): files.append(f) files.sort(key=lambda x: x[:-4]) frames = [] for i in tqdm(range(len(files))): im = imageio.imread(os.path.join(target_path, file) + '/' + files[i]) frames.append(im) fps = 24.0 imageio.mimsave(os.path.join(output_gif_path, file+'.gif'), frames, 'GIF', duration=1/fps) # def TrainSingleWithConstrain(self,batch,mode='training'): # loss_weights = eval(self.cfg.get('loss_weights','[1.0,1.0,1.0]')) # calculate_surface_loss = self.common_cfg.get('surface_loss',False) # surface_loss_ratio = self.common_cfg.get('e_value',0.1) # compare_loss_ratio = self.common_cfg.get('comapre_loss_ratio',0.1) # cls_choice = self.common_cfg.get('do_cls',False) # surface_criterion = SurfaceLoss() # compare_loss = torch.nn.MSELoss() # data,label = batch[0] # cls_label = 1 # ''' # trans img and # ''' # trans_imgs,trans_labels,trans_dicts = [],[],[] # for inner_batch_idx in range(np.squeeze(data.shape[0])): # trans_img,trans_label,trans_dict = self.train_generator._transImage(data[inner_batch_idx],label[inner_batch_idx]) # trans_imgs.append(trans_img) # trans_labels.append(trans_label) # trans_dicts.append(trans_dict) # ''' # 计算surface loss # ''' # if calculate_surface_loss: # if label.shape[0]>1: # distance_map_label = calc_dist_map_batch(np.squeeze(label.cpu().numpy())) # distance_map_label_trans = calc_dist_map_batch(np.squeeze(trans_labels)) # else: # distance_map_label = calc_dist_map(np.squeeze(label.cpu().numpy())) # distance_map_label_trans = calc_dist_map(np.squeeze(trans_labels)) # if len(distance_map_label.shape)==4: # distance_map_label = distance_map_label[:,np.newaxis,...] # if len(distance_map_label_trans.shape)==4: # distance_map_label_trans = distance_map_label_trans[:,np.newaxis,...] # distance_map_label = torch.from_numpy(distance_map_label).to(device=self.device,dtype=torch.float32) # distance_map_label_trans = torch.from_numpy(distance_map_label_trans).to(device=self.device,dtype=torch.float32) # data = data.cuda().to(device=self.device,dtype=torch.float32) # label = label.cuda().to(device=self.device,dtype=torch.float32) # trans_imgs = np.array(trans_imgs) # trans_labels = np.array(trans_labels) # trans_imgs = torch.tensor(trans_imgs).cuda().to(device=self.device,dtype=torch.float32) # trans_labels = torch.tensor(trans_labels).cuda().to(device=self.device,dtype=torch.float32) # if mode=='val': # with torch.no_grad(): # pred = self.model(data) # if cls_choice: # cls_pred = pred[-1] # pred = pred[:-1] # else: # cls_pred = None # pred_trans = self.model(trans_imgs) # else: # pred = self.model(data) # if cls_choice: # cls_pred = pred[-1] # pred = pred[:-1] # else: # cls_pred = None # pred_trans = self.model(trans_imgs) # cpu_pred = [val.data.cpu().numpy() for val in pred] # cpu_pred_trans = [] # for branch_idx in range(len(pred)): # current_branch_result = [] # for inner_batch_idx in range(data.shape[0]): # current_cpu_pred_trans = self.train_generator._transWithFunc(cpu_pred[branch_idx][inner_batch_idx],trans_dicts[inner_batch_idx]) # current_branch_result.append(np.array(current_cpu_pred_trans).astype(np.double)) # current_branch_result = np.array(current_branch_result) # cpu_pred_trans.append(torch.tensor(current_branch_result).cuda().to(device=self.device,dtype=torch.float32)) # gpu_pred_trans = cpu_pred_trans # for idx in range(len(pred)): # loss_val = self.criterion(pred[idx],label) # loss_val_trans = self.criterion(pred_trans[idx],trans_labels) # if mode=='val': # self.val_losses[idx].append(loss_val.item()) # self.val_losses[idx+len(pred)*2].append(loss_val.item()) # else: # self.train_losses[idx].append(loss_val_trans.item()) # self.train_losses[idx+len(pred)*2].append(loss_val_trans.item()) # if idx==0: # total_loss = loss_val*loss_weights[idx]+loss_val_trans*loss_weights[idx] # else: # total_loss += (loss_val*loss_weights[idx]+loss_val_trans*loss_weights[idx]) # if calculate_surface_loss: # surface_loss = surface_criterion(pred[idx],distance_map_label) # surface_loss_trans = surface_criterion(pred[idx],distance_map_label_trans) # total_loss += surface_loss_ratio*(surface_loss+surface_loss_trans) # if mode=='val': # self.val_losses[idx+len(pred)].append(surface_loss.item()*surface_loss_ratio) # self.val_losses[idx+len(pred)*3].append(surface_loss_trans.item()*surface_loss_ratio) # else: # self.train_losses[idx+len(pred)].append(surface_loss.item()*surface_loss_ratio) # self.train_losses[idx+len(pred)*3].append(surface_loss_trans.item()*surface_loss_ratio) # trans_loss = compare_loss(gpu_pred_trans[idx],pred_trans[idx]) # total_loss += trans_loss * compare_loss_ratio # if mode=='val': # self.val_losses[idx+len(pred)*4].append(trans_loss.item() * compare_loss_ratio) # else: # self.train_losses[idx+len(pred)*4].append(trans_loss.item() * compare_loss_ratio) # total_loss /= (2*len(pred)) # if cls_choice: # loss_input = [cls_pred]+[cls_label]+[label_diff] # cls_loss_val = self.cls_criterion(loss_input) # cls_loss_val = self.common_cfg.get('cls_loss_ratio',0.1)*cls_loss_val # total_loss += cls_loss_val # if mode=='val': # self.val_losses[-1].append(cls_loss_val.item()) # else: # self.train_losses[-1].append(cls_loss_val.item()) # return total_loss def TrainSingle(self,batch,mode='training'): loss_weights = eval(self.cfg.get('loss_weights','[1.0,1.0,1.0]')) cls_choice = self.common_cfg.get('do_cls',False) # calculate_surface_loss = self.common_cfg.get('surface_loss',False) calculate_surface_loss = False surface_loss_ratio = self.common_cfg.get('e_value',0.1) data,label = batch[0] cls_label = 1 surface_criterion = SurfaceLoss() print(f'input data shape: {data.shape}, mask shape: {label.shape}') # if calculate_surface_loss: # if label.shape[0]>1: # distance_map_label = calc_dist_map_batch(np.squeeze(label.cpu().numpy())) # else: # distance_map_label = calc_dist_map(np.squeeze(label.cpu().numpy())) # if len(distance_map_label.shape)==4: # distance_map_label = distance_map_label[:,np.newaxis,...] # distance_map_label = torch.from_numpy(distance_map_label).to(device=self.device,dtype=torch.float32) data = data.cuda().to(device=self.device,dtype=torch.float32) label = label.cuda().to(device=self.device,dtype=torch.float32) if mode=='val': with torch.no_grad(): pred = self.model(data) else: pred = self.model(data) print(11111, pred[0].shape) if cls_choice: cls_pred = pred[-1] pred = pred[:-1] # pred = torch.sigmoid(pred[0][0, 0, ...]).detach().cpu().numpy() # mask = np.where(pred>0.8, 1, 0) total_dice = 0 total_loss = 0 for idx in range(len(pred)): loss_val = self.criterion(pred[idx],label) if idx==0: total_loss = loss_val*loss_weights[idx] else: total_loss += loss_val*loss_weights[idx] if mode=='val': self.val_losses[idx].append(loss_val.item()) else: self.train_losses[idx].append(loss_val.item()) if calculate_surface_loss: surface_loss = surface_criterion(pred[idx],distance_map_label) total_loss += surface_loss_ratio*surface_loss if mode=='val': self.val_losses[idx+len(pred)].append(surface_loss.item()*surface_loss_ratio) else: self.train_losses[idx+len(pred)].append(surface_loss.item()*surface_loss_ratio) total_dice += Dice(label, pred[idx]) total_loss /= (len(pred)) total_dice /= (len(pred)) if cls_choice: loss_input = [cls_pred]+[cls_label]+[label_diff] cls_loss_val = self.cls_criterion(loss_input) cls_loss_val = self.common_cfg.get('cls_loss_ratio',0.1)*cls_loss_val total_loss += cls_loss_val if mode=='val': self.val_losses[-1].append(cls_loss_val.item()) else: self.train_losses[-1].append(cls_loss_val.item()) return total_loss, total_dice def compileModel(self): lr_schedule = 'basic' self.optimizer = optim.Adam(self.model.parameters(), lr=self.cfg_g.lr) if lr_schedule=='basic': self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,patience=self.common_cfg['learning_rate_patience'], factor=self.common_cfg['learning_rate_drop'],verbose=True) else: self.scheduler = optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=self.common_cfg.get('base_lr',1e-5), max_lr=self.common_cfg.get('max_lr',1e-4), step_size_up=self.common_cfg.get('step_size_up',2000), cycle_momentum=False) if self.common_cfg.get('do_cls',False): self.criterion = SeLossWithBCEFocalLoss() else: self.criterion = Tversky_loss(alpha=self.common_cfg['T_loss_alpha'],beta=1-self.common_cfg['T_loss_alpha']) def buildModel(self): filename = self.common_cfg['model_parameters'] target_model_cfg_name = os.path.join(self.save_path,filename.split('/')[-1]) # filename = 'files/%s'%filename model_cfg = LoadJson('TorchNoduleSeg/files/'+filename) # shutil.copy(filename,target_model_cfg_name) encoder_model_name = self.common_cfg['encoder'] decoder_model_name = self.common_cfg['decoder'] encoder_model_cfg = model_cfg[encoder_model_name] decoder_model_cfg = model_cfg[decoder_model_name] encoder_model = build_model(encoder_model_name,encoder_model_cfg) decoder_model_cfg['encoder_model'] = encoder_model decoder_model_cfg['do_cls'] = self.common_cfg.get('do_cls',False) self.model = build_model(decoder_model_name,decoder_model_cfg) self.model.apply(weight_init) # summary(self.model.cuda(),(1,32,32,32)) def AnayInferResult(self,dice_list,info_list,model_path=None,file_order=0,diameter_ths = [[0,4],[4,10],[10,30],[30,100000],[4,1000000]]): if model_path is None: model_path = self.pre_model_file inference_result_dict = { 'model_weights':model_path, 'test_data':self.common_cfg["test_file"], 'result':{} } inference_result_dict['result']['overall_infer_dice'] = [np.mean(dice_list),np.std(dice_list)] print ('length of info list and dice list is',len(info_list),len(dice_list)) print ('length of info is',len(info_list[0])) print ('info_list[0]',info_list[0],info_list[0][self.common_cfg['diameter_idx']],type(info_list[0][self.common_cfg['diameter_idx']])) self.logger.info('overall_infer_dice is %s'%(str([np.mean(dice_list),np.std(dice_list)]))) for diameter_th in diameter_ths: diameter_min,diameter_max = diameter_th indices = [idx for idx in range(len(dice_list)) if float(info_list[idx][self.common_cfg['diameter_idx']])>=diameter_min and float(info_list[idx][self.common_cfg['diameter_idx']])th for val in pred] th_pred = [np.array(val).astype(np.uint8) for val in th_pred] if self.common_cfg.get('zoom_shape',None) is None: return data,pred,current_mask,margins else: zoom_back_preds = [self.test_generator._zoomPredBack(info,single_mask,new_spacing,patch_shape) for single_mask in th_pred] return non_zoom_image,zoom_back_preds,non_zoom_mask,margins def MakeInference(self): ################# load predtrain weights pretrain = self.cfg['pre_model_file'] if '*' in pretrain: cand_paths = sorted(glob.glob(pretrain+'*pth')) else: cand_paths = [pretrain] model_dice_lists = [] model_order = 0 for model_path in cand_paths: model_order+=1 self.loadModel(model_path) infer_idx = self.cfg.get('infer_idx',-1) save_result_choice = self.cfg.get('save_result_choice',False) self.model.eval() indices = np.arange(len(self.test_generator)) dice_vals = [] image_list,mask_list,pred_list = [],[],[] info_list = [] margin_list = [] for idx in indices: current_data,pred,current_mask,margins = self.MakeSinglePred(self.test_generator[idx]) image_list.append(current_data) mask_list.append(current_mask) pred_list.append(pred) info_list.append(self.test_generator.info[idx]) margin_list.append(margins) dice_val = dice_coefficient_post(current_mask,pred[infer_idx]) dice_vals.append(dice_val) ########################### Analy self.AnayInferResult(dice_vals,info_list,model_path,model_order) model_dice_lists.append(np.mean(dice_vals)) self.image_list = image_list self.mask_list = mask_list self.pred_list = pred_list self.dice_vals = dice_vals self.info_list = info_list if save_result_choice: self.savePredResult(pred_list,margin_list,infer_idx) if len(model_dice_lists)>0: max_idx = np.argmax(model_dice_lists) max_dice = max(model_dice_lists) max_dice_model_path = cand_paths[max_idx] self.logger.info('model_path %s has max avg dice val %.4f'%(max_dice_model_path,max_dice)) if self.cfg.get('badcase_vis',False): result_list = [image_list,mask_list,pred_list,dice_vals,info_list] diameter_ths = eval(self.cfg['diameter_ths']) dice_range = eval(self.cfg['dice_range']) self.BadCaseVis(result_list,diameter_ths,dice_range) def savePredResult(self,result_list,margin_list,infer_idx): final_result_list = [] for idx in range(len(result_list)): pad_result = self.padImageBack(result_list[idx][infer_idx],margin_list[idx]) final_result_list.append(pad_result) save_path = os.path.join(os.path.dirname(self.pre_model_file),'prediction.npy') np.save(save_path,final_result_list) logger.info('After saving pred result to %s'%save_path) def padImageBack(self,pred,margin): current_shape = pred.shape target_shape = [val+2*val_m for val,val_m in zip(current_shape,margin)] target_mask = np.zeros(target_shape) target_mask[margin[0]:margin[0]+current_shape[0], margin[1]:margin[1]+current_shape[1], margin[2]:margin[2]+current_shape[2]] = pred return target_mask def BadCaseVis(self,result_list=None,diameter_ths=[0,1000000],dice_range=[0,1]): if result_list is None: result_list = [self.image_list,self.mask_list,self.pred_list,self.dice_vals,self.info_list] diameter_min,diameter_max = diameter_ths data,mask,pred,dices,infos = result_list indices = [idx for idx in range(len(infos)) if float(infos[idx][self.common_cfg['diameter_idx']])>=diameter_min and float(infos[idx][self.common_cfg['diameter_idx']])=dice_range[0]] for idx in indices: print ('='*60) current_data,current_mask,current_pred = np.squeeze(data[idx]),np.squeeze(mask[idx]),np.squeeze(pred[idx]) center_slice_idx = int(current_data.shape[0]/2) _,axs =plt.subplots(1,3) axs[0].imshow(current_data[center_slice_idx],cmap='bone') axs[1].imshow(current_mask[center_slice_idx],cmap='bone') axs[2].imshow(current_pred[center_slice_idx],cmap='bone') plt.show() def _DataStat(self,diameter_ths=[[0,1000000000],[0,4],[4,10],[10,20],[20,30],[30,1000000],[40,100000],[4,100000000000000000]]): if self.mode=='training': self.logger.info('Nodule distribution on diameter in train dataset is') self.train_generator.DataStat(diameter_ths) self.logger.info('Nodule distribution on diameter in val dataset is') self.val_generator.DataStat(diameter_ths) else: self.test_generator.DataStat(diameter_ths)