import os
import sys
import json
import numpy as np
from datetime import datetime
import torch
import torch.distributed
import torch.multiprocessing as mp
import torch.distributed as dist
import logging
import argparse
import shutil
import glob
assert torch.cuda.is_available()

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".."))
__package__ = "lungClassification3D"
from BaseClassifier.Se_Cls import Classification3D
from BaseClassifier.utils.gpu_utils import set_gpu
from k8s_utils import CParamFiller, CK8sPathWrapper

def init_dist(backend='nccl',
              master_ip='127.0.0.1',
              port=29500):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')
    os.environ['MASTER_ADDR'] = master_ip
    os.environ['MASTER_PORT'] = str(port)
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    num_gpus = torch.cuda.device_count()
    local_rank = os.environ['LOCAL_RANK'] 
    deviceid = eval(local_rank) % num_gpus
    torch.cuda.set_device(deviceid)

    print(fr'dist settings: local_rank {local_rank}, rank {rank}, worldsize {world_size}, gpus {num_gpus}, deviceid {deviceid}')

    dist.init_process_group(backend=backend)
    return rank, world_size

def get_logger(name, task_name=None):
    file_handler = logging.StreamHandler(sys.stdout)
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    logger.addHandler(file_handler)
    return logger

def _log_msg(strmsg="\n"):
    if torch.distributed.get_rank() == 0:
        if g_logger is not None:
            g_logger.info(strmsg)
        else:
            print(strmsg)

class MyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, datetime):
            return obj.__str__()
        else:
            return super(MyEncoder, self).default(obj)
        
# set_gpu(num_gpu=1, used_percent=0.1)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="full functional execute script of Detection module.")
    group = parser.add_mutually_exclusive_group()
    # ddp
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument('--port', default=29500, type=int, help='port of server')
    parser.add_argument('--world_size', default=2, type=int)
    parser.add_argument('--rank', default=0, type=int)
    parser.add_argument('--master_ip', default='127.0.0.1', type=str)

    parser.add_argument('--job_data_root', default="/data/job_1056/job_data_preprocess", type=str)
    args = parser.parse_args()

    g_pathWrapper = CK8sPathWrapper(args.job_data_root)
    config_path = g_pathWrapper.get_input_inputconfig_filepath()
    # cfg = CParamFiller.fillin_args(config_path, args)
    # test_params = 'configs/val/input_config.json'
    # shutil.copy(test_params, config_path)
    # cfg = CParamFiller.fillin_args(config_path, args)
    cfg = CParamFiller.fillin_args(config_path, args)
    current_dirextory = os.path.abspath(__file__)
    train_params = os.path.join(os.path.dirname(os.path.dirname(current_dirextory)), 'configs/infer/input_config.json')
    cfg = CParamFiller.fillin_args(train_params, cfg)
    
    g_logger = get_logger(__name__)
    print(cfg)

    # cfg.testing['test_data_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'val_data.npy')
    # cfg.testing['test_info_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'val_info.npy')
    cfg.testing['test_data_path'] = os.path.join(g_pathWrapper.get_output_tmp_dirpath(), 'val_split.json')
    cfg.pretrain_msg = os.path.join(g_pathWrapper.get_input_dirpath(), cfg.pretrain_msg)

    # if not os.path.exists(cfg.testing['test_data_path']) or not os.path.exists(cfg.testing['test_info_path']):
    #     current_dirextory = os.path.abspath(__file__)
    #     run_file = os.path.join(os.path.dirname(os.path.dirname(current_dirextory)), 'local_run_preprocess_lung.py')
    #     print('Need to run {}'.format(run_file))
    #     syscommand = fr"python {run_file} --job_data_root {args.job_data_root}"
    #     retcode = os.WEXITSTATUS(os.system(syscommand))

    npy_path = g_pathWrapper.get_output_preprocess_npy_path()
    ## preprocess
    current_dirextory = os.path.abspath(__file__)
    run_file = os.path.join(os.path.dirname(os.path.dirname(current_dirextory)), 'local_run_preprocess_lung_infer.py')
    print('Need to run {}'.format(run_file))
    syscommand = fr"python {run_file} --job_data_root {args.job_data_root}"
    retcode = os.WEXITSTATUS(os.system(syscommand))
    

    data_list = glob.glob(os.path.join(npy_path, '*', '*_data.npy'))
    val_data = {'data': data_list}
    with open(cfg.testing['test_data_path'], 'w+') as f:
        json.dump(val_data, f, indent=4)
    # data2info = {}
    # uids = []
    # for file in os.listdir(npy_path):
    #     uid = file[:-9]
    #     if uid not in uids:
    #         uids.append(uid)
    # for uid in uids:
    #     data2info[uid + '_data.npy'] = uid + '_info.npy'
    
    # whole_data = []
    # whole_info = []
    # whole_label = []
    # for key, value in data2info.items():
    #     key = os.path.join(npy_path, key)
    #     value = os.path.join(npy_path, value)
    #     whole_data.append(np.load(key))
    #     working_info = np.load(value, allow_pickle=True)
    #     whole_info.append(working_info)
    #     whole_label.append(working_info[:,-2])
        
    # whole_data = np.concatenate(whole_data, axis=0)
    # whole_info = np.concatenate(whole_info, axis=0)
    # whole_label = np.concatenate(whole_label, axis=0)


    # np.save(cfg.testing['test_data_path'], whole_data)
    # np.save(cfg.testing['test_info_path'], whole_info)

    # print('data: ', cfg.testing['test_data_path'])
    # print('info: ', cfg.testing['test_info_path'])

    # 设置输出的文件路径
    # cfg.testing['eval_pmd'] = g_pathWrapper.get_output_eval_performance_md_filepath()
    # cfg.testing['eval_pjson'] = g_pathWrapper.get_output_eval_performance_filepath()
    # cfg.testing['eval_result'] = g_pathWrapper.get_output_eval_evalresult_filepath()
    # cfg.testing['save_dir'] = g_pathWrapper.get_output_eval_dirpath()
    # 开始训练, mode choice in [training, testing]
    cfg.dataset_info_path = g_pathWrapper.get_input_dataset_filepath()
    cfg.dcm_data_path = g_pathWrapper.get_output_tmp_dirpath()
    cfg.save_img_dir = g_pathWrapper.get_tmp_img_dirpath()
    cfg.save_gif_dir = g_pathWrapper.get_tmp_gif_dirpath()
    train_obj = Classification3D(cfg, 'infer', _log_msg)
    train_obj()