# -*- coding:utf-8 -*-
import os
import sys
import six
import glob
import json
import random
import shutil
import datetime
import numpy as np
import pandas as pd
import scipy.ndimage as nd
from tqdm import tqdm
import datetime
import logging
import torch
import torch.nn as nn
from torch import optim
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
import copy
from tensorboardX import SummaryWriter

sys.path.append('./utils')
sys.path.append('./model')
from .utils.RWconfig import LoadJson
from .utils.ReadData import load_data, load_single_data
from .model.modelBuild import build_model
from .model.BasicModules import weight_init
from .utils.DataGen import SurDataSet, InferDataSet, TestDataSet
from .utils.OnlineEval import eval_net, eval_netSE, eval_netFSE
from .utils.OfflineEval import CalculateClsScore, CalculateAuc, CalculateClsScoreByTh
from .utils.loss_func import *
from .utils.gradcam import GradCam, GuidedBackpropReLUModel
from .utils.ImagePro import ShowEdges, HFFilter
from .utils.TripleLoss import TripletFocalLoss
from .utils.TripleBin import TripletBinFocal
from .utils.GradientBoost import GradientBoostFocal
from k8s_utils import CMetricsWriter
import torch.distributed as dist
from sklearn.metrics import precision_score, recall_score
import copy
from pydicom import dicomio

logger = logging.getLogger()
fh = logging.FileHandler('Train.log', encoding='utf-8')
sh = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s')
fh.setFormatter(formatter)
sh.setFormatter(formatter)

logger.addHandler(fh)
logger.addHandler(sh)

logger.setLevel(10)

import os
import numpy as np
import SimpleITK as sitk
import cv2

import imageio
from tqdm import tqdm
import sys

from PIL import Image, ImageDraw, ImageFont


def norm(image, hu_min=-1000.0, hu_max=600.0):
    image = (np.clip(image.astype(np.float32), hu_min, hu_max) - hu_min) / float(hu_max - hu_min)
    return image * 255.

def img2gif(tmp_path, raw_path):
    target_path = os.path.join(os.path.dirname(tmp_path), 'save_images')
    if not os.path.exists(target_path):
        os.makedirs(target_path)

    tmp_files = os.listdir(tmp_path)

    for file in tmp_files:
        print('deal with {}'.format(file))
        npy_info = np.load(os.path.join(tmp_path, file, 'all_rpns.npy'))
        raw_img = sitk.ReadImage(os.path.join(raw_path, file + '.nii.gz'))
        raw_img = sitk.GetArrayFromImage(raw_img)

        all_z = [int(i) for i in list(npy_info[:,2])]
        min_z, max_z = min(all_z)-3, max(all_z)+3

        used_slices = []
        for z in all_z:
            temp = []
            for i in range(z-3, z+3):
                temp.append(i)
            used_slices.append(temp)

        all_images = []
        for it_z in range(min_z, max_z):
            u_idx = []
            for key, value in enumerate(used_slices):
                if it_z in value:
                    u_idx.append(key)
            u_idx = sorted(u_idx)
            print('have {} box'.format(len(u_idx)))
            
            if len(u_idx) == 0:
                all_images.append(cv2.cvtColor(norm(raw_img[it_z,:,:]), cv2.COLOR_GRAY2BGR))
            else:
                worked_slice = norm(raw_img[it_z,:,:])
                worked_slice = cv2.cvtColor(worked_slice, cv2.COLOR_GRAY2BGR)
                for idx in u_idx:
                    z,y,x,dim_z,dim_y,dim_x = npy_info[idx][1:]
                    start_y = max(0, y - dim_y / 2)
                    end_y = min(worked_slice.shape[0], y + dim_y / 2)
                    start_z = max(0, z - dim_z / 2)
                    end_z = min(worked_slice.shape[1], z + dim_z / 2)
                    cv2.rectangle(worked_slice, (int(start_y), int(start_z)), (int(end_y), int(end_z)), (0,0,255), 2)
                all_images.append(worked_slice)
        
        for idx, img in enumerate(all_images):
            save_path = os.path.join(target_path, file)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            cv2.imwrite(os.path.join(save_path, '{:03d}.png'.format(idx)), img)

    output_gif_path = os.path.join(os.path.dirname(target_path), 'save_gifs')
    if not os.path.exists(output_gif_path):
        os.makedirs(output_gif_path)

    for file in os.listdir(target_path):
        files = []
        for f in os.listdir(os.path.join(target_path, file)):
            files.append(f)
        files.sort(key=lambda x: x[:-4])
        frames = []
        for i in tqdm(range(len(files))):
            im = imageio.imread(os.path.join(target_path, file) + '/' + files[i])
            frames.append(im)
        fps = 24.0
        imageio.mimsave(os.path.join(output_gif_path, file+'.gif'), frames, 'GIF', duration=1/fps)


