# -*- coding:utf-8 -*-
import os
import sys
import six
import glob
import json
import random
import shutil
import datetime
import numpy as np
import pandas as pd
import scipy.ndimage as nd
from tqdm import tqdm
import datetime
import logging
import torch
import torch.nn as nn
from torch import optim
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
# from torch.utils.tensorboard import SummaryWriter
from tensorboard_logger import Logger

sys.path.append('./utils')
sys.path.append('./model')
from RWconfig import LoadJson
from ReadData import load_data,load_single_data
from modelBuild import build_model
from BasicModules import weight_init
from PneuGen import PneuDataSet
from OnlineEval import eval_net,eval_netSE,eval_netFSE
from OfflineEval import CalculateClsScore,CalculateAuc,CalculateClsScoreByTh,PlotRoc
from loss_func import *

logger = logging.getLogger()
fh = logging.FileHandler('SimpleCls.log',encoding='utf-8')
sh = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s')
fh.setFormatter(formatter)  
sh.setFormatter(formatter)

logger.addHandler(fh)   
logger.addHandler(sh)

logger.setLevel(10)   

class Classification3D(object):
    def __init__(self,filename,mode):
        self.mode = mode
        self.filename = filename
        self.cfg = LoadJson(self.filename)[mode]
        self.basic_file_path = './files'
        current_DT = datetime.datetime.now()
        self.current_day = '%02d%02d'%(current_DT.month,current_DT.day)
        self.n_input_channel = 1
        self._clsInfoDecode()
        

            
        self.ParametersDecode()
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    def __call__(self):

        '''
        训练/infer流程
        1. buildModel&compileModel
        2. 载入数据
        3. 根据选择模式,进行训练or inference
        '''
        self._makeTensorBoard()
        self._buildModel()
        self._compileModel()
    
        
        self._LoadData(self.mode)
        if self.mode=='training':
            self._trainModel()
        else:
            self._testModel()
            

    def _makeTensorBoard(self):
        tensorboard_path = "%s/tensorboard_logs"%self.save_path
        self.tensorboard_logger = Logger(logdir=tensorboard_path,flush_secs=10)
        logger.info('tensorboard path is %s'%tensorboard_path)


    def _clsInfoDecode(self):
        self.cls_name_map_dict = eval(self.cfg['cls_name_map_dict'])
        model_order = self.cfg['model_order'] if 'model_order' in self.cfg.keys() else 1
        if type(self.cfg['cls_label_index'])==int:
            cls_related_info = self.cls_name_map_dict[self.cfg['cls_label_index']]
        else:
            cls_related_info = '_'.join([self.cls_name_map_dict[val] for val in self.cfg['cls_label_index']])
        self.save_path = '%s/%s_%s_%dCls_%s_model%s/%02d_%s/'%(self.cfg['base_path'],self.cfg['weights_pre'],cls_related_info,self.cfg['num_class'],self.current_day,self.cfg['model_name'],model_order,self.cfg['weight_memo'])
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        if type(self.cfg['cls_label_index'])==list:
            self.cfg['num_task'] = len(self.cfg['cls_label_index'])
        else:
            self.cfg['num_task'] = 1

    def _buildModel(self):
        '''
        buildModel 流程
        1. 读取模型参数并修改部分
            1. num_task:有几个要分类的任务
            2. num_class:有几类(目前多个要分类的任务类别数目必须相同)
            3. arcface_norm_choice:arcface是否进行归一化
            4. easy_margin: arcface相关参数
                如果设置为True,只要cos_theta大于0,就设置为乘性压缩的角度
                否则,要cos_theta大于th,设置为乘性压缩角度,cos_theta小于th的情况,设置为减性压缩角度
            5. freeze_blocks:不训练的block
        '''

        cfg = self.model_params
        cfg['num_task'] = self.cfg['num_task']
        cfg['num_class'] = self.cfg['num_class'] 
        cfg['freeze_blocks'] = eval(self.cfg.get('freeze_blocks',"[]"))
        
        if cfg['num_class'] == 2:
            cfg['num_class'] =1

        self.n_input_channel = cfg['n_input_channel'] if 'n_input_channel' in cfg.keys() else self.n_input_channel
        self.net = build_model(self.cfg['model_name'],cfg)
        self.net_config = cfg
        self.pretrain_weight_path = self.cfg.get('weight_path','')
        if self.pretrain_weight_path =='':    
            self.net.apply(weight_init)

    def _compileModel(self):
        self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.net.parameters()),lr=self.cfg['learning_rate'])

        lr_schedule = self.cfg.get('lr_schedule','basic')
        if lr_schedule=='basic':
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,patience=self.cfg['patience'],
                                             factor=0.5,verbose=True)
        else:
            self.scheduler = optim.lr_scheduler.CyclicLR(self.optimizer,
                                    base_lr=self.cfg.get('base_lr',1e-5),
                                    max_lr=self.cfg.get('max_lr',1e-4),
                                    step_size_up=self.cfg.get('step_size_up',2000),
                                    cycle_momentum=False)

        if self.cfg['num_class']>2:
            self.criterion = nn.CrossEntropyLoss()
        else:
            self.criterion = nn.BCEWithLogitsLoss()

        conf_paths = [self.filename,self.model_param_name]
        # if self.mode == 'training':
        #     conf_paths += [self.train_data_files,self.val_data_files]
        # else:
        #     conf_paths += [self.test_data_files]
        for conf_path in conf_paths:
            filename = conf_path.split('/')[-1]
            target_save_path = os.path.join(self.save_path,filename)
            shutil.copy(conf_path,target_save_path)

    def _DecodeGeneratorParam(self):
        training_mode = self.cfg.get('training_mode',2)
        nonfix_crop = self.cfg['nonfix_crop'] if 'nonfix_crop' in self.cfg.keys() else False
        aug = True if self.mode=='training' else False
        model_mode = self.cfg['model_mode'] if 'model_mode' in self.cfg.keys() else 'basic'
        patch_depth = self.cfg['patch_depth']
        patch_height = self.cfg['patch_height']
        patch_width = self.cfg['patch_width']
        patch_shape = [patch_depth,patch_height,patch_width]
        
        self.generator_parameters = {
            'cls_index':self.cfg['cls_label_index'],
            'aug':aug,
            'num_class':self.cfg['num_class'],
            'HU_min':self.cfg['HU_min'],'HU_max':self.cfg['HU_max'],
            'n_input_channel':self.n_input_channel,
            "model_mode":model_mode,
            "patch_shape":patch_shape,
            "training_mode":training_mode,
            "mode":self.mode
        }


    def _DecodeDataParam(self,filename):
        paths = LoadJson(filename)
        data_paths = [record['data_path'] for record in paths]
        info_paths = [record['info_path'] for record in paths]
        return {'data_paths':data_paths,
                'info_paths':info_paths}

    def _parseDataJosn(self,filename):
        df_files = []
        target_filename = os.path.join(self.save_path,filename.split('/')[-1])
        filename = 'files/%s'%filename
        files = LoadJson(filename)
        
        for record in files:
            df_file = record["df_path"]
            df_files.append(df_file)

        return df_files

    def _LoadData(self,mode):
        '''
        载入npy格式的数据
        '''
        train_data_files = self.cfg["train_file"]
        val_data_files = self.cfg["val_file"]
        test_data_files = self.cfg["test_file"]
        
        self.train_paths = self._parseDataJosn(train_data_files)
        self.val_paths =  self._parseDataJosn(val_data_files)
        self.test_paths = self._parseDataJosn(test_data_files)

        self.train_data_files = train_data_files
        self.val_data_files = val_data_files
        self.test_data_files = test_data_files
        
    def _GetLoss(self,prob_diff,gamma):

        loss_choice = self.cfg['loss_func'] if 'loss_func' in self.cfg.keys() else 'focal'

        '''
        基于focal loss基本的weights(1/该类样本量->归一化),是否要对某一类的weights进行修改,如果否就设置为None,否则设置为对应的loss
        '''
        alpha_multi_ratio = eval(self.cfg['alpha_multi_ratio']) if 'alpha_multi_ratio' in self.cfg.keys() else None

        self.convert_softmax = True
        
        if self.cfg['num_class']==1:
            ratios =self.train_dataset._GenerateRatios()
            alpha = ratios[0] 
            self.criterion = BCEFocalLoss(alpha=alpha,gamma=gamma)
            if self.cfg.get('model_mode','basic')=='deux':
                self.criterion_train = SeLossWithBCEFocalLoss(lambda_val=self.cfg.get('lambda_val',1))
                    
        else:
            ratios =self.train_dataset._GenerateRatios()
            self.criterion = FocalLoss(self.cfg['num_class'],alpha=ratios,gamma=gamma,alpha_multi_ratio=alpha_multi_ratio,convert_softmax=self.convert_softmax)

        print ('criterion is',self.criterion)

        
    def _trainOnBatch(self,batch,mode='train'):
        self.loss_weights = self.cfg.get('loss_weights',[1.0 for _ in range(3)])

        label_dtype = torch.long if self.cfg['num_class']>2 else torch.float32
        imgs,labels,_ = batch[0]
        
        imgs = imgs
        labels = labels
        imgs = np.transpose(imgs,(1,0,2,3,4,5))
        imgs = imgs.cuda().to(device=self.device,dtype=torch.float32)

        output_labels = []
        pred_count = 0
        for single_label in labels:
            current_label = []
            pred_count = len(single_label)
            for idx in range(len(single_label)):
                current_label.append(np.array(single_label[idx]))
            current_label = np.array(current_label)[...,np.newaxis]
            current_label = torch.from_numpy(current_label).cuda().to(device=self.device,dtype=torch.float32)
            output_labels.append(current_label)
        labels = output_labels

        if mode=='val':
            with torch.no_grad():
                prediction = self.net(x1=imgs[0],x2=imgs[1])
        else:
            prediction = self.net(x1=imgs[0],x2=imgs[1])
        
        loss = 0

        if type(prediction)!=list and type(prediction)!=tuple:
            prediction = [prediction]
        for idx in range(min(len(labels),len(prediction))):
            current_loss = self.criterion([prediction[idx],labels[idx]])
            loss += current_loss
            if mode=='val':
                self.val_loss[idx].append(current_loss.item())
            else:
                self.train_loss[idx].append(current_loss.item())

        return loss
    
    def _trainOnBatchV2(self,batch,mode='train'):
        label_dtype = torch.long if self.cfg['num_class']>2 else torch.float32
        imgs,labels,_ = batch[0]
        
        imgs = imgs[0]
        labels = labels[0]

        imgs = imgs.cuda().to(device=self.device,dtype=torch.float32)
        labels = labels.cuda().to(device=self.device,dtype=torch.float32)

        if mode=='val':
            with torch.no_grad():
                prediction = self.net(imgs) 
        else:
            prediction = self.net(imgs)
        loss = self.criterion([prediction,labels])

        idx = 0

        if mode=='val':
                self.val_loss[idx].append(current_loss.item())
        else:
            self.train_loss[idx].append(current_loss.item())

        return loss

    def _trainOnBatchWithCons(self,batch,mode='train'):
        label_dtype = torch.long if self.cfg['num_class']>2 else torch.float32
        imgs,labels,_ = batch[0]

        imgs = imgs[0]
        labels = labels[0]
 
        indices = np.arange(imgs.shape[0])[::-1]
        cpu_labels = labels.cpu().numpy()
        cpu_labels = np.take(cpu_labels,indices,axis=0)
        cpu_labels = cpu_labels.copy()

        imgs = imgs.cuda().to(device=self.device,dtype=torch.float32)
        labels = labels.cuda().to(device=self.device,dtype=torch.float32)
        cpu_labels = torch.tensor(cpu_labels).cuda().to(device=self.device,dtype=torch.float32)

        if mode=='val':
            with torch.no_grad():
                prediction = self.net(imgs) 
        else:
            prediction = self.net(imgs)

        prediction_cpu = prediction.detach().cpu().numpy()
        prediction_reverse = np.take(prediction_cpu,indices,axis=0)
        prediction_reverse = torch.tensor(prediction_reverse).cuda().to(device=self.device,dtype=torch.float32)
 
        loss_input = [[prediction,prediction_reverse],[labels,cpu_labels],[]]
        loss = self.criterion_train(loss_input)
        
        return loss

    
        
    def _trainModel(self):
        ########## Add parameters
        training_mode = self.cfg.get('training_mode',2)

        '''
        设置训练的device
        '''
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        '''
        如果pretrain_weight_path不为空,则load进来
        '''
        if self.pretrain_weight_path!='':
            self.net.load_state_dict(torch.load(self.pretrain_weight_path,map_location=device))
            logger.info("load model from %s"%self.pretrain_weight_path)
        
        '''
        读取关于loss和数据处理的参数
        nonfix_crop:输入的大小是否一致
        train_aug_choice:train的数据是否进行augmentation
        gamma:focal loss参数
        feature_sim:
        prob_diff:
        '''
        train_aug_choice = self.cfg['train_aug_choice'] if 'train_aug_choice' in self.cfg.keys() else False
        
        gamma = self.cfg['gamma'] if 'gamma' in self.cfg.keys() else 1
        prob_diff = self.cfg['prob_diff'] if 'prob_diff' in self.cfg.keys() else False
        
        ############ generate params for generator
        self._DecodeGeneratorParam()            

        train_dataset = PneuDataSet(self.train_paths,**self.generator_parameters)
        self.generator_parameters['aug'] = False 
        val_dataset = PneuDataSet(self.val_paths,**self.generator_parameters)
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        
        '''
        根据参数情况得到训练的loss
        '''
        self._GetLoss(prob_diff,gamma)
        
        train_loader = DataLoader(train_dataset,batch_size=self.cfg['batch_size'],shuffle=True)
        val_loader = DataLoader(val_dataset,batch_size=self.cfg['batch_size'],shuffle=True)
        self.val_loader = val_loader

        lr=self.cfg['learning_rate']
        batch_size = self.cfg['batch_size']
        global_step = 0

        epochs = self.cfg['num_epoch']
        n_train = int(round(len(train_dataset)/self.cfg['batch_size']))
        loss_names = {
            0:'left_loss',
            1:'right_loss',
            2:'whole_loss'
        }
        
        for epoch in range(epochs):
            self.net.train()
            self.net.to(self.device)
            self.net.cuda()
            epoch_loss = 0

            self.train_loss = [[] for _ in range(3)]
            self.val_loss = [[] for _ in range(3)]
            with tqdm(total=n_train,desc=f'Epoch{epoch+1}/{epochs}',unit='img') as pbar:
                batch_count = 0
                for batch in zip(train_loader):
                    try:
                        batch_count +=1
                        loss_values = []  
                        if self.cfg.get("model_mode","basic")=="deux":
                            loss = self._trainOnBatchWithCons(batch)/2.0
                        else:
                            if training_mode ==1:
                                loss = self._trainOnBatch(batch)
                            else:
                                loss = self._trainOnBatchV2(batch)

                        self.optimizer.zero_grad()
                        loss.backward()
                        epoch_loss += loss.item()
                        nn.utils.clip_grad_value_(self.net.parameters(),1e-1)
                        self.optimizer.step()
                        global_step += 1
                        pbar.update()


                    except Exception as err:
                        print ('err is',err)
                        continue
                    
            random.shuffle(train_dataset.indices)

            val_score = self.Val(val_loader,epoch,epochs)
            ############ change lr based on val result
            self.scheduler.step(val_score)
            logger.info("Train loss:{}".format(epoch_loss/n_train))
            logging.info("Validation loss:{}".format(val_score))

            self.tensorboard_logger.log_value('train_loss', epoch_loss/n_train, epoch)
            self.tensorboard_logger.log_value('val_loss', val_score, epoch)
            self._AddTensorData(self.train_loss,loss_names,'train',epoch)
            self._AddTensorData(self.val_loss,loss_names,'val',epoch)
            


            
            
            if self.cfg['save_cp']:
                model_save_path_current_epoch = self.save_path+f'/CP_epoch{epoch+1}_loss_%.2f.pth'%val_score
                torch.save(self.net.state_dict(),model_save_path_current_epoch)
                logging.info(f'Checkpoint{epoch+1} saved to %s!'%model_save_path_current_epoch)

    def _AddTensorData(self,losses,name_dict,mode,epoch):
        for idx in range(len(losses)):
            if len(losses[idx])>0:
                mean_val = np.mean(losses[idx])
                self.tensorboard_logger.log_value('%s_%s'%(mode,name_dict[idx]), mean_val, epoch)



    def Val(self,val_Loader,epoch,num_epoch):
        training_mode = self.cfg.get('training_mode',2)
        self.net.eval()
        total_loss = 0
        n_val = len(self.val_dataset)/self.cfg['batch_size']
        
        with tqdm(total=int(n_val),desc=f'Epoch{epoch+1}/{num_epoch}',unit='img') as pbar: 
            batch_count = 0
            for batch in zip(val_Loader):
                batch_count += 1
                if training_mode ==1:
                    loss = self._trainOnBatch(batch,mode='val')
                else:
                    loss = self._trainOnBatchV2(batch,mode='val')
                total_loss += loss.item()
                pbar.update()
        self.net.train()
        return total_loss/float(n_val)

    def _testModel(self):
        training_mode = self.cfg.get('training_mode',2)
        if '*' in self.cfg['weight_path']:
            model_paths = sorted(glob.glob(self.cfg['weight_path']))
            model_paths = [path for path in model_paths if '.pth' in path]
        else:
            model_paths = [self.cfg['weight_path']]
        results  = []
        result = []
        self._DecodeGeneratorParam()  
        infer_idx = -1

        self.convert_softmax = True
        test_dataset = PneuDataSet(self.test_paths,**self.generator_parameters)
    
        for model_path in model_paths:
            print ('='*60)
            print ('model_path is ',model_path)

            device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
            self.net.to(device=device)
            self.net.load_state_dict(torch.load(model_path,map_location=device))
            logging.info('Model loaded !')

            ########## 将模型设置为eval 固定BN/dropout的参数
            self.net.eval()
            

            self.label_list = []
            self.result_list = []
            pbar = tqdm(np.arange(len(test_dataset)))

            for idx in pbar:
                batch = test_dataset[idx]
                imgs,labels,label_diff = batch
                
                if training_mode==1:
                    imgs = [torch.from_numpy(img[np.newaxis,...]).to(device=self.device,dtype=torch.float32) for img in imgs]
 
                    labels = labels
                elif training_mode==2:
                    imgs = torch.from_numpy(imgs[0][np.newaxis,...]).to(device=self.device,dtype=torch.float32)
                    labels = labels[0]
                else:
                    imgs = torch.from_numpy(imgs).to(device=self.device,dtype=torch.float32)
                    labels = labels

                if training_mode==1:
                    label = labels[infer_idx]
                else:
                    label = labels
                with torch.no_grad():

                    if len(imgs)==2:
                        output = self.net(imgs[0],imgs[1])
                        output = [single_output.squeeze(0) for single_output in output]

                    else:
                        output = self.net(imgs)
                    

                    if self.cfg['num_task']==1:
                        if self.convert_softmax:
                            if training_mode==1:
                                if self.net.num_class>1:
                                    probs = torch.softmax(output[infer_idx],dim=1)
                                else:
                                    probs = torch.sigmoid(output[infer_idx])
                            elif training_mode==2:
                                if self.net.num_class>1:
                                    probs = torch.softmax(output[0],dim=1)
                                else:
                                    probs = torch.sigmoid(output[0])
                            else:
                                if self.net.num_class>1:
                                    probs = torch.softmax(output,dim=1)
                                else:
                                    probs = torch.sigmoid(output)

                                probs = probs.max()
                                # probs = probs.mean()
                            probs = probs.squeeze(0)
                            probs = probs.data.cpu().numpy()
                   
                    else:
                        current_probs = []
                        for task_idx in range(self.cfg['num_task']):
                            if self.net.num_class>1:
                                probs = torch.softmax(output[task_idx],dim=1)
                            else:
                                probs = torch.sigmoid(output[task_idx])
                            probs = probs.squeeze(0)
                            current_probs.append(probs.cpu().numpy()[0])
                        probs = current_probs
                    self.label_list.append(label)
                    self.result_list.append(probs)
            show_prob = self.cfg['show_prob'] if 'show_prob' in self.cfg.keys() else False
            if show_prob:
                target_cls_probs = self._GetTargetClassProb(self.label_list,self.result_list)
                self.ProbVis(self.label_list,target_cls_probs)
            
     
            print ('length of label and result',len(self.label_list),len(self.result_list))
            print ('shape of result list is ',np.array(self.result_list).shape)
            print ('max val of self.result_list',np.amin(self.result_list),np.amax(self.result_list))
            if self.cfg['num_class']<=2:
                if self.cfg['num_task']==1:
                    th_result = [val>0.5 for val in self.result_list]
                    result = th_result
                    CalculateClsScore(th_result,self.label_list)
                    auc_val = CalculateAuc(self.result_list,self.label_list)
                    # PlotRoc(self.result_list,self.label_list)
                    results.append(auc_val)
                    if len(model_paths)==1:
                        CalculateClsScoreByTh(self.result_list,self.label_list,acc_flag=True)

                else:
                    for task_idx in range(self.cfg['num_task']):
                        print ('='*60)
                        print ('Current task is %s'%self.cls_name_map_dict[self.cfg['cls_label_index'][task_idx]])

                        current_label_list = np.squeeze(np.array([val[task_idx] for val in self.label_list]))
                        current_result_list = np.squeeze(np.array([val[task_idx] for val in self.result_list]))
                        th_result = np.squeeze(np.array([val>0.5 for val in current_result_list]))
                        print (np.array(current_label_list).shape,np.array(current_result_list).shape,np.array(th_result).shape)
                        CalculateClsScore(th_result,current_label_list)
                        auc_val = CalculateAuc(current_result_list,current_label_list)
            else:
                result = np.squeeze(np.array(self.result_list))
                result = np.argmax(result,axis=1)
                print ('shape of result is',result.shape)
                print ('shape of self.label_list is',np.array(self.label_list).shape)
                kappa_val = CalculateClsScore(result,self.label_list)
                results.append(kappa_val)
           
        if len(results)>0:
            print ('results is ',results)
            pos = np.argmax(results)
            print ('best weights is ',model_paths[pos])
        self.test_dataset = test_dataset
        self.result = result
        self.badcase_lists = [case_id for case_id in range(len(test_dataset)) if result[case_id]!=self.label_list[case_id]]
    
        
    def _GetTargetClassProb(self,labels,probs):
        target_class_probs = [prob[0][label] for prob,label in zip(probs,labels)]
        return target_class_probs
        
    def ProbVis(self,labels,probs):
        values = np.unique(labels)
        from matplotlib import pyplot as plt
        for val in values:
            print ('='*60)
            print ('Current label is %d'%val)
            indices = [idx for idx in range(len(labels)) if labels[idx]==val]
            current_probs = np.take(probs,indices,axis=0)
            print ('Number of samples are %d'%len(current_probs))
            plt.hist(current_probs,range=(0,1))
            plt.show()
            plt.show()
    
    def BadCaseVis(self,num_incies,show_edges=False,gaussian_choice=False,sigma=1):
        for idx in self.badcase_lists[:num_incies]:
            current_data = self.test_dataset[idx]
            img = np.squeeze(current_data[0])
            label = current_data[1]
            center_slice_idx = int(img.shape[0]/2)
            slice_range = 1
            print ('='*60)
            print ('Current label is ',label,' result is ',self.result[idx])
            print ('prob is',self.result_list[idx])
            if not show_edges:
                _,axs =plt.subplots(1,2*slice_range+1)
            else:
                _,axs = plt.subplots(2,2*slice_range+1)
            for slice_idx in range(center_slice_idx-slice_range,center_slice_idx+slice_range+1):
                if not show_edges:
                    axs[slice_idx-(center_slice_idx-slice_range)].imshow(img[slice_idx],cmap='bone')
                else:
                    edges = ShowEdges(img[slice_idx],gaussian_choice=gaussian_choice,sigma=sigma)
                    # edges = HFFilter(img[slice_idx])
                    axs[0][slice_idx-(center_slice_idx-slice_range)].imshow(img[slice_idx],cmap='bone')
                    axs[1][slice_idx-(center_slice_idx-slice_range)].imshow(edges,cmap='bone')
            plt.show()


    def ParametersDecode(self):
        '''
        Generate model parameters
        '''
        self.model_param_name = os.path.join(self.basic_file_path,self.cfg['model_params'])
        self.model_params = LoadJson(self.model_param_name)[self.cfg['model_name']]