# -*- coding:utf-8 -*-
import os
import sys
import six
import glob
import json
import random
import shutil
import datetime
import numpy as np
import pandas as pd
import scipy.ndimage as nd
from tqdm import tqdm
import datetime
import logging
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
# from matplotlib import pyplot as plt
from tensorboardX import SummaryWriter

sys.path.append('./utils')
sys.path.append('./model')
from .utils.RWconfig import LoadJson
from .utils.ReadData import load_data,load_single_data
from .model.modelBuild import build_model
from .model.BasicModules import weight_init
from .utils.DataGen import SurDataSet
from .utils.OnlineEval import eval_net
from .utils.OfflineEval import CalculateClsScore,CalculateAuc,CalculateClsScoreByTh
from .utils.loss_func import BCEFocalLoss,PeerLoss,MIFocalLoss,FocalLoss
from .utils.gradcam import GradCam,GuidedBackpropReLUModel


logger = logging.getLogger()
fh = logging.FileHandler('Train.log',encoding='utf-8')
sh = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s')
fh.setFormatter(formatter)  
sh.setFormatter(formatter)

logger.addHandler(fh)   
logger.addHandler(sh)

logger.setLevel(10)   


class Classification3D(object):
    def __init__(self,filename,mode):
        self.mode = mode
        self.filename = filename
        self.cfg = LoadJson(self.filename)[mode]
        self.basic_file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        current_DT = datetime.datetime.now()
        self.current_day = '%02d%02d'%(current_DT.month,current_DT.day)
        self.convert_dict = self.cfg['convert_dict'] if 'convert_dict' in self.cfg.keys() else {}
        self._clsInfoDecode()
        writer_path = os.path.join(self.save_path, 'runs')
        train_writer_path = os.path.join(writer_path, 'train')
        val_writer_path = os.path.join(writer_path, 'val')
        self.writer = SummaryWriter(writer_path) if self.mode=='training' else None
        self.train_writer = SummaryWriter(train_writer_path) if self.mode=='training' else None
        self.val_writer = SummaryWriter(val_writer_path) if self.mode=='training' else None
        if mode=='training':
            self.trainfile_conf_path = os.path.join(self.basic_file_path,self.cfg['train_file'])
            self.valfile_conf_path = os.path.join(self.basic_file_path,self.cfg['val_file'])
            self.trainfiles = self._DecodeDataParam(self.trainfile_conf_path)
            self.valfiles = self._DecodeDataParam(self.valfile_conf_path)
            self.testfiles = []
        else:
            self.trainfiles = []
            self.valfiles = []
            self.test_conf_path = os.path.join(self.basic_file_path,self.cfg['test_file'])
            self.testfiles = self._DecodeDataParam(self.test_conf_path)
            
        self.ParametersDecode()
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    def __call__(self):

        self._buildModel()
        
        
        self._compileModel()
        
        self._LoadData(self.mode)
        if self.mode=='training':
            self._trainModel()
        else:
            import time 
            start_time = time.time()
            self._testModel()
            end_time = time.time()
            dur = end_time - start_time
            print ('========================================== It takes %s s to make inference'%str(dur))
        if self.writer:
            self.writer.close()
        if self.train_writer:
            self.train_writer.close()
        if self.val_writer:
            self.val_writer.close()

    def _clsInfoDecode(self):
        self.cls_map_dict = eval(self.cfg['cls_map_dict'])
        self.ori_cls_types = set([self.cls_map_dict[key] for key in self.cls_map_dict.keys()])
        self.input_shape = [self.cfg['patch_height'],self.cfg['patch_height'],self.cfg['patch_depth']]
        self.cls_name_map_dict = eval(self.cfg['cls_name_map_dict'])
        model_order = self.cfg['model_order'] if 'model_order' in self.cfg.keys() else 1
        if type(self.cfg['cls_label_index'])==int:
            cls_related_info = self.cls_name_map_dict[self.cfg['cls_label_index']]
        else:
            cls_related_info = '_'.join([self.cls_name_map_dict[val] for val in self.cfg['cls_label_index']])
        self.save_path = '%s/%s_%s_%dCls_%s_Patch%03d_model%s/%02d_%s/'%(self.cfg['base_path'],self.cfg['weights_pre'],cls_related_info,self.cfg['num_class'],self.current_day,self.input_shape[0],self.cfg['model_name'],model_order,self.cfg['weight_memo'])
        if (not os.path.exists(self.save_path)) and self.mode == 'training':
            os.makedirs(self.save_path)
        if type(self.cfg['cls_label_index'])==list:
            self.cfg['num_task'] = len(self.cfg['cls_label_index'])
        else:
            self.cfg['num_task'] = 1

    def _buildModel(self):
        cfg = self.model_params
        cfg['num_task'] = self.cfg['num_task']
        cfg['num_class'] = self.cfg['num_class'] 
        if cfg['num_class'] == 2:
            cfg['num_class'] = 1
        self.net = build_model(self.cfg['model_name'],cfg)
        logger.info("After building model")
        self.net.apply(weight_init)
        logger.info("After Weight Init")

                    
                    
    def _compileModel(self):

        self.optimizer = optim.Adam(self.net.parameters(),lr=self.cfg['learning_rate'])
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,patience=self.cfg['patience'])
        if self.cfg['num_class']>2:
            self.criterion = nn.CrossEntropyLoss()

        else:

            self.criterion = nn.BCEWithLogitsLoss()
        conf_paths = [self.filename,self.model_param_name]
        if self.mode == 'training':
            conf_paths += [self.trainfile_conf_path,self.valfile_conf_path]
            for conf_path in conf_paths:
                filename = conf_path.split('/')[-1]
                target_save_path = os.path.join(self.save_path, filename)
                shutil.copy(conf_path, target_save_path)
        else:
            conf_paths += [self.test_conf_path]


    def _DecodeDataParam(self,filename):
        paths = LoadJson(filename)
        data_paths = [record['data_path'] for record in paths]
        info_paths = [record['info_path'] for record in paths]
        return {'data_paths':data_paths,
                'info_paths':info_paths}

    def _LoadData(self,mode):
        if mode == 'training':
            self.train_data,self.train_info = load_data([self.trainfiles['data_paths'],self.trainfiles['info_paths']])
            self.val_data,self.val_info = load_data([self.valfiles['data_paths'],self.valfiles['info_paths']])
        else:
            self.test_data,self.test_info = load_data([self.testfiles['data_paths'],self.testfiles['info_paths']])
        

    def _trainModel(self):
        ########## Add parameters
        
        gamma = self.cfg['gamma'] if 'gamma' in self.cfg.keys() else 1
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        nonfix_crop=self.cfg['nonfix_crop'] if 'nonfix_crop' in self.cfg.keys() else False
        train_aug_choice = self.cfg['train_aug_choice'] if 'train_aug_choice' in self.cfg.keys() else False


        train_dataset = SurDataSet(self.train_data,self.train_info,cls_index=self.cfg['cls_label_index'],cls_map_dict=self.cls_map_dict,
                                    diameter_index=self.cfg['diameter_index'],num_class=self.cfg['num_class'],HU_min=self.cfg['HU_min'],
                                    HU_max=self.cfg['HU_max'],input_shape=self.input_shape,
                                    cls_types=self.ori_cls_types,aug=train_aug_choice,
                                    nonfix_crop=nonfix_crop,
                                    diameter_max_th = self.cfg['diameter_max_th'],
                                    diameter_min_th = self.cfg['diameter_min_th'])
                    
        val_dataset = SurDataSet(self.val_data,self.val_info,cls_index=self.cfg['cls_label_index'],cls_map_dict=self.cls_map_dict,
                                    diameter_index=self.cfg['diameter_index'],num_class=self.cfg['num_class'],HU_min=self.cfg['HU_min'],
                                    HU_max=self.cfg['HU_max'],input_shape=self.input_shape,
                                    cls_types=self.ori_cls_types,
                                    nonfix_crop=nonfix_crop,
                                    diameter_max_th = self.cfg['diameter_max_th'],
                                    diameter_min_th = self.cfg['diameter_min_th'])
        
        
        
        if self.cfg['focal_loss'] and self.cfg['num_task']==1:
            ratios =train_dataset._GenerateRatios()
            if self.cfg['num_class']<=2:
                alpha = ratios[0]
                print ('='*60)
                print ('alpha is ',alpha)
                train_focal_loss = self.cfg['train_focal_loss'] if 'train_focal_loss' in self.cfg.keys() else self.cfg['focal_loss'] 
                loss_choice = self.cfg['loss_func'] if 'loss_func' in self.cfg.keys() else 'focal'
                print ('loss choice is ',loss_choice)
                if loss_choice != 'focal':
                    train_focal_loss = False
                if train_focal_loss:
                    self.criterion = BCEFocalLoss(alpha=alpha,gamma=gamma)
                elif loss_choice == 'peer':
                    self.criterion = PeerLoss(alpha=alpha,gamma=gamma)
                elif loss_choice == 'MIFocalLoss':
                    self.criterion = MIFocalLoss(alpha=alpha,gamma=gamma)
            else:
                alpha = ratios
                sigmoid_choice = self.cfg['sigmoid_choice'] if 'sigmoid_choice' in self.cfg.keys() else False
                ######### sigmoid + bce loss
                self.criterion = FocalLoss(self.cfg['num_class'],alpha=ratios,gamma=gamma)
        else:
            alpha = None
        # ratios = train_dataset._GenerateRatios()


        balance_data_kws = self.cfg['balance_data'] if 'balance_data' in self.cfg.keys() else [0,0]
        modes = ['val' if val==0 else 'train' for val in balance_data_kws]
        train_dataset._GenerateTrainData(mode=modes[0])
        val_dataset._GenerateTrainData(mode=modes[1])


        train_loader = DataLoader(train_dataset,batch_size=self.cfg['batch_size'],shuffle=True)
        val_loader = DataLoader(val_dataset,batch_size=self.cfg['batch_size'],shuffle=True)
       
        lr=self.cfg['learning_rate']
        batch_size = self.cfg['batch_size']
        # writer = SummaryWriter(comment=f"LR_{lr}_BS_{batch_size}")

        global_step = 0

        #####  add config
        epochs = self.cfg['num_epoch']
        criterion = self.criterion
        for epoch in range(epochs):
            self.net.train()
            self.net.to(device)
            self.net.cuda()
            epoch_loss = 0
            print ('size of train dataset is ',len(train_dataset))
            with tqdm(total=int(len(train_dataset)/batch_size)+1,desc=f'Epoch{epoch+1}/{epochs}',unit='img') as pbar:
                batch_count = 0
                for batch in train_loader:
                    batch_count +=1
                    imgs = batch[0][0]
                    label = batch[1][0]
                    imgs = imgs.cuda()
                    if self.cfg['num_class']>2:
                        label = label.cuda()
                        label = label.to(device=device,dtype=torch.long)
                    else:
                        if self.cfg['num_task']>1:
                            label = [val.cuda() for val in label]
                            label = [val.to(device=device,dtype=torch.float32) for val in label]
                        else:
                            label = label.cuda()
                            label = label.to(device=device,dtype=torch.float32)
                            

                    imgs = imgs.to(device=device,dtype=torch.float32)
                    
                    prediction = self.net(imgs)
                    loss_values = []  
                    if self.cfg['num_task']>1:
                        for x,y in zip(label,prediction):
                            probs = torch.sigmoid(y)
                            y = y.squeeze(-1)
                            
                            loss = criterion([y,x])
                            epoch_loss += loss.item()
                            loss_values.append(loss.item())
                    else:
