from yacs.config import CfgNode as CN


# -----------------------------------------------------------------------------
# Config definition
# -----------------------------------------------------------------------------

_C = CN()

# -----------------------------------------------------------------------------
# Data
# -----------------------------------------------------------------------------
_C.DATA = CN()

# -----------------------------------------------------------------------------
# DataLoader options
# -----------------------------------------------------------------------------
_C.DATA.DATA_LOADER = CN()
_C.DATA.DATA_LOADER.DATA_DIR = None 
_C.DATA.DATA_LOADER.TRAIN_DB = None
_C.DATA.DATA_LOADER.VALIDATE_DB = None
_C.DATA.DATA_LOADER.TEST_DB = None
_C.DATA.DATA_LOADER.RAND_CROP_RATIO = 0.0
_C.DATA.DATA_LOADER.NUM_WORKERS = 2
_C.DATA.DATA_LOADER.BLACKLIST = []
_C.DATA.DATA_LOADER.BALANCED_SAMPLING = False
_C.DATA.DATA_LOADER.POS_TARGET_RANGE = [1,2,3,4,5,6]
_C.DATA.DATA_LOADER.NEG_TARGET_RANGE = []
_C.DATA.DATA_LOADER.IRRELEVANT_TARGET_RANGE = []
_C.DATA.DATA_LOADER.DATA_MODE = 'rib'

_C.DATA.DATA_LOADER.MASK_DIR = None # Only for Centernet3D
# -----------------------------------------------------------------------------
# DataProcess options
# -----------------------------------------------------------------------------
_C.DATA.DATA_PROCESS = CN()
_C.DATA.DATA_PROCESS.CROP_SIZE = [128,128,128]
_C.DATA.DATA_PROCESS.SPACING = [1.0, 1.0, 1.0]
_C.DATA.DATA_PROCESS.BBOX_BORDER = 0
_C.DATA.DATA_PROCESS.PAD_VALUE = -300
_C.DATA.DATA_PROCESS.STRIDE = 4
_C.DATA.DATA_PROCESS.NUM_NEG = 800
_C.DATA.DATA_PROCESS.NUM_HARD = 2
_C.DATA.DATA_PROCESS.BOUND_SIZE = 8
_C.DATA.DATA_PROCESS.DIAMETERLIM = [5., 60.]
_C.DATA.DATA_PROCESS.SCALELIM = [0.8, 1.2]
_C.DATA.DATA_PROCESS.AUGTYPE = CN()
_C.DATA.DATA_PROCESS.AUGTYPE.FLIP = True
_C.DATA.DATA_PROCESS.AUGTYPE.ROTATE = True
_C.DATA.DATA_PROCESS.AUGTYPE.SCALE = True
_C.DATA.DATA_PROCESS.AUGTYPE.SWAP = False
_C.DATA.DATA_PROCESS.FPN_DIAMETER_RANGE = None
_C.DATA.DATA_PROCESS.HU_MIN = -300.0
_C.DATA.DATA_PROCESS.HU_MAX = 1000.0
_C.DATA.DATA_PROCESS.MAX_OBJS = 20 # Only for Centernet3D
_C.DATA.DATA_PROCESS.DOWNSAMPLING_RATIO = 2 # Only for Centernet3D
_C.DATA.DATA_PROCESS.MIN_OVERLAP = 0.1 # Only for Centernet3D

# -----------------------------------------------------------------------------
# Model
# -----------------------------------------------------------------------------
_C.MODEL = CN()
_C.MODEL.META_ARCHITECTURE = 'NoduleNet'
_C.MODEL.WEIGHT = None
_C.MODEL.BBOX_REG_WEIGHT= [1., 1., 1., 1., 1., 1.]


# -----------------------------------------------------------------------------
# ANCHOR options
# -----------------------------------------------------------------------------
_C.MODEL.ANCHOR = CN()
_C.MODEL.ANCHOR.BASES = [10, 15, 20, 30, 40]
_C.MODEL.ANCHOR.ASPECT_RATIOS = [[1.0, 1.0, 1.0]]

# -----------------------------------------------------------------------------
# BACKBONE options
# -----------------------------------------------------------------------------
_C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.CONV_BODY = "ResUNet"
_C.MODEL.BACKBONE.BN_MOMENTUM = 0.1
_C.MODEL.BACKBONE.FPN = False

