# ── job_data
#     ├── input(只读挂载)
#     │   └── data_set.json(里面标记的路径都是任务容器可以只读访问到)
#     |   └── input_config.json(输入的param,以key-val的形式存储)
#     └── output(读写挂载)
#         ├── eval
#         │   ├── eval_file_details(每条数据写一个结果,以SUID命名)
#         │   │   └── 1.2.826.0.1.3680043.2.1125.20230207.120420561023786803566_label.nii.gz
#         │   │   └── 1.2.826.0.1.3680043.2.1125.20230207.120420561023786803566.json
#         │   ├── eval_result.json(结果总表、数据类型待定)
#         │   └── performance.json(性能指标,DICE、FROC、ConfusionMatrix等)
#         |   └── model.pth(optional,评测需要使用的模型)
#         ├── preprocess
#         │   ├── preprocess_file_details
#         │   └── preprocess_result.json
#         ├── tmp(临时文件夹,任务执行完成后平台会删除,在这里持久的数据,DDP情况下,多节点都可以shared)
#         └── train
#             ├── best_model
#             │   ├── model.onnx
#             │   └── model.pth
#             └── train_result.json
#             └── training_metrics.csv
import os
import json 
import zipfile
import shutil

def log_msg(str_msg):
    print(str_msg)


class K8sConstants:
    DATASET_JSON = 'data_set.json' # change
    INPUTCONFIG_JSON = 'input_config.json'
    EVALRESULT_JSON = 'eval_result.json'
    PERFORMANCE_JSON = 'performance.json'
    PERFORMANCE_MD = 'performance.md'

    PREPROCESSRESULT_JSON = 'preprocess_result.json' # change
    TRAINRESULT_JSON = 'train_result.json'
    TRAINMETRICS_CSV = 'training_metrics.csv'

    BEST_MODEL = 'model'
    MODEL_PTH = 'model.pth'
    MODEL_ONNX = 'model.onnx'
    CANCEL_FLAG = 'cancel_flag'

    INPUT_DIR = 'input'
    OUTPUT_DIR = 'output'
    EVAL_DIR = 'eval'
    EVALDETAILS_DIR = 'eval_file_details'
    PREPROCESS_DIR = 'preprocess'
    PREPROCESSDETAILS_DIR = 'preprocess_file_details' # change
    TMP_DIR = 'tmp'
    TRAIN_DIR = 'train'

    PREPROCESS_MDB_NAME = 'json2mdb' 
    TRAIN_DATA_TYPE = 'nii'
    TRAIN_DB_DOC = 'train_db'
    VAL_DB_DOC = 'val_db'
    RPNS = 'rpns'
    IMG = 'img'
    GIF = 'gif'
    # 分类任务
    DATA_TYPE = 'npy'
    WRITE_PATH = 'runs'

