# 整体训练示例

In [10]:
import os
import sys
import six
import json
import shutil
import datetime
import numpy as np
import scipy.ndimage as nd

sys.path.append('./models/backbones/')
sys.path.append('./models/baseLayers/')
sys.path.append('./models/decoders/')
sys.path.append('./models/')
sys.path.append('./utils/')

from data_generator import UNet_3D_Generator
from ImagePreprocess import Image_preprocess,DataNormlization
from parameter_utils import RectifyParameters
from ResultEval import dice_coefficient_post
from data_aug import *
from RWconfig import LoadJson,GenerateParameters,GenerateParameters,WriteJson
from ReadData import load_data
from ResultPostProcess import post_process

from modelBuild import build_model
from model_utils import get_block
from Segmentation_loss import Tversky_loss,loss_dice_coefficient_error,NoduleSegAuxLoss
from train_utils import get_callbacks



from keras import optimizers



# 定义对应类实现相应功能,本类包括
1. init
2. LoadGroupFiles:因为数据存储成多个npy文件,多组data_path/mask_path/info_path写在train/val/test files json文件中。本函数用于将所有data_path,mask_path,info_path组合成数组
3. BuildSegModel:搭建分割模型(train/inference都需要) 调用了GenerateParameters、RectifyParameters产生模型参数
4. GenerateTrainParameters:产生训练相关的参数。(step/generator/callbacks等)
5. CompileSegModel:compileModel
6. Train:调用上面的函数,开始训练
7. Inference:搭建模型、inference数据预处理、预测、后处理、产生指标(目前写定了用dice)
8. LoadInferenceData:inference阶段载入数据
9. LoadPretrainModel:根据config文件中定义的pretrain模型路径,载入模型。如果路径不存在,则提示

In [13]:
class Segmentation3D(object):
 def __init__(self,filename):
 self.filename = filename
 current_config = LoadJson(filename)
 key = current_config.keys()[0]
 parameters = current_config[key]
 self.parameters = parameters
 ### Generate parameters
 
 '''
 General parameters
 '''
 
 self.num_slice = parameters['num_slice']
 self.voxel_size = parameters['voxel_size']
 self.n_channels = parameters['n_channels']
 self.input_shape = [self.num_slice,self.voxel_size,self.voxel_size,self.n_channels]
 self.num_classes = parameters['num_classes']
 self.final_ksize = parameters['final_ksize']
 self.num_units = parameters['num_units']
 self.model_parameters_txt = os.path.join('./files',parameters['model_parameters_txt']) 
 self.model_map_parameters_txt = os.path.join('./files',parameters['model_map_parameters_txt']) 
 
 self.encoder = parameters['encoder']
 self.decoder = parameters['decoder']
 self.normalizer = parameters['normalizer']
 self.base_path = parameters['base_path']
 self.pre_model_file = parameters['pre_model_file']
 self.model = None
 
 
 if 'train' in key.lower():
 '''
 parameters for training only
 '''
 lr = parameters['initial_learning_rate']

 self.train_file = os.path.join('./files',parameters['train_file']) 
 self.val_file = os.path.join('./files',parameters['val_txt'])
 self.load_pre_trained = parameters['load_pre_trained']

 self.initial_learning_rate = parameters['initial_learning_rate']
 self.learning_rate_drop = parameters['learning_rate_drop']
 self.learning_rate_patience = parameters['learning_rate_patience']
 self.early_stop = parameters['early_stop']
 self.base_path = parameters['base_path']
 self.n_epochs = parameters['n_epochs']
 self.aug = parameters['aug']
 self.prefix = parameters['prefix']
 self.train_batch_size = parameters['train_batch_size']
 self.val_batch_size = parameters['val_batch_size']
 self.shuffle = parameters['shuffle']
 
 self.trainfiles = self.LoadGroupFiles(self.train_file)
 self.valfiles = self.LoadGroupFiles(self.val_file)
 
 else:
 self.test_file = os.path.join('./files',parameters['test_file'])
 self.testfiles = self.LoadGroupFiles(self.test_file)
 
 
 def LoadGroupFiles(self,filename):
 result = LoadJson(filename)
 data_paths = [record['data_path'] for record in result]
 mask_paths = [record['mask_path'] for record in result]
 info_paths = [record['info_path'] for record in result]
 return {
 'data_paths':data_paths,
 'mask_paths':mask_paths,
 'info_paths':info_paths
 }
 
 def BuildSegModel(self):
 '''
 Generate encoder/decoder parameters with model_parameters_txt and model name
 '''

 
 self.encoder_parameter_list = GenerateParameters(self.model_parameters_txt,self.model_map_parameters_txt,model=self.encoder)
 self.decoder_parameter_list = GenerateParameters(self.model_parameters_txt,self.model_map_parameters_txt,model=self.decoder)


 '''
 Rectify parameters 
 '''
 self.encoder_parameter_list = RectifyParameters(self.encoder_parameter_list,self.parameters)
 self.decoder_parameter_list = RectifyParameters(self.decoder_parameter_list,self.parameters)
 

 '''
 Build encoder and decoder
 '''

 self.encoder_result = build_model(self.encoder,input_shape=self.input_shape,parameter_list=[self.encoder_parameter_list])
 self.model = build_model(self.decoder,input_shape=self.input_shape,parameter_list=[self.encoder_result,self.decoder_parameter_list])