# -----------------------------------------------------------------------------
# RPN options
# -----------------------------------------------------------------------------
_C.MODEL.RPN = CN()
_C.MODEL.RPN.USE_FPN = False
_C.MODEL.RPN.BG_THRESH_HIGH = 0.02
_C.MODEL.RPN.FG_THRESH_LOW = 0.7
_C.MODEL.RPN.NMS_IOU_THRESH = 0.1
_C.MODEL.RPN.TRAIN_PRE_NMS_SCORE_THRESH = 0.2
_C.MODEL.RPN.TEST_PRE_NMS_SCORE_THRESH = 0.2
_C.MODEL.RPN.DEPLOY_PRE_NMS_TOP_K = 10000
_C.MODEL.RPN.BALANCED_CLS_LOSS = True

# -----------------------------------------------------------------------------
# ROI_BOX_HEAD options
# -----------------------------------------------------------------------------
_C.MODEL.ROI_BOX_HEAD = CN()
_C.MODEL.ROI_BOX_HEAD.ROI_BATCH_SIZE = 16
_C.MODEL.ROI_BOX_HEAD.NUM_CLASS = 2
_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION  = 7
_C.MODEL.ROI_BOX_HEAD.SAMPLING_RATIO = 2
_C.MODEL.ROI_BOX_HEAD.BG_THRESH_HIGH = 0.02
_C.MODEL.ROI_BOX_HEAD.FG_THRESH_LOW = 0.7
_C.MODEL.ROI_BOX_HEAD.FG_FRACTION = 0.5
_C.MODEL.ROI_BOX_HEAD.TEST_PRE_NMS_SCORE_THRESH = 0

# -----------------------------------------------------------------------------
# Training
# -----------------------------------------------------------------------------   
_C.TRAINING = CN()
_C.TRAINING.NUM_GPUS = 1
_C.TRAINING.AMP = False
_C.TRAINING.FP16_ALLREDUCE = False
_C.TRAINING.BATCH_SIZE = 4

# -----------------------------------------------------------------------------
# SOLVER options
# -----------------------------------------------------------------------------
_C.TRAINING.SOLVER = CN()
_C.TRAINING.SOLVER.OPTIMIZER = 'SGD'
_C.TRAINING.SOLVER.BASE_LR = 0.01
_C.TRAINING.SOLVER.MOMENTUM = 0.9
_C.TRAINING.SOLVER.WRIGHT_DECAY = 1e-4

# -----------------------------------------------------------------------------
# SAVER options
# -----------------------------------------------------------------------------
_C.TRAINING.SAVER = CN()
_C.TRAINING.SAVER.SAVER_DIR = None
_C.TRAINING.SAVER.SAVER_FREQUENCY = 1

# -----------------------------------------------------------------------------
# SHEDULER options
# -----------------------------------------------------------------------------
_C.TRAINING.SHEDULER = CN()
_C.TRAINING.SHEDULER.TOTAL_EPOCHS = 300
_C.TRAINING.SHEDULER.WARMUP = 5
_C.TRAINING.SHEDULER.SWITCH_ROI_EPOCH = 301
_C.TRAINING.SHEDULER.SWITCH_BALANCED_SAMPLING_EPOCH = 301
_C.TRAINING.SHEDULER.LR_SHEDULE = False
_C.TRAINING.SHEDULER.NUM_NEG_SHEDULE = False
_C.TRAINING.SHEDULER.RAND_CROP_RATIO_SHEDULE = False
   
# -----------------------------------------------------------------------------
# Testing
# -----------------------------------------------------------------------------   
_C.TESTING = CN()
_C.TESTING.WEIGHT = None
_C.TESTING.SAVER_DIR = None
_C.TESTING.USE_RCNN = False

# -----------------------------------------------------------------------------
# deploy
# -----------------------------------------------------------------------------   
_C.DEPLOY = CN()
_C.DEPLOY.TORCHSCRIPT_SAVE_PATH = None
_C.DEPLOY.TORCHSCRIPT_COMPARE_RES_DIR = None

def get_cfg_defaults():
  """Get a yacs CfgNode object with default values for my_project."""
  # Return a clone so that the defaults will not be altered
  # This is for the "local variable" use pattern
  return _C.clone()

if __name__ == '__main__':
  import json
  cfg = get_cfg_defaults()
  cfg.merge_from_file('/root/Documents/st_sample_codedata/grouplung/NoduleDetector/config.yaml')
  json_params = json.dumps(cfg)
  json_params = json_params.lower()
  json_params = json_params.replace('true', 'True')
  json_params = json_params.replace('false', 'False')
  json_params = json_params.replace('null', 'None')
  # json_params = eval(json_params)
  with open('/root/Documents/st_sample_codedata/grouplung/json_params.json', 'w') as file:
    file.write(json_params)
  print('finish')