import os
import sys
import argparse
import torch
# import horovod.torch as hvd
import numpy as np
import lmdb
import json
import SimpleITK as sitk
from pprint import pprint
import logging
from datetime import datetime
import torch.distributed
import torch.multiprocessing as mp
import torch.distributed as dist
import pdb

# if __name__ == "__main__" and __package__ is None:
#     sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".."))
#     __package__ = "lungDetection3D"
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# from grouplung.BaseDetector.config import get_cfg_defaults
from BaseDetector.utils.gpu_utils import set_gpu
from NoduleDetector.engine import NoduleDetection3D
from k8s_utils import CParamFiller, CK8sPathWrapper

def get_logger(name, task_name=None):
    file_handler = logging.StreamHandler(sys.stdout)
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    logger.addHandler(file_handler)
    return logger

def _log_msg(strmsg="\n"):
    print(strmsg)

class MyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, datetime):
            return obj.__str__()
        else:
            return super(MyEncoder, self).default(obj)


if __name__ == "__main__":
    print()
    parser = argparse.ArgumentParser(description="full functional execute script of Detection module.")
    group = parser.add_mutually_exclusive_group()
    # ddp
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument('--port', default=29500, type=int, help='port of server')
    parser.add_argument('--world_size', default=2, type=int)
    parser.add_argument('--rank', default=0, type=int)
    parser.add_argument('--master_ip', default='127.0.0.1', type=str)

    parser.add_argument('--job_data_root', default="/data/job_678/job_data_train", type=str)
    args = parser.parse_args()

    g_pathWrapper = CK8sPathWrapper(args.job_data_root)
    config_path = g_pathWrapper.get_input_inputconfig_filepath()
    if not os.path.isfile(config_path):
        print(fr'given config file {config_path} not exist !')
    else:
        # print(fr'using config file {config_path} to fill args !')
        cfg = CParamFiller.fillin_args(config_path, args)
        test_params = 'configs/val/input_config.json'
        cfg = CParamFiller.fillin_args(test_params, cfg)
        print()
    # cfg = get_cfg_defaults()
    # cfg.merge_from_file('/root/Documents/st_sample_codedata/grouplung/NoduleDetector/config.yaml')
    # cfg.freeze()

    g_logger = get_logger(__name__)
    cfg.rank = 0
    cfg.world_size = 1
    cfg.local_rank = -1

    # 如果存在cancle_flag,其的绝对路径
    cfg.check_cancleflag = g_pathWrapper.get_input_cancle_flag()
    
    # 所有数据都作为验证
    cfg.data['data_loader']['data_dir'] = g_pathWrapper.get_train_nii_path()
    cfg.data['data_loader']['test_db'] = g_pathWrapper.get_output_preprocess_mdb_path()

    if len(os.listdir(cfg.data['data_loader']['data_dir'])) == 0 or len(os.listdir(cfg.data['data_loader']['test_db'])) == 0:
        current_dirextory = os.path.abspath(__file__)
        run_file = os.path.join(os.path.dirname(os.path.dirname(current_dirextory)), 'local_run_preprocess_lung.py')
        print('Need to run {}'.format(run_file))
        syscommand = fr"python {run_file} --job_data_root {args.job_data_root}"
        retcode = os.WEXITSTATUS(os.system(syscommand))
    
    print('All ready have data !!!')

    # 模型输出的保存路径
    cfg.training['saver']['saver_eval_pmd'] = g_pathWrapper.get_output_eval_performance_md_filepath()
    cfg.training['saver']['saver_eval_pjson'] = g_pathWrapper.get_output_eval_performance_filepath()
    cfg.training['saver']['saver_eval_result'] = g_pathWrapper.get_output_eval_evalresult_filepath()

    # 模型输出的保存路径
    cfg.testing['test_tmp_rpns'] = g_pathWrapper.get_tmp_test_rpns_dirpath()
    cfg.testing['test_tmp_dir'] = g_pathWrapper.get_output_tmp_dirpath()
    cfg.testing['test_nii_dir'] = g_pathWrapper.get_train_nii_path()
    cfg.testing['saver_pred_result'] = g_pathWrapper.get_output_eval_evalresult_filepath()
    cfg.testing['saver_pred_details'] = g_pathWrapper.get_output_eval_evalfiledetails_dirpath()

    # 验证时必定提供模型参数,从相对路径转换为绝对路径
    if cfg.pretrain_msg:
        cfg.pretrain_msg = os.path.join(g_pathWrapper.get_input_dirpath(), cfg.pretrain_msg)
    
    print(cfg)
    detector = NoduleDetection3D(cfg, 'test', log_fun=_log_msg)
    detector.do_val()