# self.model.summary()
 
 def CompileSegModel(self):
 

 adam=optimizers.Adam(lr=self.initial_learning_rate,beta_1=0.9,beta_2=0.999,epsilon=1e-08)
 
 ## todo: change loss to parameters in config file
 seg_loss = loss_dice_coefficient_error
 cls_loss = NoduleSegAuxLoss

 final_loss = 0

 
 if self.decoder_parameter_list['deep_supervision']:
 final_loss = [seg_loss for _ in range(3)]
 elif self.decoder_parameter_list['ACF_choice']:
 final_loss = [seg_loss for _ in range(2)]
 else:
 final_loss = [seg_loss]
 
 if self.decoder_parameter_list['aux_task']:
 final_loss.append(cls_loss)

 if len(final_loss)==1:
 final_loss = final_loss[0]

 self.model.compile(optimizer= adam, loss=final_loss, metrics=['accuracy'])
 
 def GenerateTrainParameters(self):
 self.train_step_per_epoch = int(np.sum([np.load(path).shape[0] for path in self.trainfiles['info_paths']] )/self.train_batch_size)
 self.valid_step_per_epoch = int(np.sum([np.load(path).shape[0] for path in self.valfiles['info_paths']] )/self.val_batch_size )
 
 
# self.train_step_per_epoch = 2
# self.valid_step_per_epoch = 2
 
 self.train_generator = UNet_3D_Generator(self.trainfiles['data_paths'],self.trainfiles['mask_paths'],
 self.trainfiles['info_paths'],input_size=self.input_shape,
 batch_size = self.train_batch_size,aug=self.aug,
 HU_min = self.normalizer[0],HU_max=self.normalizer[1],crop_choice=True,
 config_dict=self.decoder_parameter_list)
 
 self.val_generator = UNet_3D_Generator(self.valfiles['data_paths'],self.valfiles['mask_paths'],
 self.valfiles['info_paths'],input_size=self.input_shape,
 batch_size = self.val_batch_size,aug=False,
 HU_min = self.normalizer[0],HU_max=self.normalizer[1],crop_choice=True,
 config_dict=self.decoder_parameter_list)
 
 current_DT = datetime.datetime.now()
 current_day = '%02d%02d' % (current_DT.month,current_DT.day)

 self.current_base_path = '%s/%s_%s/'%(self.base_path,self.prefix,current_day)
 if not os.path.exists(self.current_base_path):
 os.mkdir(self.current_base_path)
 
 
 for target_filename in [self.filename,self.model_parameters_txt,self.model_map_parameters_txt,self.train_file,self.val_file]:
 shutil.copyfile(self.filename,os.path.join(self.current_base_path,os.path.basename(target_filename)))
 
 
 self.filepath = str('%s/Train-{epoch:02d}-{val_loss:.5f}.hdf5'%(self.current_base_path))
 self.filepath = str(self.filepath)
 print ('self.filepath is ',type(self.filepath))
 self.callbacks_list = get_callbacks(self.filepath,initial_learning_rate = self.initial_learning_rate,
 learning_rate_drop = self.learning_rate_drop,learning_rate_epochs = None,
 learning_rate_patience = self.learning_rate_patience,
 early_stopping_patience = self.early_stop)
 
 
 def Train(self):
 
 self.BuildSegModel()
 print ('After building model')
 self.CompileSegModel()
 print ('After compiling model')
 self.LoadPretrainModel()
 print ('After loading pretrain model')
 self.GenerateTrainParameters()
 
 self.history_callbacks = self.model.fit_generator(self.train_generator,steps_per_epoch=self.train_step_per_epoch,epochs=self.n_epochs,
 callbacks=self.callbacks_list,shuffle=self.shuffle,
 validation_data=self.val_generator,validation_steps=self.valid_step_per_epoch)
 
 WriteJson(os.path.join(self.current_base_path,'history.json'),self.history_callbacks.history)
 
 
 def Inference(self):
 
 self.BuildSegModel()
 self.LoadPretrainModel()
 
 print ('Before Loading test data')
 self.LoadInferenceData()
 print ('After loading test data')
 
 deltas = [0,0,0]
 self.test_data = np.squeeze(np.asarray([Image_preprocess(single_data,deltas,ratio=1) for single_data in self.test_data]))
 self.test_mask = np.squeeze(np.asarray([Image_preprocess(single_mask,deltas,ratio=1) for single_mask in self.test_mask]))

 test_data_cp = DataNormlization(self.test_data,self.test_info,HU_min=self.normalizer[0],HU_max=self.normalizer[1],mode='HU')
 self.test_data = test_data_cp[...,np.newaxis]
 self.test_mask = self.test_mask[...,np.newaxis]
 
 predict_result = self.model.predict(self.test_data,batch_size=1)
 
 print ('After making prediction')
 inference_result_dict = {
 'model_weights':self.pre_model_file,
 'test_data':self.testfiles,
 'result':{}
 
 }
 
 print ('length of predict_result is',len(predict_result))
 for branch_idx in range(min(3,len(predict_result))):
 print ('branch_idx is ',branch_idx)
 final_result = predict_result[branch_idx]>0.05
 postprocess_result = np.asarray([np.squeeze(post_process(mask_image,spacing_z=1.0,inp_shape=self.input_shape,debug=False)[2]) for mask_image in final_result])
 result_dice_post = [dice_coefficient_post(y_true,y_pred) for y_true,y_pred in zip(np.squeeze(self.test_mask),np.squeeze(postprocess_result))]
 result_dice_infer = [dice_coefficient_post(y_true,y_pred) for y_true,y_pred in zip(np.squeeze(self.test_mask),np.squeeze(final_result))]
 inference_result_dict['result']['branch_%02d_post_dice'] = [np.mean(result_dice_post),np.std(result_dice_post)]
 inference_result_dict['result']['branch_%02d_infer_dice'] = [np.mean(result_dice_infer),np.std(result_dice_infer)]

 print ('dice result(label and mask after threshold): branch %02d is %s'%(branch_idx,str(inference_result_dict['result']['branch_%02d_infer_dice'])))
 print ('dice result(label and mask after postprocess): branch %02d is %s'%(branch_idx,str(inference_result_dict['result']['branch_%02d_infer_dice'])))
 
 print ('Before writing json file to %s'%os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json'))
 WriteJson(os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json'),inference_result_dict)
 print ('After writing json file')
 
 
 def LoadInferenceData(self):
 self.test_data_list = load_data([self.testfiles['data_paths'],self.testfiles['mask_paths'],
 self.testfiles['info_paths']])
 self.test_data,self.test_mask,self.test_info = self.test_data_list
 
 
 
 def LoadPretrainModel(self):
 if os.path.exists(self.pre_model_file) and self.model:
 self.model.load_weights(self.pre_model_file)
 elif not os.path.exists(self.pre_model_file):
 print ('Pretrain model path does not exist')
 else:
 print ('model has not been defined')

In [16]:
os.environ["CUDA_VISIBLE_DEVICES"]="6"

## 参数
模型相关基本参数定义在/files/train_config.json中。其他相关参数文件也定义在该文件中。
可以通过修改link的文件/模型名字/其他参数,调整模型、训练过程

In [14]:
train_file_path = './files/train_config.json'

In [15]:
train_obj = Segmentation3D(train_file_path)