import os import sys import six import json import shutil import datetime import numpy as np import scipy.ndimage as nd sys.path.append('./models/backbones/') sys.path.append('./models/baseLayers/') sys.path.append('./models/decoders/') sys.path.append('./models/') sys.path.append('./utils/') from data_generator import UNet_3D_Generator,DataStat from ImagePreprocess import Image_preprocess,DataNormlization,CalculatePatchSizeWithDiameter,Data_preprocess from parameter_utils import RectifyParameters from ResultEval import dice_coefficient_post from data_aug import * from RWconfig import LoadJson,GenerateParameters,GenerateParameters,WriteJson from ReadData import load_data,load_single_data from ResultPostProcess import post_process from modelBuild import build_model from model_utils import get_block from Segmentation_loss import Tversky_loss,loss_dice_coefficient_error,NoduleSegAuxLoss,focal_loss from train_utils import get_callbacks from matplotlib import pyplot as plt from keras import optimizers class Segmentation3D(object): def __init__(self,cfg): self.gamma = 1.0 # self.filename = filename # current_config = LoadJson(filename) # #print ('current_config is ',list(current_config.keys())[0]) # key = list(current_config.keys())[0] # parameters = current_config[key] self.cfg = cfg parameters = cfg['training'] self.parameters = parameters ### Generate parameters ''' General parameters ''' self.num_slice = parameters['num_slice'] if parameters['num_slice']>0 else None self.voxel_size = parameters['voxel_size'] if parameters['voxel_size']>0 else None self.n_channels = parameters['n_channels'] self.input_shape = [self.num_slice,self.voxel_size,self.voxel_size,self.n_channels] self.num_classes = parameters['num_classes'] self.final_ksize = parameters['final_ksize'] self.num_units = parameters['num_units'] self.model_parameters_txt = os.path.join('./files',parameters['model_parameters_txt']) self.model_map_parameters_txt = os.path.join('./files',parameters['model_map_parameters_txt']) self.encoder = parameters['encoder'] self.decoder = parameters['decoder'] self.normalizer = parameters['normalizer'] self.base_path = parameters['base_path'] self.pre_model_file = parameters['pre_model_file'] self.model = None self.nonfix_crop = parameters['nonfix_crop'] self.stride_ratio = parameters['stride_ratio'] self.densityTypeIndex = parameters['densityTypeIndex'] print ('self.densityTypeIndex is',self.densityTypeIndex) if 'T_loss_alpha' in parameters.keys(): self.T_loss_alpha = parameters['T_loss_alpha'] else: self.T_loss_alpha = 0.65 print ('self.nonfix_crop is ',self.nonfix_crop) # if 'train' in key.lower(): if cfg.mode == 'training': ''' parameters for training only ''' lr = parameters['initial_learning_rate'] # self.train_file = os.path.join('./files',parameters['train_file']) # self.val_file = os.path.join('./files',parameters['val_txt']) self.load_pre_trained = parameters['load_pre_trained'] self.initial_learning_rate = parameters['initial_learning_rate'] self.learning_rate_drop = parameters['learning_rate_drop'] self.learning_rate_patience = parameters['learning_rate_patience'] self.early_stop = parameters['early_stop'] self.base_path = parameters['base_path'] self.n_epochs = parameters['n_epochs'] self.aug = parameters['aug'] self.prefix = parameters['prefix'] self.train_batch_size = parameters['train_batch_size'] self.val_batch_size = parameters['val_batch_size'] self.shuffle = parameters['shuffle'] self.trainfiles = { 'data_paths':self.cfg['train_data_path'], 'mask_paths':self.cfg['train_mask_path'], 'info_paths':self.cfg['train_info_path'] } self.valfiles = { 'data_paths':self.cfg['val_data_path'], 'mask_paths':self.cfg['val_mask_path'], 'info_paths':self.cfg['val_info_path'] } else: # self.test_file = os.path.join('./files',parameters['test_file']) self.testfiles = { 'data_paths':self.cfg['test_data_path'], 'mask_paths':self.cfg['test_mask_path'], 'info_paths':self.cfg['test_info_path'] } # def LoadGroupFiles(self,filename): # result = LoadJson(filename) # data_paths = [record['data_path'] for record in result] # mask_paths = [record['mask_path'] for record in result] # info_paths = [record['info_path'] for record in result] # return { # 'data_paths':data_paths, # 'mask_paths':mask_paths, # 'info_paths':info_paths # } def BuildSegModel(self): ''' Generate encoder/decoder parameters with model_parameters_txt and model name ''' self.encoder_parameter_list = GenerateParameters(self.model_parameters_txt,self.model_map_parameters_txt,model=self.encoder) self.decoder_parameter_list = GenerateParameters(self.model_parameters_txt,self.model_map_parameters_txt,model=self.decoder) # print ('Before rectify parameter is ',self.decoder_parameter_list) ''' Rectify parameters ''' self.encoder_parameter_list = RectifyParameters(self.encoder_parameter_list,self.parameters) self.decoder_parameter_list = RectifyParameters(self.decoder_parameter_list,self.parameters) # print ('After rectify parameter is ',self.decoder_parameter_list) ''' Build encoder and decoder ''' self.encoder_result = build_model(self.encoder,input_shape=self.input_shape,parameter_list=[self.encoder_parameter_list]) self.model = build_model(self.decoder,input_shape=self.input_shape,parameter_list=[self.encoder_result,self.decoder_parameter_list]) # self._calculate_ratio() # self.model.summary() def CompileSegModel(self): adam=optimizers.Adam(lr=self.initial_learning_rate,beta_1=0.9,beta_2=0.999,epsilon=1e-08) ## todo: change loss to parameters in config file seg_loss = Tversky_loss(alpha=self.T_loss_alpha,beta=1-self.T_loss_alpha) final_loss = 0 if self.decoder_parameter_list['deep_supervision']: final_loss = [seg_loss for _ in range(3)] elif self.decoder_parameter_list['ACF_choice']: final_loss = [seg_loss for _ in range(2)] else: final_loss = [seg_loss] if self.decoder_parameter_list['aux_task']: self._calculate_ratio() cls_loss = focal_loss(self.focal_class_weights,self.gamma) final_loss.append(cls_loss) class_weights = [1 for _ in range(len(final_loss)-1)] + [1.5] else: class_weights = [1 for _ in range(len(final_loss))] if len(final_loss)==1: final_loss = final_loss[0] self.model.compile(optimizer= adam, loss=final_loss,loss_weights=class_weights, metrics=['accuracy']) def GenerateTrainParameters(self): self.train_step_per_epoch = int(np.sum([np.load(path).shape[0] for path in self.trainfiles['info_paths']] )/self.train_batch_size) self.valid_step_per_epoch = int(np.sum([np.load(path).shape[0] for path in self.valfiles['info_paths']] )/self.val_batch_size ) # self.train_step_per_epoch = 2 # self.valid_step_per_epoch = 2 self.train_generator = UNet_3D_Generator(self.trainfiles['data_paths'],self.trainfiles['mask_paths'], self.trainfiles['info_paths'],input_size=self.input_shape, batch_size = self.train_batch_size,aug=self.aug, HU_min = self.normalizer[0],HU_max=self.normalizer[1],crop_choice=True, config_dict=self.decoder_parameter_list,nonfix_crop=self.nonfix_crop, stride_ratio = self.stride_ratio, densityTypeIndex = self.densityTypeIndex) self.val_generator = UNet_3D_Generator(self.valfiles['data_paths'],self.valfiles['mask_paths'], self.valfiles['info_paths'],input_size=self.input_shape, batch_size = self.val_batch_size,aug=False, HU_min = self.normalizer[0],HU_max=self.normalizer[1],crop_choice=True, config_dict=self.decoder_parameter_list,nonfix_crop=self.nonfix_crop, stride_ratio = self.stride_ratio,densityTypeIndex = self.densityTypeIndex) # current_DT = datetime.datetime.now() # current_day = '%02d%02d' % (current_DT.month,current_DT.day) # self.current_base_path = '%s/%s_%s/'%(self.base_path,self.prefix,current_day) # if not os.path.exists(self.current_base_path): # os.mkdir(self.current_base_path) # for target_filename in [self.filename,self.model_parameters_txt,self.model_map_parameters_txt,self.train_file,self.val_file]: # shutil.copyfile(target_filename,os.path.join(self.current_base_path,os.path.basename(target_filename))) # self.filepath = str('%s/Train-{epoch:02d}-{val_loss:.5f}.hdf5'%(self.current_base_path)) # self.filepath = str(self.filepath) # print ('self.filepath is ',type(self.filepath),self.filepath) self.callbacks_list = get_callbacks(self.filepath,initial_learning_rate = self.initial_learning_rate, learning_rate_drop = self.learning_rate_drop,learning_rate_epochs = None, learning_rate_patience = self.learning_rate_patience, early_stopping_patience = self.early_stop) # self.val_generator() def _calculate_ratio(self): train_info = load_single_data(self.trainfiles['info_paths'],self.decoder_parameter_list['aux_task']) val_info = load_single_data(self.valfiles['info_paths']) densityTypeConvertMap = self.decoder_parameter_list['densityTypeConvertMap'] keys = sorted(densityTypeConvertMap.keys()) print ('keys',keys) counts = [0 for _ in range(4)] for info in train_info: val = str(int(float(info[self.densityTypeIndex]))) if val not in keys: continue counts[densityTypeConvertMap[val]-1] +=1 focal_class_weights = [val/float(counts[0]) for val in counts] focal_class_weights = [val**0.5 for val in focal_class_weights] # focal_class_weights[2] = focal_class_weights[2] * 1.4 self.focal_class_weights = [round(val,2) for val in focal_class_weights] print ('self.focal_class_weights is ',self.focal_class_weights) def Train(self): self.BuildSegModel() print ('After building model') self.CompileSegModel() print ('After compiling model') self.LoadPretrainModel() print ('After loading pretrain model') self.GenerateTrainParameters() self.history_callbacks = self.model.fit_generator(self.train_generator,steps_per_epoch=self.train_step_per_epoch,epochs=self.n_epochs, callbacks=self.callbacks_list,shuffle=self.shuffle, validation_data=self.val_generator,validation_steps=self.valid_step_per_epoch) WriteJson(os.path.join(self.current_base_path,'history.json'),self.history_callbacks.history) print ('After saving config to file %s'%os.path.join(self.current_base_path,'history.json')) def MakeInference(self): final_images,final_masks,result = [],[],[] for single_data,single_mask,single_info in zip(self.test_data,self.test_mask,self.test_info): uid = single_info[0] if self.nonfix_crop: patch_shape = CalculatePatchSizeWithDiameter(single_mask,stride_ratio=self.stride_ratio,aug=False) else: patch_shape = self.input_shape current_image,current_mask = Data_preprocess(single_data,single_mask,aug=False,input_shape=patch_shape,shift_ratio=1.0) current_image = np.clip(current_image,self.normalizer[0],self.normalizer[1]) current_image = (current_image - self.normalizer[0])/float(self.normalizer[1]-self.normalizer[0]) # current_image = np.squeeze(DataNormlization(current_image,single_info,HU_min=self.normalizer[0],HU_max=self.normalizer[1],mode='HU'))[np.newaxis,...,np.newaxis] current_result = self.model.predict(current_image) result.append(current_result) final_images.append(current_image) final_masks.append(current_mask) mask_th = 0.05 if not self.decoder_parameter_list['deep_supervision']: if not self.decoder_parameter_list['ACF_choice']: # if not self.ACF_choice: final_result = result>mask_th else: target_idx = 1 current_result = result[target_idx] final_result = current_result>mask_th else: target_idx = 2 final_result = [single_result[target_idx]>mask_th for single_result in result] final_prediction = [post_process(mask_image[0],spacing_z=1.0,inp_shape=patch_shape,debug=False)[2] for mask_image in final_result] return final_result,final_prediction,final_masks,final_images def Inference(self): self.BuildSegModel() self.LoadPretrainModel() print ('Before Loading test data') self.LoadInferenceData() print ('After loading test data') ''' TODO:filter unqualified result ''' final_result,final_prediction,test_masks_post,test_images_post = self.MakeInference() print ('After making prediction') inference_result_dict = { 'model_weights':self.pre_model_file, 'test_data':self.testfiles, 'result':{} } self.final_result = final_result self.final_prediction = final_prediction self.test_masks_post = test_masks_post self.test_images_post = test_images_post result_dice = [dice_coefficient_post(np.squeeze(y_true),np.squeeze(y_pred)) for y_true,y_pred in zip(test_masks_post,final_prediction)] result_dice_infer = [dice_coefficient_post(np.squeeze(y_true),np.squeeze(y_pred)) for y_true,y_pred in zip(test_masks_post,final_result)] self.result_dice = result_dice self.result_dice_infer = result_dice_infer inference_result_dict['result']['overall_post_dice'] = [np.mean(result_dice),np.std(result_dice)] inference_result_dict['result']['overall_infer_dice'] = [np.mean(result_dice_infer),np.std(result_dice_infer)] print ('Overall dice value on infer result is ',inference_result_dict['result']['overall_infer_dice'][0]) print ('Overall dice value on post result is ',inference_result_dict['result']['overall_post_dice'][0]) diameter_ths = [[0,4],[4,10],[10,30],[30,100000],[4,1000000]] diameter_idx = 4 for diameter_th in diameter_ths: diameter_min,diameter_max = diameter_th indices = [sample_idx for sample_idx in range(self.test_info.shape[0]) if self.test_info[sample_idx][diameter_idx].astype('float')=diameter_min] current_dice_post = [result_dice[sample_idx] for sample_idx in indices] current_dice_infer = [result_dice_infer[sample_idx] for sample_idx in indices] inference_result_dict['result']['Diameter_[%d,%d]_post_dice'%(diameter_min,diameter_max)] = [np.mean(current_dice_post),np.std(current_dice_post)] inference_result_dict['result']['Diameter_[%d,%d]_infer_dice'%(diameter_min,diameter_max)] = [np.mean(current_dice_infer),np.std(current_dice_infer)] print ('dice value of nodules between [%d,%d] on infer result is '%(diameter_min,diameter_max),inference_result_dict['result']['Diameter_[%d,%d]_post_dice'%(diameter_min,diameter_max)][0]) print ('dice value of nodules between [%d,%d] on post result is '%(diameter_min,diameter_max),inference_result_dict['result']['Diameter_[%d,%d]_infer_dice'%(diameter_min,diameter_max)][0]) print ('Before writing json file to %s'%os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json')) WriteJson(os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json'),inference_result_dict) print ('After writing json file to %s'%os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json')) def LoadInferenceData(self): self.test_data_list = load_data([self.testfiles['data_paths'],self.testfiles['mask_paths'], self.testfiles['info_paths']]) self.test_data,self.test_mask,self.test_info = self.test_data_list def LoadPretrainModel(self): print ('load model ',self.pre_model_file) if os.path.exists(self.pre_model_file) and self.model: self.model.load_weights(self.pre_model_file) elif not os.path.exists(self.pre_model_file): print ('Pretrain model path does not exist') else: print ('model has not been defined') def Visualize(self,dice_list,iou_th,diameter_th=[0,float('inf')],densityType=np.arange(1,10000)): for idx in range(len(dice_list)): dice_value = dice_list[idx] current_data = np.squeeze(self.test_images_post[idx]) current_mask = np.squeeze(self.test_masks_post[idx]) current_info = np.squeeze(self.test_info[idx]) current_result = np.squeeze(self.final_result[idx]) diameter_idx = 4 densityType_idx = -4 print ('current_info[diameter_idx]',current_info[diameter_idx],current_info[densityType_idx],dice_value) if float(current_info[diameter_idx])>=diameter_th[0] and float(current_info[diameter_idx])