{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 整体训练示例" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "import six\n", "import json\n", "import shutil\n", "import datetime\n", "import numpy as np\n", "import scipy.ndimage as nd\n", "\n", "sys.path.append('./models/backbones/')\n", "sys.path.append('./models/baseLayers/')\n", "sys.path.append('./models/decoders/')\n", "sys.path.append('./models/')\n", "sys.path.append('./utils/')\n", "\n", "from data_generator import UNet_3D_Generator\n", "from ImagePreprocess import Image_preprocess,DataNormlization\n", "from parameter_utils import RectifyParameters\n", "from ResultEval import dice_coefficient_post\n", "from data_aug import *\n", "from RWconfig import LoadJson,GenerateParameters,GenerateParameters,WriteJson\n", "from ReadData import load_data\n", "from ResultPostProcess import post_process\n", "\n", "from modelBuild import build_model\n", "from model_utils import get_block\n", "from Segmentation_loss import Tversky_loss,loss_dice_coefficient_error,NoduleSegAuxLoss\n", "from train_utils import get_callbacks\n", "\n", "\n", "\n", "from keras import optimizers\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 定义对应类实现相应功能,本类包括\n", "1. init\n", "2. LoadGroupFiles:因为数据存储成多个npy文件,多组data_path/mask_path/info_path写在train/val/test files json文件中。本函数用于将所有data_path,mask_path,info_path组合成数组\n", "3. BuildSegModel:搭建分割模型(train/inference都需要) 调用了GenerateParameters、RectifyParameters产生模型参数\n", "4. GenerateTrainParameters:产生训练相关的参数。(step/generator/callbacks等)\n", "5. CompileSegModel:compileModel\n", "6. Train:调用上面的函数,开始训练\n", "7. Inference:搭建模型、inference数据预处理、预测、后处理、产生指标(目前写定了用dice)\n", "8. LoadInferenceData:inference阶段载入数据\n", "9. LoadPretrainModel:根据config文件中定义的pretrain模型路径,载入模型。如果路径不存在,则提示" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class Segmentation3D(object):\n", " def __init__(self,filename):\n", " self.filename = filename\n", " current_config = LoadJson(filename)\n", " key = current_config.keys()[0]\n", " parameters = current_config[key]\n", " self.parameters = parameters\n", " ### Generate parameters\n", " \n", " '''\n", " General parameters\n", " '''\n", " \n", " self.num_slice = parameters['num_slice']\n", " self.voxel_size = parameters['voxel_size']\n", " self.n_channels = parameters['n_channels']\n", " self.input_shape = [self.num_slice,self.voxel_size,self.voxel_size,self.n_channels]\n", " self.num_classes = parameters['num_classes']\n", " self.final_ksize = parameters['final_ksize']\n", " self.num_units = parameters['num_units']\n", " self.model_parameters_txt = os.path.join('./files',parameters['model_parameters_txt']) \n", " self.model_map_parameters_txt = os.path.join('./files',parameters['model_map_parameters_txt']) \n", " \n", " self.encoder = parameters['encoder']\n", " self.decoder = parameters['decoder']\n", " self.normalizer = parameters['normalizer']\n", " self.base_path = parameters['base_path']\n", " self.pre_model_file = parameters['pre_model_file']\n", " self.model = None\n", " \n", " \n", " if 'train' in key.lower():\n", " '''\n", " parameters for training only\n", " '''\n", " lr = parameters['initial_learning_rate']\n", "\n", " self.train_file = os.path.join('./files',parameters['train_file']) \n", " self.val_file = os.path.join('./files',parameters['val_txt'])\n", " self.load_pre_trained = parameters['load_pre_trained']\n", "\n", " self.initial_learning_rate = parameters['initial_learning_rate']\n", " self.learning_rate_drop = parameters['learning_rate_drop']\n", " self.learning_rate_patience = parameters['learning_rate_patience']\n", " self.early_stop = parameters['early_stop']\n", " self.base_path = parameters['base_path']\n", " self.n_epochs = parameters['n_epochs']\n", " self.aug = parameters['aug']\n", " self.prefix = parameters['prefix']\n", " self.train_batch_size = parameters['train_batch_size']\n", " self.val_batch_size = parameters['val_batch_size']\n", " self.shuffle = parameters['shuffle']\n", " \n", " self.trainfiles = self.LoadGroupFiles(self.train_file)\n", " self.valfiles = self.LoadGroupFiles(self.val_file)\n", " \n", " else:\n", " self.test_file = os.path.join('./files',parameters['test_file'])\n", " self.testfiles = self.LoadGroupFiles(self.test_file)\n", " \n", " \n", " def LoadGroupFiles(self,filename):\n", " result = LoadJson(filename)\n", " data_paths = [record['data_path'] for record in result]\n", " mask_paths = [record['mask_path'] for record in result]\n", " info_paths = [record['info_path'] for record in result]\n", " return {\n", " 'data_paths':data_paths,\n", " 'mask_paths':mask_paths,\n", " 'info_paths':info_paths\n", " }\n", " \n", " def BuildSegModel(self):\n", " '''\n", " Generate encoder/decoder parameters with model_parameters_txt and model name\n", " '''\n", "\n", " \n", " self.encoder_parameter_list = GenerateParameters(self.model_parameters_txt,self.model_map_parameters_txt,model=self.encoder)\n", " self.decoder_parameter_list = GenerateParameters(self.model_parameters_txt,self.model_map_parameters_txt,model=self.decoder)\n", "\n", "\n", " '''\n", " Rectify parameters \n", " '''\n", " self.encoder_parameter_list = RectifyParameters(self.encoder_parameter_list,self.parameters)\n", " self.decoder_parameter_list = RectifyParameters(self.decoder_parameter_list,self.parameters)\n", " \n", "\n", " '''\n", " Build encoder and decoder\n", " '''\n", "\n", " self.encoder_result = build_model(self.encoder,input_shape=self.input_shape,parameter_list=[self.encoder_parameter_list])\n", " self.model = build_model(self.decoder,input_shape=self.input_shape,parameter_list=[self.encoder_result,self.decoder_parameter_list])\n", "\n", "# self.model.summary()\n", " \n", " def CompileSegModel(self):\n", " \n", "\n", " adam=optimizers.Adam(lr=self.initial_learning_rate,beta_1=0.9,beta_2=0.999,epsilon=1e-08)\n", " \n", " ## todo: change loss to parameters in config file\n", " seg_loss = loss_dice_coefficient_error\n", " cls_loss = NoduleSegAuxLoss\n", "\n", " final_loss = 0\n", "\n", " \n", " if self.decoder_parameter_list['deep_supervision']:\n", " final_loss = [seg_loss for _ in range(3)]\n", " elif self.decoder_parameter_list['ACF_choice']:\n", " final_loss = [seg_loss for _ in range(2)]\n", " else:\n", " final_loss = [seg_loss]\n", " \n", " if self.decoder_parameter_list['aux_task']:\n", " final_loss.append(cls_loss)\n", "\n", " if len(final_loss)==1:\n", " final_loss = final_loss[0]\n", "\n", " self.model.compile(optimizer= adam, loss=final_loss, metrics=['accuracy'])\n", " \n", " def GenerateTrainParameters(self):\n", " self.train_step_per_epoch = int(np.sum([np.load(path).shape[0] for path in self.trainfiles['info_paths']] )/self.train_batch_size)\n", " self.valid_step_per_epoch = int(np.sum([np.load(path).shape[0] for path in self.valfiles['info_paths']] )/self.val_batch_size )\n", " \n", " \n", "# self.train_step_per_epoch = 2\n", "# self.valid_step_per_epoch = 2\n", " \n", " self.train_generator = UNet_3D_Generator(self.trainfiles['data_paths'],self.trainfiles['mask_paths'],\n", " self.trainfiles['info_paths'],input_size=self.input_shape,\n", " batch_size = self.train_batch_size,aug=self.aug,\n", " HU_min = self.normalizer[0],HU_max=self.normalizer[1],crop_choice=True,\n", " config_dict=self.decoder_parameter_list)\n", " \n", " self.val_generator = UNet_3D_Generator(self.valfiles['data_paths'],self.valfiles['mask_paths'],\n", " self.valfiles['info_paths'],input_size=self.input_shape,\n", " batch_size = self.val_batch_size,aug=False,\n", " HU_min = self.normalizer[0],HU_max=self.normalizer[1],crop_choice=True,\n", " config_dict=self.decoder_parameter_list)\n", " \n", " current_DT = datetime.datetime.now()\n", " current_day = '%02d%02d' % (current_DT.month,current_DT.day)\n", "\n", " self.current_base_path = '%s/%s_%s/'%(self.base_path,self.prefix,current_day)\n", " if not os.path.exists(self.current_base_path):\n", " os.mkdir(self.current_base_path)\n", " \n", " \n", " for target_filename in [self.filename,self.model_parameters_txt,self.model_map_parameters_txt,self.train_file,self.val_file]:\n", " shutil.copyfile(self.filename,os.path.join(self.current_base_path,os.path.basename(target_filename)))\n", " \n", " \n", " self.filepath = str('%s/Train-{epoch:02d}-{val_loss:.5f}.hdf5'%(self.current_base_path))\n", " self.filepath = str(self.filepath)\n", " print ('self.filepath is ',type(self.filepath))\n", " self.callbacks_list = get_callbacks(self.filepath,initial_learning_rate = self.initial_learning_rate,\n", " learning_rate_drop = self.learning_rate_drop,learning_rate_epochs = None,\n", " learning_rate_patience = self.learning_rate_patience,\n", " early_stopping_patience = self.early_stop)\n", " \n", " \n", " def Train(self):\n", " \n", " self.BuildSegModel()\n", " print ('After building model')\n", " self.CompileSegModel()\n", " print ('After compiling model')\n", " self.LoadPretrainModel()\n", " print ('After loading pretrain model')\n", " self.GenerateTrainParameters()\n", " \n", " self.history_callbacks = self.model.fit_generator(self.train_generator,steps_per_epoch=self.train_step_per_epoch,epochs=self.n_epochs,\n", " callbacks=self.callbacks_list,shuffle=self.shuffle,\n", " validation_data=self.val_generator,validation_steps=self.valid_step_per_epoch)\n", " \n", " WriteJson(os.path.join(self.current_base_path,'history.json'),self.history_callbacks.history)\n", " \n", " \n", " def Inference(self):\n", " \n", " self.BuildSegModel()\n", " self.LoadPretrainModel()\n", " \n", " print ('Before Loading test data')\n", " self.LoadInferenceData()\n", " print ('After loading test data')\n", " \n", " deltas = [0,0,0]\n", " self.test_data = np.squeeze(np.asarray([Image_preprocess(single_data,deltas,ratio=1) for single_data in self.test_data]))\n", " self.test_mask = np.squeeze(np.asarray([Image_preprocess(single_mask,deltas,ratio=1) for single_mask in self.test_mask]))\n", "\n", " test_data_cp = DataNormlization(self.test_data,self.test_info,HU_min=self.normalizer[0],HU_max=self.normalizer[1],mode='HU')\n", " self.test_data = test_data_cp[...,np.newaxis]\n", " self.test_mask = self.test_mask[...,np.newaxis]\n", " \n", " predict_result = self.model.predict(self.test_data,batch_size=1)\n", " \n", " print ('After making prediction')\n", " inference_result_dict = {\n", " 'model_weights':self.pre_model_file,\n", " 'test_data':self.testfiles,\n", " 'result':{}\n", " \n", " }\n", " \n", " print ('length of predict_result is',len(predict_result))\n", " for branch_idx in range(min(3,len(predict_result))):\n", " print ('branch_idx is ',branch_idx)\n", " final_result = predict_result[branch_idx]>0.05\n", " postprocess_result = np.asarray([np.squeeze(post_process(mask_image,spacing_z=1.0,inp_shape=self.input_shape,debug=False)[2]) for mask_image in final_result])\n", " result_dice_post = [dice_coefficient_post(y_true,y_pred) for y_true,y_pred in zip(np.squeeze(self.test_mask),np.squeeze(postprocess_result))]\n", " result_dice_infer = [dice_coefficient_post(y_true,y_pred) for y_true,y_pred in zip(np.squeeze(self.test_mask),np.squeeze(final_result))]\n", " inference_result_dict['result']['branch_%02d_post_dice'] = [np.mean(result_dice_post),np.std(result_dice_post)]\n", " inference_result_dict['result']['branch_%02d_infer_dice'] = [np.mean(result_dice_infer),np.std(result_dice_infer)]\n", "\n", " print ('dice result(label and mask after threshold): branch %02d is %s'%(branch_idx,str(inference_result_dict['result']['branch_%02d_infer_dice'])))\n", " print ('dice result(label and mask after postprocess): branch %02d is %s'%(branch_idx,str(inference_result_dict['result']['branch_%02d_infer_dice'])))\n", " \n", " print ('Before writing json file to %s'%os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json'))\n", " WriteJson(os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json'),inference_result_dict)\n", " print ('After writing json file')\n", " \n", " \n", " def LoadInferenceData(self):\n", " self.test_data_list = load_data([self.testfiles['data_paths'],self.testfiles['mask_paths'],\n", " self.testfiles['info_paths']])\n", " self.test_data,self.test_mask,self.test_info = self.test_data_list\n", " \n", " \n", " \n", " def LoadPretrainModel(self):\n", " if os.path.exists(self.pre_model_file) and self.model:\n", " self.model.load_weights(self.pre_model_file)\n", " elif not os.path.exists(self.pre_model_file):\n", " print ('Pretrain model path does not exist')\n", " else:\n", " print ('model has not been defined')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"6\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 参数\n", "模型相关基本参数定义在/files/train_config.json中。其他相关参数文件也定义在该文件中。\n", "可以通过修改link的文件/模型名字/其他参数,调整模型、训练过程" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "train_file_path = './files/train_config.json'" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "train_obj = Segmentation3D(train_file_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.15" } }, "nbformat": 4, "nbformat_minor": 2 }