import json
import numpy as np
import pandas

# =========================================
KEY_ANNO_LABELS = "annotationLabels"
KEY_ANNO_LABEL_NAME = "name"
KEY_ANNO_LABEL_ID = "id"
KEY_ANNO_LABEL_PROPERTY = "properties"
KEY_ANNO_LABEL_PROPERTY_NAME = "name"
KEY_ANNO_LABEL_PROPERTY_ID = "id"
KEY_ANNO_SESSIONS = "annotationSessions"
KEY_ANNO_SESSIONS_SESSIONTYPE = "sessionType"
KEY_ANNO_SESSIONS_SESSIONTYPE_ANNOTATION = "ANNOTATION"
KEY_ANNO_SESSIONS_SESSIONTYPE_REVIEW = "REVIEW"
KEY_ANNO_SESSIONS_ANNOTATION = "annotationSet"
KEY_ANNO_SESSIONS_ANNOTATION_LABELPROPERTY = "labelProperties"
KEY_ANNO_SESSIONS_ANNOTATION_LABELPROPERTY_LABELID = "labelID"
KEY_ANNO_SESSIONS_ANNOTATION_LABELPROPERTY_PROPERTYID = "propertyID"
KEY_ANNO_SESSIONS_ANNOTATIONSET = "annotationSet"
KEY_ANNO_SESSIONS_ANNOTATIONSET_LABELPROPERTY = "labelProperties"
KEY_ANNO_SESSIONS_ANNOTATIONSET_COORDS = "coordinates"
KEY_ANNO_SESSIONS_ANNOTATIONSET_CONF = "confidence"
KEY_ANNO_SESSIONS_ANNOTATIONSET_COORDS_X = "x"
KEY_ANNO_SESSIONS_ANNOTATIONSET_COORDS_Y = "y"
KEY_ANNO_SESSIONS_ANNOTATIONSET_COORDS_Z = "z"
KEY_ANNO_SESSIONS_ANNOTATIONSET_SHAPETYPE = "shapeType"
KEY_ANNO_SESSIONS_ANNOTATIONSET_SHAPETYPE_RECTANGLE_3D = "Rectangle3D"
KEY_ANNO_SESSIONS_ANNOTATIONSET_SHAPETYPE_RECTANGLE_2D = "Rectangle2D"
KEY_ANNO_SESSIONS_ANNOTATIONSET_SHAPETYPE_SPHERE_3D = "Sphere3D"
KEY_ANNO_SESSIONS_ANNOTATIONSET_SHAPETYPE_SPHERE_2D = "Sphere2D"
KEY_ANNO_SESSIONS_ANNOTATIONSET_SHAPETYPE_POINT_3D = "Point3D"
KEY_ANNO_SESSIONS_ANNOTATIONSET_SHAPETYPE_POINT_2D = "Point2D"
KEY_ANNO_SESSIONS_ANNOTATIONSET_LABELPROPERTY_LABELID = "labelID"
KEY_ANNO_SESSIONS_ANNOTATIONSET_LABELPROPERTY_PROPERTYID = "propertyID"

KEY_ANNO_ANNOTATION_TYPE = "annotationType"
KEY_ANNO_ANNOTATION_TYPE_2D_DETECTIION = "2D_DETECTION"
KEY_ANNO_ANNOTATION_TYPE_3D_DETECTIION = "3D_DETECTION"
KEY_ANNO_ANNOTATION_TYPE_2D_CLASSIFICATION = "2D_CLASSIFICATION"
KEY_ANNO_ANNOTATION_TYPE_3D_CLASSIFICATION = "3D_CLASSIFICATION"
KEY_ANNO_ANNOTATION_TYPE_2D_SEGMENTAION = "2D_SEGMENTATION"
KEY_ANNO_ANNOTATION_TYPE_3D_SEGMENTATION = "3D_CLASSIFICATION"


KEY_SUID = 'seriesUID'

# =========================================
GT_CLS_HEADERs = ["seriesuid", "class", ]
PRED_CLS_HEADERs = ['seriesuid', 'class', ]

GT_CSV_HEADERs = ["seriesuid", "coordX", "coordY", "coordZ", "diameter_mm"]
PRED_CSV_HEADERs = ['seriesuid', 'coordX', 'coordY', 'coordZ', 'probability']

CLS_UNDEFINED_LABEL = 0


