import os
import sys
import six
import json
import shutil
import datetime
import numpy as np
import scipy.ndimage as nd

sys.path.append('./models/backbones/')
sys.path.append('./models/baseLayers/')
sys.path.append('./models/decoders/')
sys.path.append('./models/')
sys.path.append('./utils/')

from data_generator import UNet_3D_Generator,DataStat
from ImagePreprocess import Image_preprocess,DataNormlization,CalculatePatchSizeWithDiameter,Data_preprocess
from parameter_utils import RectifyParameters
from ResultEval import dice_coefficient_post
from data_aug import *
from RWconfig import LoadJson,GenerateParameters,GenerateParameters,WriteJson
from ReadData import load_data,load_single_data
from ResultPostProcess import post_process

from modelBuild import build_model
from model_utils import get_block
from Segmentation_loss import Tversky_loss,loss_dice_coefficient_error,NoduleSegAuxLoss,focal_loss
from train_utils import get_callbacks
from matplotlib import pyplot as plt



from keras import optimizers

class Segmentation3D(object):
    def __init__(self,cfg):

        self.gamma = 1.0

        # self.filename = filename
        # current_config = LoadJson(filename)
        # #print ('current_config is ',list(current_config.keys())[0])
        # key = list(current_config.keys())[0]
        # parameters = current_config[key]
        self.cfg = cfg
        parameters = cfg['training']
        self.parameters = parameters
        ### Generate parameters
        
        '''
        General parameters
        '''
        
        self.num_slice = parameters['num_slice'] if parameters['num_slice']>0 else None
        self.voxel_size = parameters['voxel_size'] if parameters['voxel_size']>0 else None
        self.n_channels = parameters['n_channels']
        self.input_shape = [self.num_slice,self.voxel_size,self.voxel_size,self.n_channels]
        self.num_classes = parameters['num_classes']
        self.final_ksize = parameters['final_ksize']
        self.num_units = parameters['num_units']
        self.model_parameters_txt = os.path.join('./files',parameters['model_parameters_txt']) 
        self.model_map_parameters_txt = os.path.join('./files',parameters['model_map_parameters_txt']) 
        
        self.encoder = parameters['encoder']
        self.decoder = parameters['decoder']
        self.normalizer = parameters['normalizer']
        self.base_path = parameters['base_path']
        self.pre_model_file = parameters['pre_model_file']
        self.model = None
        self.nonfix_crop = parameters['nonfix_crop']
        self.stride_ratio = parameters['stride_ratio']
        self.densityTypeIndex = parameters['densityTypeIndex']
        print ('self.densityTypeIndex is',self.densityTypeIndex)

        if 'T_loss_alpha' in parameters.keys():
            self.T_loss_alpha = parameters['T_loss_alpha']
        else:
            self.T_loss_alpha = 0.65

        print ('self.nonfix_crop is ',self.nonfix_crop)
        
        
        # if 'train' in key.lower():
        if cfg.mode == 'training':
            '''
            parameters for training only
            '''
            lr = parameters['initial_learning_rate']

            # self.train_file = os.path.join('./files',parameters['train_file']) 
            # self.val_file = os.path.join('./files',parameters['val_txt'])
            self.load_pre_trained = parameters['load_pre_trained']

            self.initial_learning_rate = parameters['initial_learning_rate']
            self.learning_rate_drop = parameters['learning_rate_drop']
            self.learning_rate_patience = parameters['learning_rate_patience']
            self.early_stop = parameters['early_stop']
            self.base_path = parameters['base_path']
            self.n_epochs = parameters['n_epochs']
            self.aug = parameters['aug']
            self.prefix = parameters['prefix']
            self.train_batch_size = parameters['train_batch_size']
            self.val_batch_size  = parameters['val_batch_size']
            self.shuffle = parameters['shuffle']
            
            
            self.trainfiles = {
                'data_paths':self.cfg['train_data_path'],
                'mask_paths':self.cfg['train_mask_path'],
                'info_paths':self.cfg['train_info_path']
            }
            self.valfiles = {
                'data_paths':self.cfg['val_data_path'],
                'mask_paths':self.cfg['val_mask_path'],
                'info_paths':self.cfg['val_info_path']
            }
        
        else:
            # self.test_file = os.path.join('./files',parameters['test_file'])
            self.testfiles = {
                'data_paths':self.cfg['test_data_path'],
                'mask_paths':self.cfg['test_mask_path'],
                'info_paths':self.cfg['test_info_path']
            }
            
    
    # def LoadGroupFiles(self,filename):
    #     result = LoadJson(filename)
    #     data_paths = [record['data_path'] for record in result]
    #     mask_paths = [record['mask_path'] for record in result]
    #     info_paths = [record['info_path'] for record in result]
    #     return {
    #         'data_paths':data_paths,
    #         'mask_paths':mask_paths,
    #         'info_paths':info_paths
    #     }
            
    def BuildSegModel(self):
        '''
        Generate encoder/decoder parameters with model_parameters_txt and model name
        '''

    
        self.encoder_parameter_list = GenerateParameters(self.model_parameters_txt,self.model_map_parameters_txt,model=self.encoder)
        self.decoder_parameter_list = GenerateParameters(self.model_parameters_txt,self.model_map_parameters_txt,model=self.decoder)
        # print ('Before rectify parameter is ',self.decoder_parameter_list)
  

        '''
        Rectify parameters 
        '''
        self.encoder_parameter_list = RectifyParameters(self.encoder_parameter_list,self.parameters)
        self.decoder_parameter_list = RectifyParameters(self.decoder_parameter_list,self.parameters)
        # print ('After rectify parameter is ',self.decoder_parameter_list)

        '''
        Build encoder and decoder
        '''

        self.encoder_result = build_model(self.encoder,input_shape=self.input_shape,parameter_list=[self.encoder_parameter_list])
        self.model = build_model(self.decoder,input_shape=self.input_shape,parameter_list=[self.encoder_result,self.decoder_parameter_list])
        # self._calculate_ratio()
        # self.model.summary()
        
    def CompileSegModel(self):
        
        
        adam=optimizers.Adam(lr=self.initial_learning_rate,beta_1=0.9,beta_2=0.999,epsilon=1e-08)
        
        ## todo: change loss to parameters in config file
        seg_loss = Tversky_loss(alpha=self.T_loss_alpha,beta=1-self.T_loss_alpha)

        final_loss = 0

        
        if self.decoder_parameter_list['deep_supervision']:
            final_loss = [seg_loss for _ in range(3)]
        elif self.decoder_parameter_list['ACF_choice']:
            final_loss = [seg_loss for _ in range(2)]
        else:
            final_loss = [seg_loss]

        
        if self.decoder_parameter_list['aux_task']:
            self._calculate_ratio()
            cls_loss = focal_loss(self.focal_class_weights,self.gamma)
            final_loss.append(cls_loss)
            class_weights = [1 for _ in range(len(final_loss)-1)] + [1.5]
        else:
            class_weights = [1 for _ in range(len(final_loss))]
        if len(final_loss)==1:
            final_loss = final_loss[0]

        self.model.compile(optimizer= adam,  loss=final_loss,loss_weights=class_weights, metrics=['accuracy'])
        
    
    def GenerateTrainParameters(self):
        self.train_step_per_epoch = int(np.sum([np.load(path).shape[0] for path in self.trainfiles['info_paths']] )/self.train_batch_size)
        self.valid_step_per_epoch = int(np.sum([np.load(path).shape[0] for path in self.valfiles['info_paths']] )/self.val_batch_size )
        
        
        # self.train_step_per_epoch = 2
        # self.valid_step_per_epoch = 2
        
        self.train_generator = UNet_3D_Generator(self.trainfiles['data_paths'],self.trainfiles['mask_paths'],
                                                 self.trainfiles['info_paths'],input_size=self.input_shape,
                                                 batch_size = self.train_batch_size,aug=self.aug,
                                                 HU_min = self.normalizer[0],HU_max=self.normalizer[1],crop_choice=True,
                                                config_dict=self.decoder_parameter_list,nonfix_crop=self.nonfix_crop,
                                                stride_ratio = self.stride_ratio,
                                                densityTypeIndex = self.densityTypeIndex)
         
        self.val_generator = UNet_3D_Generator(self.valfiles['data_paths'],self.valfiles['mask_paths'],
                                             self.valfiles['info_paths'],input_size=self.input_shape,
                                             batch_size = self.val_batch_size,aug=False,
                                             HU_min = self.normalizer[0],HU_max=self.normalizer[1],crop_choice=True,
                                              config_dict=self.decoder_parameter_list,nonfix_crop=self.nonfix_crop,
                                              stride_ratio = self.stride_ratio,densityTypeIndex = self.densityTypeIndex)
        
        # current_DT = datetime.datetime.now()
        # current_day = '%02d%02d' % (current_DT.month,current_DT.day)

        # self.current_base_path = '%s/%s_%s/'%(self.base_path,self.prefix,current_day)
        # if not os.path.exists(self.current_base_path):
        #     os.mkdir(self.current_base_path)
            
            
        # for target_filename in [self.filename,self.model_parameters_txt,self.model_map_parameters_txt,self.train_file,self.val_file]:
        #     shutil.copyfile(target_filename,os.path.join(self.current_base_path,os.path.basename(target_filename)))
        
        
        # self.filepath = str('%s/Train-{epoch:02d}-{val_loss:.5f}.hdf5'%(self.current_base_path))
        # self.filepath = str(self.filepath)
        # print ('self.filepath is ',type(self.filepath),self.filepath)
        self.callbacks_list = get_callbacks(self.filepath,initial_learning_rate = self.initial_learning_rate,
                                  learning_rate_drop = self.learning_rate_drop,learning_rate_epochs = None,
                                  learning_rate_patience = self.learning_rate_patience,
                                  early_stopping_patience = self.early_stop)
        
        # self.val_generator()
    def _calculate_ratio(self):
        train_info = load_single_data(self.trainfiles['info_paths'],self.decoder_parameter_list['aux_task'])
        val_info = load_single_data(self.valfiles['info_paths'])
        densityTypeConvertMap = self.decoder_parameter_list['densityTypeConvertMap']
        keys = sorted(densityTypeConvertMap.keys())
        print ('keys',keys)
        counts = [0 for _ in range(4)]
        for info in train_info:
            val = str(int(float(info[self.densityTypeIndex])))
            if val not in keys:
                continue
            counts[densityTypeConvertMap[val]-1] +=1
        
        focal_class_weights = [val/float(counts[0]) for val in counts]
        focal_class_weights = [val**0.5 for val in focal_class_weights]
        # focal_class_weights[2] = focal_class_weights[2] * 1.4

        self.focal_class_weights = [round(val,2) for val in focal_class_weights]
        print ('self.focal_class_weights is ',self.focal_class_weights)

    def Train(self):
        
        self.BuildSegModel()
        print ('After building model')
        
        
        self.CompileSegModel()
        print ('After compiling model')

        self.LoadPretrainModel()
        print ('After loading pretrain model')
        self.GenerateTrainParameters()
   
                                  
        self.history_callbacks = self.model.fit_generator(self.train_generator,steps_per_epoch=self.train_step_per_epoch,epochs=self.n_epochs,
                                                     callbacks=self.callbacks_list,shuffle=self.shuffle,
                                                     validation_data=self.val_generator,validation_steps=self.valid_step_per_epoch)
        
        WriteJson(os.path.join(self.current_base_path,'history.json'),self.history_callbacks.history)
        print ('After saving config to file %s'%os.path.join(self.current_base_path,'history.json'))
                                                    
    
    def MakeInference(self):
        final_images,final_masks,result = [],[],[]
        for single_data,single_mask,single_info in zip(self.test_data,self.test_mask,self.test_info):
            uid = single_info[0]
            if self.nonfix_crop:
                patch_shape = CalculatePatchSizeWithDiameter(single_mask,stride_ratio=self.stride_ratio,aug=False)
            else:
                patch_shape = self.input_shape
            current_image,current_mask = Data_preprocess(single_data,single_mask,aug=False,input_shape=patch_shape,shift_ratio=1.0)
            current_image = np.clip(current_image,self.normalizer[0],self.normalizer[1])
            current_image = (current_image - self.normalizer[0])/float(self.normalizer[1]-self.normalizer[0])

            # current_image = np.squeeze(DataNormlization(current_image,single_info,HU_min=self.normalizer[0],HU_max=self.normalizer[1],mode='HU'))[np.newaxis,...,np.newaxis]
            current_result = self.model.predict(current_image)
            result.append(current_result)
            final_images.append(current_image)
            final_masks.append(current_mask)
        
        mask_th = 0.05
        if not self.decoder_parameter_list['deep_supervision']:
            if not self.decoder_parameter_list['ACF_choice']:
            # if not self.ACF_choice:
                final_result = result>mask_th
            else:
                target_idx = 1
                current_result = result[target_idx]
                final_result = current_result>mask_th
        else:
            target_idx = 2
            final_result = [single_result[target_idx]>mask_th for single_result in result]
            final_prediction = [post_process(mask_image[0],spacing_z=1.0,inp_shape=patch_shape,debug=False)[2] for mask_image in final_result]
        return final_result,final_prediction,final_masks,final_images


    def Inference(self):
        
        self.BuildSegModel()
        self.LoadPretrainModel()
        
        print ('Before Loading test data')
        self.LoadInferenceData()
        print ('After loading test data')
        
        '''
        TODO:filter unqualified result 
        '''
        final_result,final_prediction,test_masks_post,test_images_post = self.MakeInference()
        
        print ('After making prediction')
        inference_result_dict = {
            'model_weights':self.pre_model_file,
            'test_data':self.testfiles,
            'result':{}
            
        }

        self.final_result = final_result
        self.final_prediction = final_prediction
        self.test_masks_post = test_masks_post
        self.test_images_post = test_images_post
        result_dice = [dice_coefficient_post(np.squeeze(y_true),np.squeeze(y_pred)) for y_true,y_pred in zip(test_masks_post,final_prediction)]
        result_dice_infer = [dice_coefficient_post(np.squeeze(y_true),np.squeeze(y_pred)) for y_true,y_pred in zip(test_masks_post,final_result)]

        self.result_dice = result_dice
        self.result_dice_infer = result_dice_infer

        inference_result_dict['result']['overall_post_dice'] = [np.mean(result_dice),np.std(result_dice)]
        inference_result_dict['result']['overall_infer_dice'] = [np.mean(result_dice_infer),np.std(result_dice_infer)]
        print ('Overall dice value on infer result is ',inference_result_dict['result']['overall_infer_dice'][0])
        print ('Overall dice value on post result is ',inference_result_dict['result']['overall_post_dice'][0])

        diameter_ths = [[0,4],[4,10],[10,30],[30,100000],[4,1000000]]
        diameter_idx = 4

        for diameter_th in diameter_ths:
            diameter_min,diameter_max = diameter_th
            indices = [sample_idx for sample_idx in range(self.test_info.shape[0]) if self.test_info[sample_idx][diameter_idx].astype('float')<diameter_max and 
                        self.test_info[sample_idx][diameter_idx].astype('float')>=diameter_min]
            current_dice_post = [result_dice[sample_idx] for sample_idx in indices]
            current_dice_infer = [result_dice_infer[sample_idx] for sample_idx in indices]

            inference_result_dict['result']['Diameter_[%d,%d]_post_dice'%(diameter_min,diameter_max)] = [np.mean(current_dice_post),np.std(current_dice_post)]
            inference_result_dict['result']['Diameter_[%d,%d]_infer_dice'%(diameter_min,diameter_max)] = [np.mean(current_dice_infer),np.std(current_dice_infer)]
            print ('dice value of nodules between [%d,%d] on infer result is '%(diameter_min,diameter_max),inference_result_dict['result']['Diameter_[%d,%d]_post_dice'%(diameter_min,diameter_max)][0])
            print ('dice value of nodules between [%d,%d] on post result is '%(diameter_min,diameter_max),inference_result_dict['result']['Diameter_[%d,%d]_infer_dice'%(diameter_min,diameter_max)][0])


        print ('Before writing json file to %s'%os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json'))
        WriteJson(os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json'),inference_result_dict)
        print ('After writing json file to %s'%os.path.join(os.path.dirname(self.pre_model_file),'inference_result.json'))
        
    
    def LoadInferenceData(self):
        self.test_data_list = load_data([self.testfiles['data_paths'],self.testfiles['mask_paths'],
                                                 self.testfiles['info_paths']])
        self.test_data,self.test_mask,self.test_info = self.test_data_list
        
    
    
    def LoadPretrainModel(self):
        print ('load model ',self.pre_model_file)
        if os.path.exists(self.pre_model_file) and self.model:
            self.model.load_weights(self.pre_model_file)
        elif not os.path.exists(self.pre_model_file):
            print ('Pretrain model path does not exist')
        else:
            print ('model has not been defined')



    def Visualize(self,dice_list,iou_th,diameter_th=[0,float('inf')],densityType=np.arange(1,10000)):
        for idx in range(len(dice_list)):
            dice_value = dice_list[idx]
            current_data = np.squeeze(self.test_images_post[idx])
            current_mask = np.squeeze(self.test_masks_post[idx])
            current_info = np.squeeze(self.test_info[idx])
            current_result = np.squeeze(self.final_result[idx])
            diameter_idx = 4
            densityType_idx = -4
            print ('current_info[diameter_idx]',current_info[diameter_idx],current_info[densityType_idx],dice_value)
            if float(current_info[diameter_idx])>=diameter_th[0] and float(current_info[diameter_idx])<diameter_th[1]:
                if int(float(current_info[densityType_idx])) in set(densityType):
                    if dice_value<iou_th:
                        print ('='*60)
                        for slice_idx in range(current_mask.shape[0]):
                            if np.sum(current_mask[slice_idx])>0 or np.sum(current_result[slice_idx])>0:
                                _,axs = plt.subplots(1,3)
                                axs[0].imshow(current_data[slice_idx],cmap='bone')
                                axs[1].imshow(current_mask[slice_idx],cmap='bone')
                                axs[2].imshow(current_result[slice_idx],cmap='bone')
                                plt.show()


    def _DataStat(self):
        # self.self.BuildSegModel()
        diameter_ths = [[0,1000000000],[0,4],[4,10],[10,20],[20,30],[30,1000000],[4,100000000000000000]]
        info_list = self.trainfiles['info_paths'] + self.valfiles['info_paths']
        for path in info_list:
            print ('path is %s'%path)
        DataStat(info_list,diameter_ths)