#                         print ('prediction before',prediction)
                        if self.cfg['num_class']>2:
                            prediction = prediction
                        else:
                            if type(prediction) == list:
                                prediction = prediction[0]
                            prediction = prediction.squeeze(-1)
                        loss = criterion([prediction,label])
                    
                        epoch_loss += loss.item()
                    self.optimizer.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_value_(self.net.parameters(),1e-1)
                    self.optimizer.step()
                    global_step += 1

#                     if(global_step %(len(train_dataset)//(10*batch_size)))==0:
#                         for tag,value in self.net.named_parameters():
#                             tag = tag.replace('.','/')
                            # writer.add_histogram('weights/'+tag,value.data.cpu().numpy,global_step)
                            # writer.add_histogram('grads/'+tag,value.grad.data.cpu().numpy,global_step)
                        ############ Add eval net func(for samples in val generator make prediction and calculating loss)
            ############## Stragety 1.

                    pbar.update(1)
            if self.cfg['select_samples'] == 'partial':
                loss_list = self._CalculateLossTrain(train_dataset,device)
                train_dataset._updateIndices(loss_list,ratio_pos=self.cfg['ratio_pos'],num_hard=100)
            elif self.cfg['select_samples'] == 'boost':
                ############# Stragety 2. _CalculateDiffTrain
                loss_list = self._CalculateLossTrain(train_dataset,device)
                initial_epoch = self.cfg['initial_epoch'] if 'initial_epoch' in self.cfg.keys() else 0
                epoch_interval = self.cfg['epoch_interval'] if 'epoch_interval' in self.cfg.keys() else 1
                if epoch>=initial_epoch:
                    if (epoch-epoch_interval)%epoch_interval == 0:
                        diff_list = self._CalculateDiffTrain(train_dataset,device)
                        increase_choice = self.cfg['increase_choice'] if 'increase_choice' in self.cfg.keys() else False
                        train_dataset._updateWeights(diff_list,increase=increase_choice)
                        train_dataset._SelectSamplesBasedOnWeights()
                        new_ratios = train_dataset._GenerateRatios()
                        if self.cfg['num_class']<2:
                            self.criterion = BCEFocalLoss(alpha=new_ratios[0],gamma=gamma)
