{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 整体训练示例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import six\n",
    "import json\n",
    "import shutil\n",
    "import datetime\n",
    "import numpy as np\n",
    "import scipy.ndimage as nd\n",
    "\n",
    "sys.path.append('./models/backbones/')\n",
    "sys.path.append('./models/baseLayers/')\n",
    "sys.path.append('./models/decoders/')\n",
    "sys.path.append('./models/')\n",
    "sys.path.append('./utils/')\n",
    "\n",
    "from data_generator import UNet_3D_Generator\n",
    "from ImagePreprocess import Image_preprocess,DataNormlization\n",
    "from parameter_utils import RectifyParameters\n",
    "from ResultEval import dice_coefficient_post\n",
    "from data_aug import *\n",
    "from RWconfig import LoadJson,GenerateParameters,GenerateParameters,WriteJson\n",
    "from ReadData import load_data\n",
    "from ResultPostProcess import post_process\n",
    "\n",
    "from modelBuild import build_model\n",
    "from model_utils import get_block\n",
    "from Segmentation_loss import Tversky_loss,loss_dice_coefficient_error,NoduleSegAuxLoss\n",
    "from train_utils import get_callbacks\n",
    "\n",
    "\n",
    "\n",
    "from keras import optimizers\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义对应类实现相应功能,本类包括\n",
    "1. init\n",
    "2. LoadGroupFiles:因为数据存储成多个npy文件,多组data_path/mask_path/info_path写在train/val/test files json文件中。本函数用于将所有data_path,mask_path,info_path组合成数组\n",
    "3. BuildSegModel:搭建分割模型(train/inference都需要) 调用了GenerateParameters、RectifyParameters产生模型参数\n",
    "4. GenerateTrainParameters:产生训练相关的参数。(step/generator/callbacks等)\n",
    "5. CompileSegModel:compileModel\n",
    "6. Train:调用上面的函数,开始训练\n",
    "7. Inference:搭建模型、inference数据预处理、预测、后处理、产生指标(目前写定了用dice)\n",
    "8. LoadInferenceData:inference阶段载入数据\n",
    "9. LoadPretrainModel:根据config文件中定义的pretrain模型路径,载入模型。如果路径不存在,则提示"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Segmentation3D(object):\n",
    "    def __init__(self,filename):\n",
    "        self.filename = filename\n",
    "        current_config = LoadJson(filename)\n",
    "        key = current_config.keys()[0]\n",
    "        parameters = current_config[key]\n",
    "        self.parameters = parameters\n",
    "        ### Generate parameters\n",
    "        \n",
    "        '''\n",
    "        General parameters\n",
    "        '''\n",
    "        \n",
    "        self.num_slice = parameters['num_slice']\n",
    "        self.voxel_size = parameters['voxel_size']\n",
    "        self.n_channels = parameters['n_channels']\n",
    "        self.input_shape = [self.num_slice,self.voxel_size,self.voxel_size,self.n_channels]\n",
    "        self.num_classes = parameters['num_classes']\n",
    "        self.final_ksize = parameters['final_ksize']\n",
    "        self.num_units = parameters['num_units']\n",
    "        self.model_parameters_txt = os.path.join('./files',parameters['model_parameters_txt']) \n",
    "        self.model_map_parameters_txt = os.path.join('./files',parameters['model_map_parameters_txt']) \n",
    "        \n",
    "        self.encoder = parameters['encoder']\n",
    "        self.decoder = parameters['decoder']\n",
    "        self.normalizer = parameters['normalizer']\n",
    "        self.base_path = parameters['base_path']\n",
    "        self.pre_model_file = parameters['pre_model_file']\n",
    "        self.model = None\n",
    "        \n",
    "        \n",
    "        if 'train' in key.lower():\n",
    "            '''\n",
    "            parameters for training only\n",
    "            '''\n",
    "            lr = parameters['initial_learning_rate']\n",
    "\n",
    "            self.train_file = os.path.join('./files',parameters['train_file']) \n",
    "            self.val_file = os.path.join('./files',parameters['val_txt'])\n",
    "            self.load_pre_trained = parameters['load_pre_trained']\n",
    "\n",
    "            self.initial_learning_rate = parameters['initial_learning_rate']\n",
    "            self.learning_rate_drop = parameters['learning_rate_drop']\n",
    "            self.learning_rate_patience = parameters['learning_rate_patience']\n",
    "            self.early_stop = parameters['early_stop']\n",
    "            self.base_path = parameters['base_path']\n",
    "            self.n_epochs = parameters['n_epochs']\n",
    "            self.aug = parameters['aug']\n",
    "            self.prefix = parameters['prefix']\n",
    "            self.train_batch_size = parameters['train_batch_size']\n",
    "            self.val_batch_size  = parameters['val_batch_size']\n",
    "            self.shuffle = parameters['shuffle']\n",
    "            \n",
    "            self.trainfiles = self.LoadGroupFiles(self.train_file)\n",
    "            self.valfiles = self.LoadGroupFiles(self.val_file)\n",
    "        \n",
    "        else:\n",
    "            self.test_file = os.path.join('./files',parameters['test_file'])\n",
    "            self.testfiles = self.LoadGroupFiles(self.test_file)\n",
    "            \n",
    "    \n",
    "    def LoadGroupFiles(self,filename):\n",
    "        result = LoadJson(filename)\n",
    "        data_paths = [record['data_path'] for record in result]\n",
    "        mask_paths = [record['mask_path'] for record in result]\n",
    "        info_paths = [record['info_path'] for record in result]\n",
    "        return {\n",
    "            'data_paths':data_paths,\n",
    "            'mask_paths':mask_paths,\n",
    "            'info_paths':info_paths\n",
    "        }\n",
    "            \n",
    "    def BuildSegModel(self):\n",
    "        '''\n",
    "        Generate encoder/decoder parameters with model_parameters_txt and model name\n",
    "        '''\n",
    "\n",
    "    \n",
    "        self.encoder_parameter_list = GenerateParameters(self.model_parameters_txt,self.model_map_parameters_txt,model=self.encoder)\n",
    "        self.decoder_parameter_list = GenerateParameters(self.model_parameters_txt,self.model_map_parameters_txt,model=self.decoder)\n",
    "\n",
    "\n",
    "        '''\n",
    "        Rectify parameters \n",
    "        '''\n",
    "        self.encoder_parameter_list = RectifyParameters(self.encoder_parameter_list,self.parameters)\n",
    "        self.decoder_parameter_list = RectifyParameters(self.decoder_parameter_list,self.parameters)\n",
    "        \n",
    "\n",
    "        '''\n",
    "        Build encoder and decoder\n",
    "        '''\n",
    "\n",
    "        self.encoder_result = build_model(self.encoder,input_shape=self.input_shape,parameter_list=[self.encoder_parameter_list])\n",
    "        self.model = build_model(self.decoder,input_shape=self.input_shape,parameter_list=[self.encoder_result,self.decoder_parameter_list])\n",
    "\n",
    "#         self.model.summary()\n",
    "        \n",
    "    def CompileSegModel(self):\n",
    "        \n",
    "\n",
    "        adam=optimizers.Adam(lr=self.initial_learning_rate,beta_1=0.9,beta_2=0.999,epsilon=1e-08)\n",
    "        \n",
    "        ## todo: change loss to parameters in config file\n",
    "        seg_loss = loss_dice_coefficient_error\n",
    "        cls_loss = NoduleSegAuxLoss\n",
    "\n",
    "        final_loss = 0\n",
    "\n",
    "        \n",
    "        if self.decoder_parameter_list['deep_supervision']:\n",
    "            final_loss = [seg_loss for _ in range(3)]\n",
    "        elif self.decoder_parameter_list['ACF_choice']:\n",
    "            final_loss = [seg_loss for _ in range(2)]\n",
    "        else:\n",
    "            final_loss = [seg_loss]\n",
    "            \n",
    "        if self.decoder_parameter_list['aux_task']:\n",
    "            final_loss.append(cls_loss)\n",
    "\n",
    "        if len(final_loss)==1:\n",
    "            final_loss = final_loss[0]\n",
    "\n",
    "        self.model.compile(optimizer= adam,  loss=final_loss, metrics=['accuracy'])\n",
    "    \n",
    "    def GenerateTrainParameters(self):\n",
    "        self.train_step_per_epoch = int(np.sum([np.load(path).shape[0] for path in self.trainfiles['info_paths']] )/self.train_batch_size)\n",
    "        self.valid_step_per_epoch = int(np.sum([np.load(path).shape[0] for path in self.valfiles['info_paths']] )/self.val_batch_size )\n",
    "        \n",
    "        \n",
    "#         self.train_step_per_epoch = 2\n",
    "#         self.valid_step_per_epoch = 2\n",
    "        \n",
    "        self.train_generator = UNet_3D_Generator(self.trainfiles['data_paths'],self.trainfiles['mask_paths'],\n",
    "                                                 self.trainfiles['info_paths'],input_size=self.input_shape,\n",
    "                                                 batch_size = self.train_batch_size,aug=self.aug,\n",
    "                                                 HU_min = self.normalizer[0],HU_max=self.normalizer[1],crop_choice=True,\n",
    "                                                config_dict=self.decoder_parameter_list)\n",
    "         \n",
    "        self.val_generator = UNet_3D_Generator(self.valfiles['data_paths'],self.valfiles['mask_paths'],\n",
    "                                             self.valfiles['info_paths'],input_size=self.input_shape,\n",
    "                                             batch_size = self.val_batch_size,aug=False,\n",
    "                                             HU_min = self.normalizer[0],HU_max=self.normalizer[1],crop_choice=True,\n",
    "                                              config_dict=self.decoder_parameter_list)\n",
    "        \n",
    "        current_DT = datetime.datetime.now()\n",
    "        current_day = '%02d%02d' % (current_DT.month,current_DT.day)\n",
    "\n",
    "        self.current_base_path = '%s/%s_%s/'%(self.base_path,self.prefix,current_day)\n",
    "        if not os.path.exists(self.current_base_path):\n",
    "            os.mkdir(self.current_base_path)\n",
    "            \n",
    "            \n",
    "        for target_filename in [self.filename,self.model_parameters_txt,self.model_map_parameters_txt,self.train_file,self.val_file]:\n",
    "            shutil.copyfile(self.filename,os.path.join(self.current_base_path,os.path.basename(target_filename)))\n",
    "        \n",
    "        \n",
    "        self.filepath = str('%s/Train-{epoch:02d}-{val_loss:.5f}.hdf5'%(self.current_base_path))\n",
    "        self.filepath = str(self.filepath)\n",
    "        print ('self.filepath is ',type(self.filepath))\n",
    "        self.callbacks_list = get_callbacks(self.filepath,initial_learning_rate = self.initial_learning_rate,\n",
    "                                  learning_rate_drop = self.learning_rate_drop,learning_rate_epochs = None,\n",
    "                                  learning_rate_patience = self.learning_rate_patience,\n",
    "                                  early_stopping_patience = self.early_stop)\n",
    "        \n",
    "        \n",
    "    def Train(self):\n",
    "        \n",
    "        self.BuildSegModel()\n",
    "        print ('After building model')\n",
    "        self.CompileSegModel()\n",
    "        print ('After compiling model')\n",
    "        self.LoadPretrainModel()\n",
    "        print ('After loading pretrain model')\n",
    "        self.GenerateTrainParameters()\n",
    "                                  \n",
    "        self.history_callbacks = self.model.fit_generator(self.train_generator,steps_per_epoch=self.train_step_per_epoch,epochs=self.n_epochs,\n",
    "                                                     callbacks=self.callbacks_list,shuffle=self.shuffle,\n",
    "                                                     validation_data=self.val_generator,validation_steps=self.valid_step_per_epoch)\n",
    "        \n",
    "        WriteJson(os.path.join(self.current_base_path,'history.json'),self.history_callbacks.history)\n",
    "                                                    \n",
    "    \n",
    "    def Inference(self):\n",
    "        \n",
    "        self.BuildSegModel()\n",
    "        self.LoadPretrainModel()\n",
    "        \n",
    "        print ('Before Loading test data')\n",
    "        self.LoadInferenceData()\n",
    "        print ('After loading test data')\n",
    "        \n",
    "        deltas = [0,0,0]\n",
    "        self.test_data = np.squeeze(np.asarray([Image_preprocess(single_data,deltas,ratio=1) for single_data in self.test_data]))\n",
    "        self.test_mask = np.squeeze(np.asarray([Image_preprocess(single_mask,deltas,ratio=1) for single_mask in self.test_mask]))\n",
    "\n",
    "        test_data_cp = DataNormlization(self.test_data,self.test_info,HU_min=self.normalizer[0],HU_max=self.normalizer[1],mode='HU')\n",
    "        self.test_data = test_data_cp[...,np.newaxis]\n",
    "        self.test_mask = self.test_mask[...,np.newaxis]\n",
    "        \n",
    "        predict_result = self.model.predict(self.test_data,batch_size=1)\n",
    "        \n",
    "        print ('After making prediction')\n",
    "        inference_result_dict = {\n",
    "            'model_weights':self.pre_model_file,\n",
    "            'test_data':self.testfiles,\n",
    "            'result':{}\n",
    "            \n",
    "        }\n",
    "        \n",
    "        print ('length of predict_result is',len(predict_result))\n",
    "        for branch_idx in range(min(3,len(predict_result))):\n",
    "            print ('branch_idx is ',branch_idx)\n",
    "            final_result = predict_result[branch_idx]>0.05\n",
    "            postprocess_result = np.asarray([np.squeeze(post_process(mask_image,spacing_z=1.0,inp_shape=self.input_shape,debug=False)[2]) for mask_image in final_result])\n",
    "            result_dice_post = [dice_coefficient_post(y_true,y_pred) for y_true,y_pred in zip(np.squeeze(self.test_mask),np.squeeze(postprocess_result))]\n",
    "            result_dice_infer = [dice_coefficient_post(y_true,y_pred) for y_true,y_pred in zip(np.squeeze(self.test_mask),np.squeeze(final_result))]\n",
    "            inference_result_dict['result']['branch_%02d_post_dice'] = [np.mean(result_dice_post),np.std(result_dice_post)]\n",
    "            inference_result_dict['result']['branch_%02d_infer_dice'] = [np.mean(result_dice_infer),np.std(result_dice_infer)]\n",
    "\n",
    "            print ('dice result(label and mask after threshold): branch %02d is %s'%(branch_idx,str(inference_result_dict['result']['branch_%02d_infer_dice'])))\n",
    "            print ('dice result(label and mask after postprocess): branch %02d is %s'%(branch_idx,str(inference_result_dict['result']['branch_%02d_infer_dice'])))\n",
    "        \n",
    "        print ('Before writing json file to %s'%os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json'))\n",
    "        WriteJson(os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json'),inference_result_dict)\n",
    "        print ('After writing json file')\n",
    "        \n",
    "    \n",
    "    def LoadInferenceData(self):\n",
    "        self.test_data_list = load_data([self.testfiles['data_paths'],self.testfiles['mask_paths'],\n",
    "                                                 self.testfiles['info_paths']])\n",
    "        self.test_data,self.test_mask,self.test_info = self.test_data_list\n",
    "        \n",
    "    \n",
    "    \n",
    "    def LoadPretrainModel(self):\n",
    "        if os.path.exists(self.pre_model_file) and self.model:\n",
    "            self.model.load_weights(self.pre_model_file)\n",
    "        elif not os.path.exists(self.pre_model_file):\n",
    "            print ('Pretrain model path does not exist')\n",
    "        else:\n",
    "            print ('model has not been defined')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"6\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参数\n",
    "模型相关基本参数定义在/files/train_config.json中。其他相关参数文件也定义在该文件中。\n",
    "可以通过修改link的文件/模型名字/其他参数,调整模型、训练过程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_file_path = './files/train_config.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_obj = Segmentation3D(train_file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}