class CK8sPathWrapper:
    def _ensure_path(self, given_path):
        if not os.path.exists(given_path):
            os.makedirs(given_path)

    def __init__(self, job_root) -> None:
        self.job_root = job_root

    def get_tmp_img_dirpath(self):
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TMP_DIR, K8sConstants.IMG)
        self._ensure_path(path)
        return path

    def get_tmp_gif_dirpath(self):
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TMP_DIR, K8sConstants.GIF)
        self._ensure_path(path)
        return path

    def get_input_dirpath(self): 
        path = os.path.join(self.job_root, K8sConstants.INPUT_DIR)
        self._ensure_path(path)
        return path

    def get_input_dataset_filepath(self): 
        path = os.path.join(self.job_root, K8sConstants.INPUT_DIR, K8sConstants.DATASET_JSON)
        return path

    def get_input_inputconfig_filepath(self): 
        path = os.path.join(self.job_root, K8sConstants.INPUT_DIR, K8sConstants.INPUTCONFIG_JSON)
        return path

    def get_input_inputmodel_filepath(self): 
        path = os.path.join(self.job_root, K8sConstants.INPUT_DIR, K8sConstants.MODEL_PTH)
        return path

    def get_input_cancle_flag(self): 
        path = os.path.join(self.job_root, K8sConstants.INPUT_DIR, K8sConstants.CANCEL_FLAG)
        return os.path.exists(path)

    def get_output_tmp_dirpath(self): 
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TMP_DIR)
        self._ensure_path(path)
        return path
        
    def get_output_eval_dirpath(self): 
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.EVAL_DIR)
        self._ensure_path(path)
        return path

    def get_tmp_test_rpns_dirpath(self):
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TMP_DIR, K8sConstants.RPNS)
        self._ensure_path(path)
        return path

    def get_output_eval_evalfiledetails_dirpath(self): 
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.EVAL_DIR, K8sConstants.EVALDETAILS_DIR)
        self._ensure_path(path)
        return path

    def get_output_eval_evalresult_filepath(self): 
        path0 = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.EVAL_DIR)
        self._ensure_path(path0)
        path = os.path.join(path0, K8sConstants.EVALRESULT_JSON)
        return path

    def get_output_eval_performance_filepath(self): 
        path0 = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.EVAL_DIR)
        self._ensure_path(path0)
        path = os.path.join(path0, K8sConstants.PERFORMANCE_JSON)
        return path

    def get_output_eval_performance_md_filepath(self): 
        path0 = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.EVAL_DIR)
        self._ensure_path(path0)
        path = os.path.join(path0, K8sConstants.PERFORMANCE_MD)
        return path

    def get_output_preprocess_dirpath(self): 
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.PREPROCESS_DIR)
        self._ensure_path(path)
        return path

    def get_output_preprocess_preprocessfiledetails_dirpath(self): 
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.PREPROCESS_DIR, K8sConstants.PREPROCESSDETAILS_DIR)
        self._ensure_path(path)
        return path

    def get_output_preprocess_preprocessresult_filepath(self): 
        path0 = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.PREPROCESS_DIR)
        self._ensure_path(path0)
        path = os.path.join(path0, K8sConstants.PREPROCESSRESULT_JSON)
        return path
    
    def get_output_preprocess_mdb_path(self):
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.PREPROCESS_DIR, 
                            K8sConstants.PREPROCESSDETAILS_DIR, K8sConstants.PREPROCESS_MDB_NAME)
        self._ensure_path(path)
        return path
    
    def get_train_nii_path(self):
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.PREPROCESS_DIR, 
                            K8sConstants.PREPROCESSDETAILS_DIR, K8sConstants.TRAIN_DATA_TYPE)
        self._ensure_path(path)
        return path
    
    def get_train_db_path(self):
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.PREPROCESS_DIR, 
                            K8sConstants.PREPROCESSDETAILS_DIR, K8sConstants.TRAIN_DB_DOC)
        self._ensure_path(path)
        return path
    
    def get_val_db_path(self):
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.PREPROCESS_DIR, 
                            K8sConstants.PREPROCESSDETAILS_DIR, K8sConstants.VAL_DB_DOC)
        self._ensure_path(path)
        return path
    
    def get_pretrain_ap(self):
        path = os.path.join(self.job_root, K8sConstants.INPUT_DIR)
        return path

    def get_output_train_dirpath(self): 
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TRAIN_DIR)
        self._ensure_path(path)
        return path

    def get_output_train_bestmodel_dirpath(self): 
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TRAIN_DIR, K8sConstants.BEST_MODEL)
        self._ensure_path(path)
        return path

    def get_output_train_latestmodel_pth_filepath(self): 
        path0 = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TMP_DIR)
        self._ensure_path(path0)
        path = os.path.join(path0, K8sConstants.MODEL_PTH)
        return path

    def get_output_train_bestmodel_pth_filepath(self): 
        path0 = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TRAIN_DIR, K8sConstants.BEST_MODEL)
        self._ensure_path(path0)
        path = os.path.join(path0, K8sConstants.MODEL_PTH)
        return path

    def get_output_train_bestmodel_onnx_filepath(self): 
        path0 = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TRAIN_DIR, K8sConstants.BEST_MODEL)
        self._ensure_path(path0)
        path = os.path.join(path0, K8sConstants.MODEL_PTH)
        return path

    def get_output_train_trainresult_filepath(self): 
        path0 = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TRAIN_DIR)
        self._ensure_path(path0)
        path = os.path.join(path0, K8sConstants.TRAINRESULT_JSON)
        return path

    def get_output_train_trainingmetrics_filepath(self): 
        path0 = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TRAIN_DIR)
        self._ensure_path(path0)
        path = os.path.join(path0, K8sConstants.TRAINMETRICS_CSV)
        return path
    
    # 分类任务
    def get_output_preprocess_npy_path(self):
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.PREPROCESS_DIR,
                            K8sConstants.PREPROCESSDETAILS_DIR, K8sConstants.DATA_TYPE)
        self._ensure_path(path)
        return path

    def get_output_train_writer_path(self):
        path = os.path.join(self.job_root, K8sConstants.OUTPUT_DIR, K8sConstants.TRAIN_DIR,
                            K8sConstants.WRITE_PATH)
        self._ensure_path(path)
        return path


class CMetricsWriter:
    def __init__(self, filename, headernames) -> None:
        self.filename = filename
        self.headernames = headernames
        self._write_headers()

    def _write_headers(self):
        if not os.path.exists(self.filename):
            str_csv_headers = ",".join(self.headernames) 
            with open(self.filename, 'w+') as fp:
                fp.write('{}\n'.format(str_csv_headers))

    def append_one_line(self, list2write):
        self._write_headers()

        len_headers = len(self.headernames)
        if len(list2write) < len_headers:
            list2write = list2write + [0.0] * (len_headers - len(list2write))
        elif len(list2write) > len_headers:
            list2write = list2write[ :len_headers]

        str2write = ",".join(map(str, list2write))
        log_msg(str2write)
        with open(self.filename, 'a+') as fp:
            fp.write('{}\n'.format(str2write))


    def get_filename(self):
        return self.filename

    def get_headernames(self):
        return self.headernames