def load_ct_from_dicom(dcm_path, sort_by_distance=True):
    class DcmInfo(object):
        def __init__(self, dcm_path, series_instance_uid, acquisition_number, sop_instance_uid, instance_number,
                     image_orientation_patient, image_position_patient):
            super(DcmInfo, self).__init__()

            self.dcm_path = dcm_path
            self.series_instance_uid = series_instance_uid
            self.acquisition_number = acquisition_number
            self.sop_instance_uid = sop_instance_uid
            self.instance_number = instance_number
            self.image_orientation_patient = image_orientation_patient
            self.image_position_patient = image_position_patient

            self.slice_distance = self._cal_distance()

        def _cal_distance(self):
            normal = [self.image_orientation_patient[1] * self.image_orientation_patient[5] -
                      self.image_orientation_patient[2] * self.image_orientation_patient[4],
                      self.image_orientation_patient[2] * self.image_orientation_patient[3] -
                      self.image_orientation_patient[0] * self.image_orientation_patient[5],
                      self.image_orientation_patient[0] * self.image_orientation_patient[4] -
                      self.image_orientation_patient[1] * self.image_orientation_patient[3]]

            distance = 0
            for i in range(3):
                distance += normal[i] * self.image_position_patient[i]
            return distance

    def is_sop_instance_uid_exist(dcm_info, dcm_infos):
        for item in dcm_infos:
            if dcm_info.sop_instance_uid == item.sop_instance_uid:
                return True
        return False

    def get_dcm_path(dcm_info):
        return dcm_info.dcm_path

    reader = sitk.ImageSeriesReader()
    if sort_by_distance:
        dcm_infos = []

        files = os.listdir(dcm_path)
        for file in files:
            file_path = os.path.join(dcm_path, file)

            dcm = dicomio.read_file(file_path, force=True)
            _series_instance_uid = dcm.SeriesInstanceUID
            _sop_instance_uid = dcm.SOPInstanceUID
            _instance_number = dcm.InstanceNumber
            _acquisition_number = dcm.AcquisitionNumber
            _image_orientation_patient = dcm.ImageOrientationPatient
            _image_position_patient = dcm.ImagePositionPatient

            dcm_info = DcmInfo(file_path, _series_instance_uid, _acquisition_number, _sop_instance_uid,
                               _instance_number, _image_orientation_patient, _image_position_patient)

            if is_sop_instance_uid_exist(dcm_info, dcm_infos):
                continue

            dcm_infos.append(dcm_info)

        dcm_infos.sort(key=lambda x: x.slice_distance)
        dcm_series = list(map(get_dcm_path, dcm_infos))
    else:
        dcm_series = reader.GetGDCMSeriesFileNames(dcm_path)

    reader.SetFileNames(dcm_series)
    sitk_image = reader.Execute()
    return sitk_image