#                         else:
#                             self.criterion = FocalLoss(self.cfg['num_class'],alpha=new_ratios,gamma=gamma)
            print ('batch_count is',batch_count)
            random.shuffle(train_dataset.indices)
            val_score = eval_net(self.net, val_dataset,val_loader, device,alpha=alpha,cfg=self.cfg)*100
            ############ change lr based on val result
            self.scheduler.step(val_score)
            ###########3 param_groups?
            # writer.add_scalar('learning_rate',self.optimizer.param_groups[0]['lr'],global_step)
            print ('epoch_loss is',epoch_loss,len(train_dataset))
            logger.info("Train loss: {}".format(epoch_loss/float(len(train_dataset))*100))
            logging.info("Validation loss: {}".format(val_score))

            if self.cfg['save_cp']:
                model_save_path_current_epoch = self.save_path + f'/CP_epoch{epoch + 1}_loss_%.2f.pth' % val_score
                torch.save(self.net.state_dict(), model_save_path_current_epoch)
                logging.info(f'Checkpoint{epoch + 1} saved to %s!' % model_save_path_current_epoch)
            if self.train_writer:
                self.train_writer.add_scalar('loss', epoch_loss/float(len(train_dataset))*100, epoch)
            if self.val_writer:
                self.val_writer.add_scalar('loss', val_score, epoch)

    def _testModel(self):
        if '*' in self.cfg['weight_path']:
            model_paths = sorted(glob.glob(self.cfg['weight_path']))
            model_paths = [path for path in model_paths if '.pth' in path]
        else:
            model_paths = [self.cfg['weight_path']]
        results  = []
        for model_path in model_paths:
            print ('='*60)
            print ('model_path is ',model_path)

            device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
            self.net.to(device=device)
            self.net.load_state_dict(torch.load(model_path,map_location=device))
            logging.info('Model loaded !')

            ########## 将模型设置为eval 固定BN/dropout的参数
            self.net.eval()
            nonfix_crop=self.cfg['nonfix_crop'] if 'nonfix_crop' in self.cfg.keys() else False
            test_dataset = SurDataSet(self.test_data,self.test_info,cls_index=self.cfg['cls_label_index'],cls_map_dict=self.cls_map_dict,
                                        diameter_index=self.cfg['diameter_index'],num_class=self.cfg['num_class'],HU_min=self.cfg['HU_min'],
                                        HU_max=self.cfg['HU_max'],input_shape=self.input_shape,
                                        cls_types=self.ori_cls_types,
                                        nonfix_crop = nonfix_crop,
                                        aug=False,
                                        mode='test',
                                    diameter_max_th = self.cfg['diameter_max_th'],
                                    diameter_min_th = self.cfg['diameter_min_th'])

            test_dataset._GenerateTrainData()
            self.label_list = []
            self.result_list = []
            self.badcase_lists = []
            for idx in range(len(test_dataset)):
                current_data = test_dataset[idx]
                img = current_data[0][0]
                label = current_data[1][0]
                img = img[np.newaxis,...]
                img = torch.from_numpy(img)
                img = img.to(device=device,dtype=torch.float32)
                with torch.no_grad():
                    output = self.net(img)
                    if type(output) == list:
                        output = output[0]
                    if self.cfg['num_task']==1:
                        if self.net.num_class>2:
                            probs = torch.softmax(output,dim=1)
                        else:
                            probs = torch.sigmoid(output)
                        probs = probs.squeeze(0)
                        probs = probs.data.cpu().numpy()
                    else:
                        current_probs = []
                        for task_idx in range(self.cfg['num_task']):
                            if self.net.num_class>1:
                                probs = F.softmax(output[task_idx],dim=1)
                            else:
                                probs = torch.sigmoid(output[task_idx])
                            probs = probs.squeeze(0)
                            current_probs.append(probs.cpu().numpy()[0])
                        probs = current_probs
                    self.label_list.append(label)
                    self.result_list.append(probs)
            print ('length of label and result',len(self.label_list),len(self.result_list))
            print ('shape of result list is ',np.array(self.result_list).shape)
            print ('max val of self.result_list',np.amin(self.result_list),np.amax(self.result_list))
            if self.net.num_class<2:
                if self.cfg['num_task']==1:
                    th_result = [val>0.5 for val in self.result_list]
                    CalculateClsScore(th_result,self.label_list)
                    auc_val = CalculateAuc(self.result_list,self.label_list)
                    results.append(auc_val)
                    if len(model_paths)==1:
                        CalculateClsScoreByTh(self.result_list,self.label_list,acc_flag=True)

                else:
                    for task_idx in range(self.cfg['num_task']):
                        print ('='*60)
                        print ('Current task is %s'%self.cls_name_map_dict[self.cfg['cls_label_index'][task_idx]])

                        current_label_list = np.squeeze(np.array([val[task_idx] for val in self.label_list]))
                        current_result_list = np.squeeze(np.array([val[task_idx] for val in self.result_list]))
                        th_result = np.squeeze(np.array([val>0.5 for val in current_result_list]))
                        print (np.array(current_label_list).shape,np.array(current_result_list).shape,np.array(th_result).shape)
                        CalculateClsScore(th_result,current_label_list)
                        auc_val = CalculateAuc(current_result_list,current_label_list)
            else:
                result = np.squeeze(np.array(self.result_list))
                result = np.argmax(result,axis=1)
                kappa_val = CalculateClsScore(result,self.label_list)
                results.append(kappa_val)
        # self.badcase_lists = [case_id for case_id in range(len(test_dataset)) if self.result_list[case_id]!=self.label_list[case_id]]
        if len(results)>0:
            print ('results is ',results)
            pos = np.argmax(results)
            print ('best weights is ',model_paths[pos])
        self.test_dataset = test_dataset