class CParamFiller:
    @staticmethod
    def fillin_args(param_json, args):
        param_dict = CParamFiller.fillin_dict(param_json=param_json)
        for k, v in param_dict.items():
            # log_msg(fr'filling param: {v} -> {k}')
            setattr(args, k, v)
        return args

    @staticmethod
    def fillin_dict(param_json):
        param_dict = {}
        
        try:
            with open(param_json, 'r') as fp:
                param_dict = json.loads(fp.read())
        except Exception as ex:
            log_msg(fr'open json file {param_json} failed !! exception {ex}')

        return param_dict

# {
#     "protocol": "FILE",
#     "dataList": [
#         {
#             "uid": "1.2.826.0.1.3680043.2.1125.20230207.120420561023786803566", //md5
#             "rawDataType": "DCM",//DCM_ZIP, PREPROCESS(预处理)
#             "rawDataUrls": [
#                 "/data/dicom/1.2.826.0.1.3680043.2.1125.20230207.120420561023786803566-1.dcm",
#                 "/data/dicom/1.2.826.0.1.3680043.2.1125.20230207.120420561023786803566-2.dcm"
#             ],
#             "annotations": [
#                 {
#                     "annotationType":"3D_SEGMENTATION",
#                     "annotationUrls":[
#                         "/data/annotation/1.2.826.0.1.3680043.2.1125.20230207.120420561023786803566.json",
#                         "/data/annotation/1.2.826.0.1.3680043.2.1125.20230207.120420561023786803566_label.nii.gz",
#                     ]
#                 }
#             ]
#         }
#     ]
# }
class CDatasetWrapper():
    DATASET_DATALIST = "dataList"
    DATASET_UID = "uid"
    DATASET_RAWDATATYPE = "rawDataType"
    DATASET_RAWDATAURL = "rawDataUrls"
    DATASET_ANNOTATION = "annotations"
    DATASET_ANNOTATION_TYPE = "annotationType"
    DATASET_ANNOTATION_URL = "annotationUrls"
    
    @staticmethod
    def extract_zipfile(zip_file_path, target_root_path):
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            for member in zip_ref.namelist():
                filename = os.path.basename(member)
                if not filename:
                    continue
                source = zip_ref.open(member)
                target = open(os.path.join(target_root_path, filename), "wb") 
                with source, target:
                    shutil.copyfileobj(source, target)
    
    @staticmethod
    def get_dataset_list(config_json, output_path=None):
        output = []
        try:
            with open(config_json, 'r') as fp:
                data_dict = json.loads(fp.read())

                for dl in data_dict[CDatasetWrapper.DATASET_DATALIST]:
                    if dl[CDatasetWrapper.DATASET_RAWDATATYPE] =='DCM_ZIP':
                        zipfilepath = dl[CDatasetWrapper.DATASET_RAWDATAURL][0]
                               
                        zipfile_basename = os.path.splitext(os.path.basename(zipfilepath))[0]
                        target_zip_path = os.path.join(output_path, zipfile_basename)

                        if not os.path.exists(target_zip_path):
                            os.makedirs(target_zip_path)

                        log_msg(fr'UID {dl[CDatasetWrapper.DATASET_UID]}, basename {zipfile_basename}, {zipfilepath} -> {target_zip_path}')
          
                        if output_path is None:
                            raise IOError('output path needed for extracting zip dataset files')
                        # unzip data
                        CDatasetWrapper.extract_zipfile(zipfilepath, target_zip_path)

                        # fill data list
                        fileslist = os.listdir(target_zip_path)
                        fileslist = [os.path.join(target_zip_path, fn) for fn in fileslist]
                    else:
                        fileslist = dl[CDatasetWrapper.DATASET_RAWDATAURL]

                    log_msg(fr'final fileslist length {len(fileslist)}')

                    # no anno keys
                    if CDatasetWrapper.DATASET_ANNOTATION not in dl or dl[CDatasetWrapper.DATASET_ANNOTATION] is None:
                        output.append([
                            dl[CDatasetWrapper.DATASET_UID],
                            fileslist, 
                            None,
                        ]) 
                        continue
                    
                    # good anno keys
                    list_annotations = dl[CDatasetWrapper.DATASET_ANNOTATION]
                    if isinstance(list_annotations, list):
                        if len(list_annotations) == 0:
                            output.append([
                                dl[CDatasetWrapper.DATASET_UID],
                                fileslist, 
                                None,
                            ]) 
                        else:
                            for ll in list_annotations:
                                output.append([
                                    dl[CDatasetWrapper.DATASET_UID],
                                    fileslist, 
                                    ll[CDatasetWrapper.DATASET_ANNOTATION_URL], 
                                ]) 
                    else:
                        output.append([
                            dl[CDatasetWrapper.DATASET_UID],
                            zipfile_basename,
                            fileslist, 
                            list_annotations[CDatasetWrapper.DATASET_ANNOTATION_URL] if len(list_annotations) > 0 else None, 
                        ]) 
        except Exception as ex:
            log_msg(fr'exception during process dataset, msg {ex}')

        return output
        

def get_fullfilepaths_according2pattern(givenpath, pattern):
    output = []
    for root,dirs,files in os.walk(givenpath):
        for fn in files:
            if pattern in fn:
                output.append(os.path.join(root, fn).replace(givenpath, '.'))
    return output