class Metric(object):
    def __init__(self, name):
        self.name = name
        self.sum = torch.tensor(0.)
        self.n = torch.tensor(0.)

    def update(self, val):
        val = val.clone().detach()
        output_tensors = [val.clone().to(val) for _ in range(torch.distributed.get_world_size())]
        dist.all_gather(output_tensors, val)
        self.sum += torch.sum(torch.stack(output_tensors)).cpu().detach()
        self.n += 1

    @property
    def avg(self):
        return self.sum / self.n
    
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    # get the accuracy of number preds correctly
    # if math.isnan(torch.sum(preds == labels).item() / len(preds)):
    #     print('preds:', preds)
    #     print('check: ', preds == labels)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def precision(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    # get the accuracy of number preds correctly
    return torch.tensor(precision_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro'))

def recall(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    # get the accuracy of number preds correctly
    return torch.tensor(recall_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro'))

class Classification3D(object):
    def __init__(self, cfg, mode, log_msg):
        self.cfg = cfg
        self.mode = mode
        self.log_msg = log_msg
        if mode == 'training':
            self.mode_cfg = cfg.training
        else:
            logger.info('load testing parameters...')
            self.mode_cfg = cfg.testing

        if mode == 'training':
            self.metrics_writer = CMetricsWriter(
                filename=cfg.training['train_metrics'],
                headernames=['epoch', 'accuracy', 'f1_score', 'precision', 'recall', 'lr', 'lossTotal', 'performanceAccuracy']
            )
        
        self.basic_file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        self.n_input_channel = 1
        self._clsInfoDecode()
        
        if mode == 'training': 
            # self.traindata_npy = self.mode_cfg['train_data_path']
            # self.traininfo_npy = self.mode_cfg['train_info_path']
            # self.valdata_npy = self.mode_cfg['val_data_path']
            # self.valinfo_npy = self.mode_cfg['val_info_path']
            # self.testdata_npy = []
            # self.testinfo_npy = []
            self.traindata = self.mode_cfg['train_data_path']
            self.valdata = self.mode_cfg['val_data_path']
            self.testdata = []
            print(self.traindata, self.valdata)
        else:
            # self.traindata_npy = []
            # self.traininfo_npy = []
            # self.valdata_npy = []
            # self.valinfo_npy = []
            # self.testdata_npy = self.mode_cfg['test_data_path']
            # self.testinfo_npy = self.mode_cfg['test_info_path']
            self.traindata = []
            self.valdata = []
            self.testdata = self.mode_cfg['test_data_path']

        self.ParametersDecode()
        self.mode_cfg['model_mode'] = 'basic'
        # self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    def __call__(self):

        '''
        训练/infer流程
        1. buildModel&compileModel
        2. 载入数据
        3. 根据选择模式,进行训练or inference
        '''
        # self._LoadData(self.mode)
        if self.mode == 'training':
            self._buildModel()
            self._compileModel()
            self._trainModel()
        elif self.mode == 'infer':
            self._buildModel()
            self._inferModel()
        else:
            self._buildModel()
            self._testModel()


    def _clsInfoDecode(self):
        self.cls_map_dict = eval(self.mode_cfg['cls_map_dict'])
        self.ori_cls_types = set([self.cls_map_dict[key] for key in self.cls_map_dict.keys()])
        self.input_shape = [self.mode_cfg['patch_height'], self.mode_cfg['patch_height'], self.mode_cfg['patch_depth']]
        self.cls_name_map_dict = eval(self.mode_cfg['cls_name_map_dict'])
        # self.save_path = self.mode_cfg['save_path']
        # if (not os.path.exists(self.save_path)) and self.mode == 'training':
        #     os.makedirs(self.save_path)
        if type(self.mode_cfg['cls_label_index']) == list:
            self.mode_cfg['num_task'] = len(self.mode_cfg['cls_label_index'])
        else:
            self.mode_cfg['num_task'] = 1

    def _buildModel(self):
        '''
        buildModel 流程
        1. 读取模型参数并修改部分
            1. num_task:有几个要分类的任务
            2. num_class:有几类(目前多个要分类的任务类别数目必须相同)
            3. arcface_norm_choice:arcface是否进行归一化
            4. easy_margin: arcface相关参数
                如果设置为True,只要cos_theta大于0,就设置为乘性压缩的角度
                否则,要cos_theta大于th,设置为乘性压缩角度,cos_theta小于th的情况,设置为减性压缩角度
            5. freeze_blocks:不训练的block
        '''

        cfg = self.model_params
        cfg['num_task'] = self.mode_cfg['num_task']
        cfg['num_class'] = self.mode_cfg['num_class']
        cfg['arcface_norm_choice'] = self.mode_cfg.get('norm_choice', False)
        cfg['easy_margin'] = self.mode_cfg.get('easy_margin', True)
        cfg['freeze_blocks'] = eval(self.mode_cfg.get('freeze_blocks', "[]"))

        if cfg['num_class'] == 2:
            cfg['num_class'] = 1

        self.n_input_channel = cfg['n_input_channel'] if 'n_input_channel' in cfg.keys() else self.n_input_channel
        self.net = build_model(self.mode_cfg['model_name'], cfg)
        # print('model:',self.net)
        self.net_init = copy.deepcopy(self.net)
        if self.cfg.pretrain_msg:
            print('pretrain weight:', self.cfg.pretrain_msg)
            # if self.cfg.pretrain_msg.split('.')[-1] == 'pt':
            try:
                logger.info("load model from %s" % self.cfg.pretrain_msg)

                need_keys = [key for key,_ in self.net.named_parameters()] ##
                print('get need_keys')
                logger.info('get need_keys')
                checkpoints = torch.load(self.cfg.pretrain_msg, map_location=torch.device('cpu'))
                print('get checkpoints')
                # print(checkpoints.keys())
                logger.info('get checkpoints')
                if self.cfg.mode == 'testing':
                    new_checkpoints = {key[7:]:value for key, value in checkpoints.items() if key[7:] in need_keys}
                else:
                    new_checkpoints = {key:value for key, value in checkpoints.items() if key in need_keys}

                print('Start load ...')
                # print(list(new_checkpoints.keys())[0], new_checkpoints[list(new_checkpoints.keys())[0]])
                logger.info('Start load ...')
                self.net_init.load_state_dict(new_checkpoints)
                print('Have already load pretrain_msg !')
                logger.info('Have already load pretrain_msg !')
            except:
                self.net_init = torch.jit.load(self.cfg.pretrain_msg)
                print('load pt!!!!!')
                
        # for a,b in self.net.named_parameters():
        #     break
        # logger.info(b)
        # print(b)
            
        ## DDP  
        logger.info(self.mode)
        if self.mode == 'training':
            print('start')
            logger.info('start')
            self.net = self.net.cuda()
            logger.info('cuda')
            logger.info(self.cfg.local_rank)
            logger.info(torch.cuda.device_count())
            self.net = torch.nn.parallel.DistributedDataParallel(self.net, 
                                                      device_ids=[int(self.cfg.local_rank)]) 
                                                    #   broadcast_buffers=False, 
                                                    #   find_unused_parameters=True)
            logger.info('have ddp')
        else:
            self.net = self.net_init

    def _compileModel(self):
        self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.net.parameters()),
                                    lr=self.cfg.lr)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=self.mode_cfg['patience'],
                                                              factor=0.5, verbose=True)
        if self.mode_cfg['num_class'] > 2:
            self.criterion = nn.CrossEntropyLoss()
        else:

            self.criterion = nn.BCEWithLogitsLoss()

        # conf_paths = [self.filename, self.model_param_name]
        # if self.mode == 'training':
        #     conf_paths += [self.trainfile_conf_path, self.valfile_conf_path]
        #     for conf_path in conf_paths:
        #         filename = conf_path.split('/')[-1]
        #         target_save_path = os.path.join(self.save_path, filename)
        #         shutil.copy(conf_path, target_save_path)
        # else:
        #     conf_paths += [self.test_conf_path]

    def _DecodeGeneratorParam(self):
        high_freq_filter = self.mode_cfg['high_freq_filter'] if 'high_freq_filter' in self.mode_cfg.keys() else False
        rescale_intensity = self.mode_cfg['rescale_intensity'] if 'rescale_intensity' in self.mode_cfg.keys() else False
        nonfix_crop = self.mode_cfg['nonfix_crop'] if 'nonfix_crop' in self.mode_cfg.keys() else False
        aug = True if self.mode == 'training' else False
        model_mode = self.mode_cfg['model_mode'] if 'model_mode' in self.mode_cfg.keys() else 'basic'
        reg_map_dict = eval(self.mode_cfg['reg_map_dict']) if 'reg_map_dict' in self.mode_cfg.keys() else None
        self.regression_choice = self.mode_cfg['regression_choice'] if 'regression_choice' in self.mode_cfg.keys() else False
        self.generator_parameters = {
            'cls_index': self.mode_cfg['cls_label_index'],
            'cls_map_dict': self.cls_map_dict,
            'aug': aug,
            'nonfix_crop': nonfix_crop,
            'diameter_index': self.mode_cfg['diameter_index'], 'diameter_max_th': self.mode_cfg['diameter_max_th'],
            'diameter_min_th': self.mode_cfg['diameter_min_th'],
            'num_class': self.mode_cfg['num_class'],
            'HU_min': self.mode_cfg['HU_min'], 'HU_max': self.mode_cfg['HU_max'],
            'cls_types': self.ori_cls_types,
            'n_input_channel': self.n_input_channel,
            'high_freq_filter': high_freq_filter,
            'rescale_intensity': rescale_intensity,
            "input_shape": self.input_shape,
            "model_mode": model_mode,
            "regression_choice": self.regression_choice,
            "reg_map_dict": reg_map_dict,
            'cls_balance': self.mode_cfg.get('cls_balance', False)

        }

    def _DecodeDataParam(self, filename):
        paths = LoadJson(filename)
        data_paths = [record['data_path'] for record in paths]
        info_paths = [record['info_path'] for record in paths]
        return {'data_paths': data_paths,
                'info_paths': info_paths}

    # def _LoadData(self, mode):
    #     '''
    #     载入npy格式的数据
    #     '''
    #     if mode == 'training':
    #         self.train_data, self.train_info = load_data([self.traindata_npy, self.traininfo_npy])
    #         self.val_data, self.val_info = load_data([self.valdata_npy, self.valinfo_npy])
    #     else:
    #         print('testdata_npy and info: ',self.testdata_npy, self.testinfo_npy)
    #         self.test_data, self.test_info = load_data([self.testdata_npy, self.testinfo_npy])

    def _GetLoss(self, prob_diff, gamma):
        model_mode = self.mode_cfg['model_mode'] if 'model_mode' in self.mode_cfg.keys() else 'basic'
        loss_choice = self.mode_cfg['loss_func'] if 'loss_func' in self.mode_cfg.keys() else 'focal'
        lambda_val = self.mode_cfg['lambda_val'] if 'lambda_val' in self.mode_cfg.keys() else 1
        margin_val = self.mode_cfg['margin_val'] if 'margin_val' in self.mode_cfg.keys() else 0.0
        print('model_mode is ', model_mode)

        '''
        gradientBoost的参数,是否使用gradient_boost/选择保留几个类别
        '''
        gradientBoost_choice = self.mode_cfg['gradientBoost'] if 'gradientBoost' in self.mode_cfg.keys() else False
        gradientBoost_k_val = self.mode_cfg['gradientBoost_k'] if 'gradientBoost_k' in self.mode_cfg.keys() else 1

        '''
        基于focal loss基本的weights(1/该类样本量->归一化),是否要对某一类的weights进行修改,如果否就设置为None,否则设置为对应的loss
        '''
        alpha_multi_ratio = eval(self.mode_cfg['alpha_multi_ratio']) if 'alpha_multi_ratio' in self.mode_cfg.keys() else None

        '''
        得到regression_task的loss
        '''
        self.regression_loss = torch.nn.MSELoss() if self.regression_choice else None
        self.regression_loss = torch.nn.SmoothL1Loss() if self.mode_cfg.get('regression_loss',
                                                                       'MSE') == 'SmoothL1Loss' else self.regression_loss

        self.convert_softmax = True
        print("self.mode_cfg['num_task']: ", self.mode_cfg['num_task'])
        if self.mode_cfg['focal_loss'] and self.mode_cfg['num_task'] == 1:
            # ratios = self.train_dataset._GenerateRatios()
            ratios = [0.25, 0.25, 0.25, 0.25]
            if self.mode_cfg['num_class'] <= 2:
                alpha = ratios[0]
                if model_mode == 'deux':
                    self.criterion = SeLossWithBCEFocalLoss(alpha=alpha, gamma=gamma, lambda_val=lambda_val,
                                                            prob_diff=prob_diff)
                elif model_mode == 'triple':
                    self.criterion = TripletBinFocal(self.cfg['num_class'], alpha=alpha, gamma=gamma,
                                                     lambda_val=lambda_val, margin=margin_val)
                else:
                    self.criterion = BCEFocalLoss(alpha=alpha, gamma=gamma)

            else:
                '''
                多类情况下的loss
                '''
                alpha = ratios
                if model_mode == 'deux':
                    
                    # ratios = [0.25, 0.25, 0.25, 0.25]
                    print('parameters-alpha: ', ratios)
                    self.criterion = SE_FocalLoss(self.mode_cfg['num_class'], alpha=ratios, gamma=gamma,
                                                  lambda_val=lambda_val, gradient_boost=gradientBoost_choice,
                                                  alpha_multi_ratio=alpha_multi_ratio,
                                                  convert_softmax=self.convert_softmax, k=gradientBoost_k_val,
                                                  margin=margin_val, new_se_loss=self.mode_cfg.get('new_se_loss', False))
                    print('self.alpha: ', self.criterion.alpha)
                elif model_mode == 'triple':
                    self.criterion = TripletFocalLoss(self.mode_cfg['num_class'], alpha=ratios, gamma=gamma,
                                                      lambda_val=lambda_val, margin=margin_val)
                else:
                    if gradientBoost_choice:
                        self.criterion = GradientBoostFocal(self.mode_cfg['num_class'], alpha=ratios, gamma=gamma)
                    else:
                        
                        # ratios = [0.25, 0.25, 0.25, 0.25]
                        print('parameters-alpha: ', ratios)
                        self.criterion = FocalLoss(self.mode_cfg['num_class'], alpha=ratios, gamma=gamma,
                                                   alpha_multi_ratio=alpha_multi_ratio,
                                                   convert_softmax=self.convert_softmax)
                        print('self.alpha: ', self.criterion.alpha)
        else:
            alpha = None

    #         return alpha

    def _trainOnBatch(self, epoch, batch, label_dtype):
        imgs, labels, label_diff = batch[0]
        label_diff = [single_label_diff.type(label_dtype).cuda() for single_label_diff in label_diff]

        imgs = [single_set_img.type(torch.float32).cuda() for single_set_img in imgs]

        labels = [single_label.type(label_dtype).cuda() for single_label in labels]

        prediction = self.net(imgs, labels=labels)
        
        # 计算训练输出
        whole_pred = torch.cat(prediction, dim=0)
        whole_labels = torch.cat(labels, dim=0)
        acc = accuracy(whole_pred, whole_labels)
        out_precision = precision(whole_pred, whole_labels)
        out_recall = recall(whole_pred, whole_labels)
        out_f1 = 2*out_precision*out_recall / (out_precision + out_recall)

        '''
        网络预测的前len(labels)是logits。len(labels)->2*len(labels)部分如果存在就是features
        '''
        logits = [pred.squeeze(-1) for pred in prediction[:len(labels)]]
        features = [pred.squeeze(-1) for pred in prediction[len(labels):len(labels) * 2]]
        cons = prediction[-1].squeeze(-1)

        loss_input = [logits] + [labels] + [label_diff]
        loss = self.criterion(loss_input)
        # self.metrics_writer.append_one_line(
        #     [epoch, acc.item(), out_f1.item(), out_precision.item(), out_recall.item(), self.optimizer.param_groups[0]['lr'], loss.item(), acc.item()]
        # )
        return loss, acc, out_f1, out_precision, out_recall

    def _trainOnBatchWithReg(self, epoch, batch, label_dtype, device):
        regression_ratio = self.mode_cfg.get('regression_ratio', 0.05)

        imgs, labels, label_diff = batch[0]
        label_diff = [single_label_diff.cuda().to(device=device, dtype=label_dtype) for single_label_diff in label_diff]
        imgs = [single_set_img.cuda().to(device=device, dtype=torch.float32) for single_set_img in imgs]
        cls_labels = [single_label[0].cuda().to(device=device, dtype=label_dtype) for single_label in labels]
        reg_labels = [single_label[1].cuda().to(device=device, dtype=torch.float32) for single_label in labels]
        prediction = self.net(imgs, labels=cls_labels)
        logits = [pred[0].squeeze(-1) for pred in prediction[:len(labels)]]
        reg_outputs = [pred[1].squeeze(-1) for pred in prediction[:len(labels)]]

        loss_input = [logits] + [cls_labels] + [label_diff]
        loss = self.criterion(loss_input)

        reg_loss = 0

        for sample_idx in range(len(reg_outputs)):
            reg_loss_val = self.regression_loss(reg_outputs[sample_idx], reg_labels[sample_idx])
            if sample_idx == 0:
                reg_loss = reg_loss_val
            else:
                reg_loss += reg_loss_val

        loss = regression_ratio * reg_loss + loss
        # 计算训练输出
        whole_pred = torch.cat(prediction, dim=0)
        whole_labels = torch.cat(labels, dim=0)
        acc = accuracy(whole_pred, whole_labels)
        out_precision = precision(whole_pred, whole_labels)
        out_recall = recall(whole_pred, whole_labels)
        out_f1 = 2*out_precision*out_recall / (out_precision + out_recall)

        # self.metrics_writer.append_one_line(
        #     [epoch, acc, out_f1, out_precision, out_recall, self.optimizer.param_groups[0]['lr'], loss.item()]
        # )
        return 

    def _trainModel(self):
        ########## Add parameters

        '''
        设置训练的device
        '''
        # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        '''
        如果pretrain_weight_path不为空,则load进来
        '''
        # if self.cfg.pretrain_msg and self.cfg.rank == 0:
        #     logger.info("load model from %s" % self.cfg.pretrain_msg)
            
        #     need_keys = [key for key,_ in self.net.named_parameters()]
        #     checkpoints = torch.load(self.cfg.pretrain_msg)
        #     new_checkpoints = {key:value for key, value in checkpoints.items() if key in need_keys}
        #     self.net.load_state_dict(new_checkpoints)
            # for name, value in self.net.named_parameters():
            #     print(name, value.shape)
            

        '''
        读取关于loss和数据处理的参数
        nonfix_crop:输入的大小是否一致
        train_aug_choice:train的数据是否进行augmentation
        gamma:focal loss参数
        feature_sim:
        prob_diff:
        '''
        nonfix_crop = self.mode_cfg['nonfix_crop'] if 'nonfix_crop' in self.mode_cfg.keys() else False
        train_aug_choice = self.mode_cfg['train_aug_choice'] if 'train_aug_choice' in self.mode_cfg.keys() else False
        # print('type: ', type(self.cfg))
        # print('cfg: ', self.cfg)
        gamma = self.cfg.lr_gamma if hasattr(self.cfg, 'lr_gamma') else 1
        prob_diff = self.mode_cfg['prob_diff'] if 'prob_diff' in self.mode_cfg.keys() else False

        ############ generate params for generator
        self._DecodeGeneratorParam()

        # train_dataset = SurDataSet(self.train_data, self.train_info, **self.generator_parameters)
        train_dataset = SurDataSet(self.traindata, **self.generator_parameters)
        self.generator_parameters['aug'] = False
        # val_dataset = SurDataSet(self.val_data, self.val_info, **self.generator_parameters)
        val_dataset = SurDataSet(self.valdata, **self.generator_parameters)

        balance_data_kws = self.mode_cfg['balance_data'] if 'balance_data' in self.mode_cfg.keys() else [0, 0]
        modes = ['val' if val == 0 else 'train' for val in balance_data_kws]
        # train_dataset._GenerateTrainData(mode=modes[0])
        # val_dataset._GenerateTrainData(mode=modes[1])
        self.train_dataset = train_dataset

        '''
        根据参数情况得到训练的loss
        '''
        self._GetLoss(prob_diff, gamma)
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
        train_loader = DataLoader(train_dataset, batch_size=self.cfg.batch_size, sampler=train_sampler)
        val_loader = DataLoader(val_dataset, batch_size=self.cfg.batch_size, sampler=val_sampler)

        lr = self.cfg.lr
        batch_size = self.cfg.batch_size
        global_step = 0

        epochs = self.cfg.epochs
        criterion = self.criterion
        val_criterion = self.criterion
        label_dtype = torch.long if self.mode_cfg['num_class'] > 2 else torch.float32

        n_train = int(round(len(train_dataset) / self.cfg.batch_size))

        best_val_score = 1000000

        for epoch in range(epochs):
            train_sampler.set_epoch(epoch)
            self.net.train()
            # self.net.to(device)
            self.net.cuda()
            epoch_loss = 0
            print('size of train dataset is ', len(train_dataset))
            if torch.distributed.get_rank() == 0:
                epoch_acc = 0
                epoch_f1 = 0
                epoch_precision = 0
                epoch_recall = 0
                epoch_num = 0
            with tqdm(total=n_train, desc=f'Epoch{epoch + 1}/{epochs}', unit='img') as pbar:
                
                batch_count = 0
                for batch in zip(train_loader):
                    batch_count += 1
                    loss_values = []
                    # if self.regression_choice:
                    #     loss = self._trainOnBatchWithReg(epoch, batch, label_dtype, device)
                        
                    # else:
                    loss, acc, out_f1, out_precision, out_recall = self._trainOnBatch(epoch, batch, label_dtype)
                    # logger.info('torch.distributed.get_rank(): ', torch.distributed.get_rank())
                    if torch.distributed.get_rank() == 0:
                        # logger.info('deal with epoch_num ...')
                        epoch_acc += acc.item()
                        epoch_f1 += out_f1.item()
                        epoch_precision += out_precision.item()
                        epoch_recall += out_recall.item()
                        epoch_num += 1

                    self.optimizer.zero_grad()
                    loss.backward()
                    epoch_loss += loss.item()
                    nn.utils.clip_grad_value_(self.net.parameters(), 1e-1)
                    self.optimizer.step()
                    global_step += 1
                    pbar.update()
            
            if torch.distributed.get_rank() == 0:
                print('self.metrics_writer.append_one_line -- loss: ', loss)
                self.metrics_writer.append_one_line(
                    [epoch, 
                     epoch_acc / epoch_num, 
                     epoch_f1 / epoch_num,
                     epoch_precision / epoch_num, 
                     epoch_recall / epoch_num, 
                     self.optimizer.param_groups[0]['lr'], 
                     loss.item(), 
                     epoch_acc / epoch_num]
                )

            # random.shuffle(train_dataset.indices)
            print('val_criterion: ', val_criterion)
            val_score = eval_netSE(self.net, val_dataset, val_loader, base_cfg=self.cfg, cfg=self.mode_cfg, criterion=val_criterion,
                                   reg_loss=self.regression_loss)
            ############ change lr based on val result
            self.scheduler.step(val_score)
            logger.info("Train loss: {}".format(epoch_loss / n_train))
            logging.info("Validation loss: {}".format(val_score))

            if val_score <= best_val_score and self.cfg.rank == 0:
                torch.save(self.net.state_dict(), self.mode_cfg['save_path'] + '/model.pth')
                best_val_score = val_score
                with open(self.mode_cfg['train_result'], 'w+') as file:
                    json.dump(
                        {
                            "successFlag":"TRAINING", 
                            "bestModelEpoch": epoch
                        }, file, indent=4)
                # 更新eval中的performance.json
                with open(self.mode_cfg['eval_pjson'], 'w+') as file:
                    json.dump(
                        {
                            "loss": val_score,
                        }, file, indent=4)

                        
                # 更新eval中的performance.md
                with open(self.mode_cfg['eval_pmd'], 'w+') as file:
                    file.write('# overall performance \n')
                    file.write('| loss | \n')
                    file.write('| -------- | \n')
                    file.write(fr'| {val_score} | \n')

        if self.cfg.rank == 0:
            json_info = json.load(open(self.mode_cfg['train_result'], 'r'))
            json_info['successFlag'] = 'SUCCESS'
            with open(self.mode_cfg['train_result'], 'w') as file:
                json.dump(json_info, file, indent=4)
            shutil.copy(self.cfg.pretrain_msg, self.mode_cfg['save_path'] + '/model.pth')
            print('save best ok')
            
    
    def _testModel_ori(self):
        # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        # model_path = self.mode_cfg['weight_path']
        # self.net.load_state_dict(torch.load(model_path, map_location=device))
        # logger.info("load model from %s" % model_path)
        # self.net.to(device=device)
        assert torch.cuda.is_available()
        self.net = self.net.cuda()
        self.net.eval()

        gamma = self.cfg.lr_gamma if hasattr(self.cfg, 'lr_gamma') else 1
        prob_diff = self.mode_cfg['prob_diff'] if 'prob_diff' in self.mode_cfg.keys() else False
        ############ generate params for generator
        self._DecodeGeneratorParam()
        self.generator_parameters['aug'] = False

        val_dataset = SurDataSet(self.test_data, self.test_info, **self.generator_parameters)
        # print('val_dataset: ',len(val_dataset))

        val_dataset._GenerateTrainData()

        self._GetLoss(prob_diff, gamma)
        val_loader = DataLoader(val_dataset, batch_size=self.mode_cfg['batch_size'], shuffle=False)
        val_criterion = self.criterion
        val_score = eval_netSE(self.net, val_dataset, val_loader, base_cfg=self.cfg, cfg=self.mode_cfg, criterion=val_criterion,
                                reg_loss=self.regression_loss)
        # 更新eval中的performance.json
        with open(self.mode_cfg['eval_pjson'], 'w+') as file:
            json.dump(
                {
                    "precision": val_score,
                }, file, indent=4)
                
        # 更新eval中的performance.md
        with open(self.mode_cfg['eval_pmd'], 'w+') as file:
            file.write('# overall performance \n')
            file.write('| precision | \n')
            file.write('| -------- | \n')
            file.write(fr'| {val_score} | \n')

        with open(self.mode_cfg['eval_result'], 'w+') as file:
            json.dump(
                    {
                        "successFlag": "SUCCESS",
                    },file, indent=4)

    def _testModel(self):
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        # model_path = self.mode_cfg['weight_path']
        # self.net.load_state_dict(torch.load(model_path, map_location=device))
        # logger.info("load model from %s" % model_path)
        # self.net.to(device=device)
        print('all cfg:')
        print(self.cfg)
        print('model cfg')
        print(self.mode_cfg)
        assert torch.cuda.is_available()
        self.net = torch.jit.load(self.cfg.pretrain_msg)
        self.net = self.net.cuda()
        self.net.eval()

        gamma = self.cfg.lr_gamma if hasattr(self.cfg, 'lr_gamma') else 1
        prob_diff = self.mode_cfg['prob_diff'] if 'prob_diff' in self.mode_cfg.keys() else False
        ############ generate params for generator
        self._DecodeGeneratorParam()
        self.generator_parameters['aug'] = False

        val_dataset = TestDataSet(self.testdata, **self.generator_parameters)
        # val_loader = DataLoader(val_dataset, batch_size=self.mode_cfg['batch_size'], shuffle=False)
        # batch_size = self.mode_cfg['batch_size']
        # print(f'batch size: {batch_size}')
        labels = []
        preds = []
        limits = []
        # for data, label, is_limit in val_loader:
        for data, label, is_limit in val_dataset:
            # data = data.float().cuda()
            data = torch.from_numpy(data[None, ...]).to(device=device, dtype=torch.float32)
            pred = self.net(data)
            # print(pred)
            labels.append(label)
            pred = torch.softmax(pred[0], -1)
            pred = torch.max(pred, -1)[1]
            # pred = torch.argmax(pred[0])
            # print('pred', pred, pred.shape)
            preds.append(pred)

            # limits += is_limit.reshape(-1).tolist()
            limits.append(is_limit)
            # print(f'data shape: {data.shape}, label shape: {label.shape}, pred shape: {pred.shape}, is limit shape: {is_limit.shape}')
        # labels = torch.cat(labels, 0).cpu().numpy()
        labels = np.array(labels)
        preds = torch.cat(preds, 0).cpu().numpy()
        # print(labels.shape, preds.shape)
        # preds = torch.cat(preds, 0).cpu().numpy()
        # print(f'labels shape: {labels.shape}, preds shape: {preds.shape}, limits shape: {len(limits)}')
        # print(f'labels: {labels}')
        # print(f'preds: {preds}')
        
        preds_new = []
        for pred_cls_ori, limit in zip(preds, limits):
            if pred_cls_ori == 1:
                pred_cls = 2
            elif pred_cls_ori == 2:
                pred_cls = 1
            else:
                pred_cls = pred_cls_ori
            # if pred_cls_ori == 3:
            #     if limit:
            #         pred_cls = 5
            # else:
            #     if limit:
            #         pred_cls = 4
            preds_new.append(pred_cls)
        preds_new = np.array(preds_new)
        acc = np.sum(preds_new == labels) / preds_new.shape[0]
        num_gts = preds_new.shape[0]
        num_tps = np.sum(preds_new == labels)

        print('gt info:----------------')
        elems, elems_count = np.unique(labels, return_counts=True)
        print('gts:', elems)
        print('gt counts:', elems_count)

        print('pred info:----------------')
        elems, elems_count = np.unique(preds_new, return_counts=True)
        print('preds:', elems)
        print('pred counts:', elems_count)

        print('=================')
        print('num gts:', num_gts)
        print('num tps:', num_tps)

        # 更新eval中的performance.json
        with open(self.mode_cfg['eval_pjson'], 'w+') as file:
            json.dump(
                {
                    "accuracy": acc,
                }, file, indent=4)
                
        # 更新eval中的performance.md
        with open(self.mode_cfg['eval_pmd'], 'w+') as file:
            file.write('# overall performance \n')
            file.write('| accuracy |\n')
            file.write('| -------- |\n')
            file.write(fr'| {acc} |\n')

        with open(self.mode_cfg['eval_result'], 'w+') as file:
            json.dump(
                    {
                        "successFlag": "SUCCESS",
                    },file, indent=4)

    def make_infer_data_info(self, info_path, data_path):
        info_dict = {}
        with open(info_path, 'r+') as f:
            infer_data_info = json.load(f)
        for i in infer_data_info['dataList']:
            uid = i['rawDataUrls'][0].split('/')[-1].split('.zip')[0]
            if uid not in info_dict:
                info_dict[uid] = {'dcm_path': os.path.join(data_path, uid)}
            for j in i['annotations']:
                ann_path = j['annotationUrls'][0]
                with open(ann_path, 'r+') as f:
                    ann = json.load(f)
                for index, k in enumerate(ann['annotationSessions'][0]['annotationSet']):
                    info_dict[uid][index] = k['coordinates']
        return info_dict

    def _inferModel(self):
        class_map = {0: '肺内磨玻璃结节', 
                    1:'肺内混合结节', 
                    2: '肺内实性结节', 
                    3: '肺内钙化结节',
                    4: '胸膜结节或斑块',
                    5: '胸膜钙化结节'}
        class_map2 = {'肺内磨玻璃结节': 'mbl', 
                    '肺内混合结节': 'hh', 
                    '肺内实性结节': 'sx', 
                    '肺内钙化结节': 'gh',
                    '胸膜结节或斑块': 'xm',
                    '胸膜钙化结节': 'xmgh'}
        print(class_map)
        assert torch.cuda.is_available()
        self.net = self.net.cuda()
        self.net.eval()

        gamma = self.cfg.lr_gamma if hasattr(self.cfg, 'lr_gamma') else 1
        prob_diff = self.mode_cfg['prob_diff'] if 'prob_diff' in self.mode_cfg.keys() else False
        ############ generate params for generator
        self._DecodeGeneratorParam()
        self.generator_parameters['aug'] = False

        val_dataset = InferDataSet(self.testdata, **self.generator_parameters)
        batch_size = self.mode_cfg['batch_size']
        print(f'batch size: {batch_size}')
        # preds = []
        info_dict = self.make_infer_data_info(self.cfg.dataset_info_path, self.cfg.dcm_data_path)
        print(f'info dict: {info_dict}')
        pred_dict = {}

        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.2
        font_color = (0,0,255)
        font_thickness = 1
        font_size = 10
        # text_position = (0,0)
        text = '肺结节'
        (text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, font_thickness)
        adjusted_position = (20-text_w, 20+text_h)
        font_path = '/shared/msyhl.ttc'
        for data, info in val_dataset:
            uid = info[4]
            is_limit = info[5]
            if uid not in pred_dict:
                pred_dict[uid] = {}
            index = info[1]
            # print(data.shape)
            data = torch.from_numpy(data[None, ...]).float().cuda()
            pred = self.net(data)
            pred = torch.softmax(pred[0], -1)
            pred = torch.max(pred, -1)[1]
            pred = pred.cpu().numpy()[0]
            if pred == 3:
                if is_limit:
                    pred = 5
            else:
                if is_limit:
                    pred = 4
            pred_cls = class_map[pred]
            pred_dict[uid][index] = pred_cls
            # print('pred', pred, pred.shape)
            # preds.append(pred)
            # print(f'data shape: {data.shape}, pred shape: {pred.shape}')
        # preds = torch.cat(preds, 0).cpu().numpy()
        # for i in preds:
        #     pred_cls = class_map[i]
        #     print(f'pred class: {pred_cls}')
        print(f'pred dict: {pred_dict}')

        target_path = self.cfg.save_img_dir
        for uid, pred in pred_dict.items():
            dcm_path = info_dict[uid]['dcm_path']
            itk_image = load_ct_from_dicom(dcm_path)
            origin = itk_image.GetOrigin()
            spacing = itk_image.GetSpacing()
            image = sitk.GetArrayFromImage(itk_image)
            x_origin, y_origin, z_origin = origin
            x_spacing, y_spacing, z_spacing = spacing


            for index, pred_cls in pred.items():
                gt_info = info_dict[uid][index]
                x1, y1, z1, x2, y2, z2 = gt_info[0]['x'], gt_info[0]['y'], gt_info[0]['z'], gt_info[1]['x'], gt_info[1]['y'], gt_info[1]['z']
                x1 = int((x1 - x_origin)/x_spacing)
                y1 = int((y1 - y_origin)/y_spacing)
                z1 = int((z1 - z_origin)/z_spacing)
                x2 = int((x2 - x_origin)/x_spacing)
                y2 = int((y2 - y_origin)/y_spacing)
                z2 = int((z2 - z_origin)/z_spacing)

                all_images = []
                z_min = max(0, z1-4)
                z_max = min(image.shape[0]-1, z2+4)
                for z in range(z_min, z_max, 1):
                    img_z = image.copy()[z, :, :]
                    img_z = norm(img_z)
                    img_z = cv2.cvtColor(img_z, cv2.COLOR_GRAY2BGR)
                    # cv2.rectangle(img_z, (int(y1), int(x1)), (int(y2), int(x2)), (0,0,255), 2)
                    if z>=z1 and z<=z2:
                        cv2.rectangle(img_z, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)

                    cut_x_min = max(0, x1-64)
                    cut_y_min = max(0, y1-64)
                    cut_x_max = min(image.shape[2]-1, x2+64)
                    cut_y_max = min(image.shape[1]-1, y2+64)

                    img_z = img_z[cut_y_min: cut_y_max, cut_x_min: cut_x_max, :]
                    # img_z = img_z[cut_x_min: cut_x_max, cut_y_min: cut_y_max, :]
                    img_pil = Image.fromarray(cv2.cvtColor(img_z, cv2.COLOR_BGR2RGB).astype(np.uint8))
                    draw = ImageDraw.Draw(img_pil)
                    font = ImageFont.truetype(font_path, font_size)
                    draw.text((0, 0), pred_cls, font=font, fill=(255,0,0))
                    img_z = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
                    # cv2.putText(img_z, text, adjusted_position, font, font_scale, font_color, font_thickness)
                    all_images.append(img_z)
        
                for idx, img in enumerate(all_images):
                    save_path = os.path.join(target_path, uid+'_'+str(index)+'_'+class_map2[pred_cls])
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    # print(img)
                    cv2.imwrite(os.path.join(save_path, '{:03d}.png'.format(idx)), img)

        output_gif_path = self.cfg.save_gif_dir
        if not os.path.exists(output_gif_path):
            os.makedirs(output_gif_path)

        for file in os.listdir(target_path):
            files = []
            for f in os.listdir(os.path.join(target_path, file)):
                files.append(f)
            files.sort(key=lambda x: x[:-4])
            frames = []
            for i in tqdm(range(len(files))):
                im = imageio.imread(os.path.join(target_path, file) + '/' + files[i])
                frames.append(im)
            fps = 24.0
            imageio.mimsave(os.path.join(output_gif_path, file+'.gif'), frames, 'GIF', duration=1/fps)


    def _GetTargetClassProb(self, labels, probs):
        target_class_probs = [prob[0][label] for prob, label in zip(probs, labels)]
        return target_class_probs

    def ProbVis(self, labels, probs):
        values = np.unique(labels)
        from matplotlib import pyplot as plt
        for val in values:
            print('=' * 60)
            print('Current label is %d' % val)
            indices = [idx for idx in range(len(labels)) if labels[idx] == val]
            current_probs = np.take(probs, indices, axis=0)
            print('Number of samples are %d' % len(current_probs))
            plt.hist(current_probs, range=(0, 1))
            plt.show()
            plt.show()

    def BadCaseVis(self, num_incies, show_edges=False, gaussian_choice=False, sigma=1):
        for idx in self.badcase_lists[:num_incies]:
            current_data = self.test_dataset[idx]
            img = np.squeeze(current_data[0])
            label = current_data[1]
            center_slice_idx = int(img.shape[0] / 2)
            slice_range = 1
            print('=' * 60)
            print('Current label is ', label, ' result is ', self.result[idx])
            print('prob is', self.result_list[idx])
            if not show_edges:
                _, axs = plt.subplots(1, 2 * slice_range + 1)
            else:
                _, axs = plt.subplots(2, 2 * slice_range + 1)
            for slice_idx in range(center_slice_idx - slice_range, center_slice_idx + slice_range + 1):
                if not show_edges:
                    axs[slice_idx - (center_slice_idx - slice_range)].imshow(img[slice_idx], cmap='bone')
                else:
                    edges = ShowEdges(img[slice_idx], gaussian_choice=gaussian_choice, sigma=sigma)
                    # edges = HFFilter(img[slice_idx])
                    axs[0][slice_idx - (center_slice_idx - slice_range)].imshow(img[slice_idx], cmap='bone')
                    axs[1][slice_idx - (center_slice_idx - slice_range)].imshow(edges, cmap='bone')
            plt.show()

    def ParametersDecode(self):
        '''
        Generate model parameters
        '''
        # self.model_param_name = os.path.join(self.basic_file_path, self.cfg['model_params'])
        # self.model_params = LoadJson(self.model_param_name)[self.cfg['model_name']]
        if self.mode_cfg['model_name'] == 'VGG3D':
            self.model_params = self.cfg.VGG3D
        elif self.mode_cfg['model_name'] == 'ResNet':
            self.model_params = self.cfg.ResNet
        elif self.mode_cfg['model_name'] == 'Se_ResNet':
            self.model_params = self.cfg.Se_ResNet
        elif self.mode_cfg['model_name'] == 'Darknet':
            self.model_params = self.cfg.Darknet
        elif self.mode_cfg['model_name'] == 'FSe_ResNet':
            self.model_params = self.cfg.FSe_ResNet