#     def BadCaseVis(self):
#         for idx in self.badcase_lists:
#             current_data = test_dataset[idx]
#             img = np.squeeze(current_data[0])
#             label = current_data[1]
#             center_slice_idx = int(current_data.shape[0]/2)
#             slice_range = 2
#             print ('='*60)
#             print ('Current label is ',label)
#             for slice_idx in range(center_slice_idx-slice_range,center_slice_idx+slice_range):
#                 plt.imshow(img[slice_idx],cmap='bone')
#                 plt.show()




    def ParametersDecode(self):
        '''
        Generate model parameters
        '''
        self.model_param_name = os.path.join(self.basic_file_path,self.cfg['model_params'])
        self.model_params = LoadJson(self.model_param_name)[self.cfg['model_name']]

    def _CalculateLossTrain(self,train_dataset,device):
        loss_list = []
        print ('number of data inside train dataset is ',len(train_dataset.data))
        train_dataset.mode = 'test'
        self.net.eval()
        for idx in range(len(train_dataset.data)):
            data,label,_ = train_dataset[idx]
            data = data[0]
            data = data[np.newaxis,...]
            data = torch.from_numpy(data)
            label = label[0]
            label = np.array([label])
            label = torch.from_numpy(label)
      
            data = data.to(device=device,dtype=torch.float32)

            if self.cfg['num_task']==1:
                label = label.to(device=device,dtype=torch.float32)
            else:
                label = [val.to(device=device,dtype=torch.float32) for val in label]

            with torch.no_grad():
                prediction = self.net(data)
                #prediction = prediction.squeeze(-1)  

            if self.cfg['num_task']==1:
                prediction = prediction.squeeze(-1)  
                if self.cfg['num_class']>2:
                    label = label.to(device=device,dtype=torch.long)
                loss_list.append([idx,self.criterion([prediction,label]).item()])
            else:
                current_loss = []
                for x,y in zip(prediction,label):
                    x = x.squeeze(-1)
                    current_loss.append(self.criterion([x,y]).item())
                loss_list.append([idx,current_loss])
        train_dataset.mode = 'train'
        self.net.train()
        return loss_list

    def _CalculateDiffTrain(self,train_dataset,device):
        diff_list = []
        print ('number of data inside train dataset is ',len(train_dataset.data))
        train_dataset.mode = 'test'
        self.net.eval()
        for idx in range(len(train_dataset.data)):
            data,label,_ = train_dataset[idx]
            data = data[0]
            data = data[np.newaxis,...]
            data = torch.from_numpy(data)
            label = label[0]
            label = np.array([label])
            label = torch.from_numpy(label)
    
            data = data.to(device=device,dtype=torch.float32)

            if self.cfg['num_task']==1:
                label = label.to(device=device,dtype=torch.float32)
            else:
                label = [val.to(device=device,dtype=torch.float32) for val in label]

            with torch.no_grad():
                prediction = self.net(data)
                prediction = prediction.squeeze(-1)  
                prediction = torch.sigmoid(prediction)

            diff = abs(prediction-label)
            diff_list.append([idx,diff])
        train_dataset.mode = 'train'
        self.net.train()
        return diff_list