{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import glob\n",
    "import json\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = '/hdd/disk6/zrx/projects/Pulmonary_Nodule/develop_code/GroupLung/grouplung/lungNoduleSegmentation/NoduleSegModule/files/model_parameters.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "path1 = '/hdd/disk6/zrx/projects/Pulmonary_Nodule/develop_code/GroupLung/grouplung/lungNoduleSegmentation/NoduleSegModule/files/model_map.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def GenerateParameters(model_parameter_filename,model_map_parameter_filename,model='VGG'):\n",
    "    if not os.path.exists(model_parameter_filename):\n",
    "        print ('Invalid model_parameter_filename')\n",
    "        return {}\n",
    "    if not os.path.exists(model_map_parameter_filename):\n",
    "        print ('Invalid model_map_parameter_filename')\n",
    "        return {}\n",
    "    with open(model_parameter_filename,'r') as f:\n",
    "        model_parameters = json.load(f)\n",
    "    f.close()\n",
    "    \n",
    "    with open(model_map_parameter_filename,'r') as f:\n",
    "        parameter_map_dict = json.load(f)\n",
    "    f.close()\n",
    "   \n",
    "    current_parameters = parameter_map_dict[model]\n",
    "    \n",
    "    model_parameter_dict = {}\n",
    "    for key in current_parameters:\n",
    "        if key not in model_parameters.keys():\n",
    "            model_parameter_dict[key] = float('inf')\n",
    "        else:\n",
    "            model_parameter_dict[key] = model_parameters[key]\n",
    "    return model_parameter_dict\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{u'ACF_choice': True,\n",
       " u'OCR_choice': True,\n",
       " u'SEB_choice': True,\n",
       " u'activation_func': u'LeakyReLU',\n",
       " u'deep_supervision': True,\n",
       " u'final_kernel_size': 3,\n",
       " u'kernel_initializer': u'he_normal',\n",
       " u'kernel_regularizer': None,\n",
       " u'kernel_size': 3,\n",
       " u'merge_axis': inf,\n",
       " u'norm_func': u'InstanceNormalization',\n",
       " u'num_units': [3, 3, 3, 3],\n",
       " u'padding': u'same',\n",
       " u'seg_num_class': inf}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "GenerateParameters(filename,path1,model='NoduleSegDecoder')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp = {\n",
    "    \"training\": {\n",
    "        \"train_file\": \"train_files.json\",\n",
    "        \"val_txt\": \"val_files.json\",\n",
    "        \"model_parameters_txt\":\"model_parameters.json\",\n",
    "        \"model_map_parameters_txt\":\"model_map.json\",\n",
    "        \"encoder\":\"VGG\",\n",
    "        \"decoder\":\"NoduleSegDecoder_proxima\",\n",
    "        \n",
    "        \n",
    "        \n",
    "        \"num_classes\": 1,\n",
    "        \"num_units\":1,\n",
    "        \"voxel_size\":64,\n",
    "        \"num_slice\":64,\n",
    "        \"final_ksize\":1,\n",
    "        \"load_pre_trained\": False,\n",
    "        \"pre_model_file\": \"\",\n",
    "        \"base_path\":'/hdd/disk4/Segmentation/Weights',\n",
    "        \"n_channels\": 1,\n",
    "        \n",
    "        \"early_stop\": 10,\n",
    "        \"initial_learning_rate\": 5*1e-4,\n",
    "        \"learning_rate_drop\": 0.5,\n",
    "        \n",
    "        \"train_batch_size\": 2,\n",
    "        \"val_batch_size\": 2,\n",
    "        \"normalizer\": [-1024.0, 400.0],\n",
    "        \"n_epochs\": 25,\n",
    "        \"shuffle\": True,\n",
    "        \"aug\": True,\n",
    "        \n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'{\"training\": {\"model_parameters_txt\": \"model_parameters.json\", \"num_slice\": 64, \"n_epochs\": 25, \"n_channels\": 1, \"shuffle\": true, \"early_stop\": 10, \"aug\": true, \"train_file\": \"train_files.json\", \"model_map_parameters_txt\": \"model_map.json\", \"normalizer\": [-1024.0, 400.0], \"load_pre_trained\": false, \"base_path\": \"/hdd/disk4/Segmentation/Weights\", \"encoder\": \"VGG\", \"learning_rate_drop\": 0.5, \"decoder\": \"NoduleSegDecoder_proxima\", \"train_batch_size\": 2, \"val_txt\": \"val_files.json\", \"final_ksize\": 1, \"num_classes\": 1, \"pre_model_file\": \"\", \"voxel_size\": 64, \"initial_learning_rate\": 0.0005, \"num_units\": 1, \"val_batch_size\": 2}}'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "json.dumps(temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"training\": {\"model_parameters_txt\": \"model_parameters.json\", \"num_slice\": 64, \"n_epochs\": 25, \"n_channels\": 1, \"shuffle\": true, \"early_stop\": 10, \"aug\": true, \"train_file\": \"train_files.json\", \"model_map_parameters_txt\": \"model_map.json\", \"normalizer\": [-1024.0, 400.0], \"load_pre_trained\": false, \"base_path\": \"/hdd/disk4/Segmentation/Weights\", \"encoder\": \"VGG\", \"learning_rate_drop\": 0.5, \"decoder\": \"NoduleSegDecoder_proxima\", \"train_batch_size\": 2, \"val_txt\": \"val_files.json\", \"final_ksize\": 1, \"num_classes\": 1, \"pre_model_file\": \"\", \"voxel_size\": 64, \"initial_learning_rate\": 0.0005, \"num_units\": 1, \"val_batch_size\": 2}}\n"
     ]
    }
   ],
   "source": [
    "temptemp = json.dumps(temp)\n",
    "print temptemp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../files/train_config.json','w') as f:\n",
    "    f.write(temptemp)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}