def annolabels2labelidlabelmap(anno_labels):
    if anno_labels is not None:
        label_id = anno_labels['id']
        label_map_list = anno_labels['properties']
        label_map = {str(lml['id']):lml['name'] for lml in label_map_list}
    else:
        label_id = -1
        label_map = None
    return label_id, label_map


def annojson2cls(anno_fullpath, full_save_path=None, given_labelid=-1, log_func=None):
    if log_func is None:
        log_func = print
    
    output_anno = []
    with open(anno_fullpath, 'rb') as fp:
        json_dict = json.loads(fp.read())

        log_func(fr'jsondict {json_dict}')
        this_uid = ''
        working_classmap_dict= None
        working_classmap_id = -1

        # check class map
        for k, v in json_dict.items():
            if k == KEY_SUID:
                this_uid = v
            elif k == KEY_ANNO_LABELS:
                # only #0 annotation label supported
                if given_labelid == -1:
                    if isinstance(v, list) and len(v) > 0:
                        working_classmap_dict = v[0][KEY_ANNO_LABEL_PROPERTY]
                        working_classmap_id = v[0][KEY_ANNO_LABEL_ID]
                else:
                    for vv in v:
                        if vv[KEY_ANNO_LABEL_ID] == given_labelid:
                            working_classmap_dict = vv[KEY_ANNO_LABEL_PROPERTY]
                            working_classmap_id = vv[KEY_ANNO_LABEL_ID]
                            break
                
                if working_classmap_dict is None or working_classmap_id == -1:
                    log_func(fr'error process classmap, given KEYWORD {given_labelid} not found !')
                else:
                    working_classmap_dict = {wcl[KEY_ANNO_LABEL_PROPERTY_ID]:wcl[KEY_ANNO_LABEL_PROPERTY_NAME] for wcl in working_classmap_dict}
                break
        
        # check annos
        for k, v in json_dict.items():
            if k == KEY_ANNO_SESSIONS:
                ####################################
                # predict data, label_id = undefined
                if len(v) == 0:
                    output_anno.append([this_uid, CLS_UNDEFINED_LABEL])
                    log_func(fr'pure predict data {this_uid}...using undefined-label...')
                    continue
                # predict data 
                ####################################

                ####################################
                # look for review session or #0 session
                working_session_idx = 0
                for vidx, vv in enumerate(v):
                    log_func(fr'looking for session index...')
                    if vv[KEY_ANNO_SESSIONS_SESSIONTYPE] == KEY_ANNO_SESSIONS_SESSIONTYPE_REVIEW:
                        log_func(fr'review session found for index {vidx}...')
                        working_session_idx = vidx
                log_func(fr'using session index {working_session_idx}...')
                # look for review session or #0 session
                ####################################

                label_list = v[working_session_idx][KEY_ANNO_SESSIONS_ANNOTATION]
                # [KEY_ANNO_SESSIONS_ANNOTATION_LABELPROPERTY]
                try:
                    for ll in label_list:
                        ll = ll[0][KEY_ANNO_SESSIONS_ANNOTATION_LABELPROPERTY]
                        for this_label in ll:
                            this_label_cls = -1
                            if working_classmap_id != -1 and \
                                this_label[KEY_ANNO_SESSIONS_ANNOTATION_LABELPROPERTY_LABELID] == working_classmap_id:

                                this_label_cls = this_label[KEY_ANNO_SESSIONS_ANNOTATION_LABELPROPERTY_PROPERTYID]
                            
                                if this_label_cls >= 0:
                                    output_anno.append([this_uid, this_label_cls])
                                else:
                                    log_func(fr'working_classmap_id {working_classmap_id}, label {this_label_cls} for image {this_uid}, NOT valid')

                except Exception as ex:
                    log_func(fr'exception during processing {label_list}, msg {ex}')
                    continue

        log_func(fr'suid {this_uid}, anno length {len(output_anno)}')

    if full_save_path is not None:
        if len(output_anno) > 0:
            df = pandas.DataFrame(output_anno)
            df.to_csv(full_save_path, header=GT_CLS_HEADERs, index=False)
        else:
            with open(full_save_path, "w") as f:
                first_row = GT_CLS_HEADERs
                f.write("%s,%s\n" %(first_row[0], first_row[1]))
            
    return output_anno, working_classmap_id, working_classmap_dict