import os
import sys
import glob
import numpy as np

import tensorflow as tf
from keras import Model
from keras.regularizers import l2
from keras import backend as K
from keras.engine import Layer,InputSpec
from keras.layers.merge import concatenate
from keras.callbacks import TensorBoard,Callback
from keras.layers.advanced_activations import LeakyReLU
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from keras import initializers, regularizers, constraints,optimizers
from keras.callbacks import ModelCheckpoint, LearningRateScheduler,TensorBoard
from keras.layers import Add,Input,Conv3D,Convolution3D,Dropout,UpSampling3D,Concatenate,MaxPooling3D,\
GlobalAveragePooling3D,Dense,GlobalMaxPooling3D,Lambda,Activation,Reshape,Permute, PReLU, Deconvolution3D,Conv3DTranspose

sys.path.append('../baseLayers/')
from unet_utils import *
from SEB import SEB_block
from OCR import SpatialGather,OCRModule
from CommonLayers import NormActi,ConvUnit


def UpsampleBlock(x,**kwargs):
    num_filters = kwargs.get('num_filters')
    kernel_initializer = kwargs.get('kernel_initializer','he_normal')
    kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4))
    block_pre = kwargs.get('block_pre','')
    kernel_size = kwargs.get('kernel_size',3)
    rate = kwargs.get('rate',2)
    upsample_choice = kwargs.get('upsample_choice','upsample')
    norm_func = kwargs.get('norm_func')
    activation_func = kwargs.get('activation_func')
    atrous_rate = kwargs.get('atrous_rate',1)
    padding = kwargs.get('padding','same')
    if type(activation_func)==str:
        current_acti = activation_func
        activation_func = None
    else:
        current_acti = None
#     if rate>1:
#     print ('kernel of %s is %d'%(block_pre,kernel_size))
#     print ('current_acti is ',current_acti,norm_func)
    if upsample_choice == 'upsample':
        x = UpSampling3D(size=rate,name='%s_upsample'%(block_pre))(x)
        x = Convolution3D(nb_filter = num_filters,kernel_size=kernel_size, 
                             kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer,
                            border_mode=padding,activation=current_acti,name='%s_conv'%block_pre)(x)     
    elif upsample_choice == 'deconv':
        x = Deconvolution3D(filters = num_filters,kernel_size=kernel_size, strides=rate,
                     kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer,
                    padding=padding,activation=current_acti,name='%s_conv'%block_pre)(x) 
    else:
        x = Conv3DTranspose(filters = num_filters,kernel_size=kernel_size, strides=rate,
                     kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer,
                    padding=padding,activation=current_acti,name='%s_conv'%block_pre)(x)  

    x = NormActi(x,norm_func=norm_func,activation_func=activation_func,
                               block_prefix=block_pre,block_suffix=block_pre)
#     else:
#         x = ConvUnit(x,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters,
#                     kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_pre,
#                      kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
#                     layer_idx=1,conv_first=True)
    return x
    
    
    

def NoduleSegDecoderBlock(x,encoder_tensors,**kwargs):
    block_idx = kwargs.get('block_idx')
    num_units = kwargs.get('num_units')
    kernel_initializer = kwargs.get('kernel_initializer','he_normal')
    kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4))
    SEB_choice = kwargs.get('SEB_choice',False)
    kernel_size = kwargs.get('kernel_size',3)
    norm_func = kwargs.get('norm_func')
    activation_func = kwargs.get('activation_func')
    
    
    num_filters = kwargs.get('num_filters')
    padding = kwargs.get('padding','same')
    
    atrous_rate = kwargs.get('atrous_rate',1)
    stride = kwargs.get('upsample_rate',2)
    conv_first = kwargs.get('conv_first',True)
    merge_axis = kwargs.get('merge_axis',-1)
    upsample_choice = kwargs.get('upsample_choice','upsample')
    
    DIM1_AXIS,DIM2_AXIS,DIM3_AXIS = 1,2,3
    block_prefix = 'NoduleSegDecoder_Block%02d'%block_idx

    x1 = UpsampleBlock(x,num_filters=num_filters,kernel_size=kernel_size,rate=stride,atrous_rate=atrous_rate,
                 padding=padding,block_pre=block_prefix,kernel_initializer=kernel_initializer,
                 kernel_regularizer=kernel_regularizer,upsample_choice=upsample_choice,
                 norm_func=norm_func,activation_func=activation_func)
    
    target_depth,target_height,target_width = x1._keras_shape[DIM1_AXIS],x1._keras_shape[DIM2_AXIS],x1._keras_shape[DIM3_AXIS]
    
    depth_before_SEB,height_before_SEB,width_before_SEB =  x._keras_shape[DIM1_AXIS],x._keras_shape[DIM2_AXIS],x._keras_shape[DIM3_AXIS]
    
    if not SEB_choice:
        encoder_target_tensor = encoder_tensors[-(block_idx+1)]
    else:
        target_tensor = encoder_tensors[-(block_idx+1)]
        SEB_tensors = encoder_tensors[-(block_idx):][::-1]
        
        upsample_tensor_list = []
        for layer_idx,current_tensor in enumerate(SEB_tensors):
            current_depth,current_height,current_width = current_tensor._keras_shape[DIM1_AXIS],current_tensor._keras_shape[DIM2_AXIS],current_tensor._keras_shape[DIM3_AXIS]
            upsample_rate = [int(depth_before_SEB/current_depth),int(height_before_SEB/current_height),int(width_before_SEB/current_width)]
            upsample_tensor = UpSampling3D(size=upsample_rate,name='%s_SEB_preUpsample%02d'%(block_prefix,layer_idx+1))(current_tensor)
            upsample_tensor_list.append(upsample_tensor)
            
        target_num_filters = target_tensor._keras_shape[-1]
        
        encoder_target_tensor = SEB_block([upsample_tensor_list,target_tensor],norm_func=norm_func,activation_func=activation_func,
                                         num_filters=target_num_filters,kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,
                                         block_prefix=block_prefix,kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                                         conv_first=conv_first,merge_axis=merge_axis)

    x = Concatenate(axis=merge_axis,name='%s_Concatenate'%(block_prefix))([encoder_target_tensor,x1])
    
    for layer_idx in range(1,num_units):
        x = ConvUnit(x,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters,
                    kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix,
                     kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                    layer_idx=layer_idx+1,conv_first=conv_first)
    return x


def NoduleSegDeepCombineBlock(input_tensors,**kwargs):
    kernel_initializer = kwargs.get('kernel_initializer','he_normal')
    kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4))
    kernel_size = kwargs.get('kernel_size',3)
    norm_func = kwargs.get('norm_func')
    activation_func = kwargs.get('activation_func')
    
    num_filters = kwargs.get('num_filters')
    
    padding = kwargs.get('padding','same')
    atrous_rate = kwargs.get('atrous_rate',1)
    stride = kwargs.get('upsample_rate',2)
    conv_first = kwargs.get('conv_first',True)
    merge_axis = kwargs.get('merge_axis',-1)
    upsample_choice = kwargs.get('upsample_choice','upsample')
    
    print ('input_tensors is ',input_tensors)
    x1,x2,x3 = input_tensors
    block_prefix = 'NoduleSegDeepCombineBlock'

    uprate_1 = int(x2._keras_shape[1]/x1._keras_shape[1])
    uprate_2 = int(x3._keras_shape[1]/x2._keras_shape[1])
    x1 = ConvUnit(x1,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters,
                    kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_block01_',
                     kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                    layer_idx=1,conv_first=conv_first)
    up_x1 = UpSampling3D(size=uprate_1,name='%s_block01_upsample'%block_prefix)(x1)

    
    x2 = ConvUnit(x2,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters,
                    kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_block02_',
                     kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                    layer_idx=1,conv_first=conv_first)
    up_x2 = UpSampling3D(size=uprate_2,name='%s_block02_upsample'%block_prefix)(Add(name='%s_Add01'%block_prefix)([up_x1,x2]))
    
    
    x3 = ConvUnit(x3,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters,
                    kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_block03_',
                     kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                    layer_idx=1,conv_first=conv_first)
    
    final_result = Add(name='%s_Add02'%block_prefix)([up_x2,x3])
    
    return final_result,[x1,x2,x3]
    
    
    
    
    
def NoduleSegDecoder_proxima(encoder_result,**kwargs):
    
    kernel_initializer = kwargs.get('kernel_initializer','he_normal')
    kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4))
    kernel_size = kwargs.get('kernel_size',3)
    final_kernel_size = kwargs.get('final_kernel_size',3)
    norm_func = kwargs.get('norm_func')
    activation_func = kwargs.get('activation_func')
    padding = kwargs.get('padding','same')
    atrous_rate = kwargs.get('atrous_rate',1)
    conv_first = kwargs.get('conv_first',True)
    merge_axis = kwargs.get('merge_axis',-1)
    
    SEB_choice = kwargs.get('SEB_choice',True)
    ACF_choice = kwargs.get('ACF_choice',False)
    OCR_choice = kwargs.get('OCR_choice',False)
    deep_supervision = kwargs.get('deep_supervision',True)
    num_units = kwargs.get('num_units',[3,3,3,3])
    seg_num_class = kwargs.get('seg_num_class')
    aux_task = kwargs.get('aux_task',False)
    num_classes = kwargs.get('num_classes')
    classification_layers = kwargs.get('classification_layers')
    upsample_choice = kwargs.get('upsample_choice','upsample')
    divide_ratio = kwargs.get('divide_ratio',4)
    
    encoder_model = encoder_result[0]
    encoder_tensors = encoder_result[1]
    
    try:
        encoder_stride_list = encoder_result[2]
    except:
        encoder_stride_list = [None for _ in range(len(encoder_tensors))]


    input_tensor = encoder_model.input
    
    x = encoder_tensors[-1]
    decoder_tensors = []

    for block_idx in range(len(encoder_tensors)-2):
        if encoder_stride_list[-(block_idx+1)]:
            current_stride = encoder_stride_list[-(block_idx+1)]
        else:
            current_stride = int(encoder_tensors[-(block_idx+2)]._keras_shape[1]/x._keras_shape[1])
        num_filters = int(x._keras_shape[-1]/2)
        if block_idx == len(encoder_tensors)-3:
            SEB_choice = False
        x = NoduleSegDecoderBlock(x,encoder_tensors,block_idx=block_idx+1,num_units=num_units[block_idx],
                                 kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                                 SEB_choice=SEB_choice,kernel_size=kernel_size,norm_func=norm_func,activation_func=activation_func,
                                 num_filters=num_filters,padding=padding,atrous_rate=atrous_rate,upsample_rate=current_stride,conv_first=conv_first,
                                 merge_axis=merge_axis,upsample_choice=upsample_choice)
        decoder_tensors.append(x)
    
    start_idx = len(decoder_tensors)-3

    current_stride_list =   encoder_stride_list[-4:-1][::-1]

    combine_tensors = NoduleSegDeepCombineBlock(decoder_tensors[start_idx:start_idx+4],kernel_initializer=kernel_initializer,
                                               kernel_regularizer=kernel_regularizer,kernel_size=kernel_size,norm_func=norm_func,
                                               activation_func=activation_func,num_filters=int(decoder_tensors[-1]._keras_shape[-1]/divide_ratio),
                                                padding=padding,atrous_rate=atrous_rate,conv_first=conv_first,merge_axis=merge_axis,upsample_choice=upsample_choice,
                                                stride_list=current_stride_list)
    
    aux_conv_tensor = 0
    if seg_num_class==1:
        current_acti = 'sigmoid'
    else:
        current_acti = 'softmax'
    
    final_conv_func = Convolution3D(nb_filter = seg_num_class,kernel_size=final_kernel_size, 
                                 kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer,
                                border_mode='same',activation=current_acti,name='NoduleSegDecoder_conv3d_mins1')

    if not deep_supervision:
        block_prefix = 'NoduleSegDecoder_NotSupervision'
        conv_mins1 = final_conv_func(combine_tensors[0])
        aux_conv_tensor = conv_mins1
    
        ########### TODO:revise this part during non fix
        if ACF_choice:
            coarse_feature_map = conv_mins1

            conv_mins1 = []
            feature_map_class = CCB([coarse_feature_map,combine_tensors[0]],num_classes=seg_num_class)
            CAB_feature_map = CAB([coarse_feature_map,feature_map_class],num_classes=seg_num_class,activaion=current_acti)
            concatenate_feature_map = Concatenate()([CAB_feature_map,combine_tensors[0]])
            aux_conv_tensor = concatenate_feature_map
            final_result = Convolution3D(nb_filter = seg_num_class,kernel_size=final_kernel_size, 
                                 kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer,
                                border_mode='same',activation=current_acti,name='NoduleSegDecoder_conv_ACF')(concatenate_feature_map)
            
            conv_mins1 = [coarse_feature_map,final_result]
    else:
        conv_mins1 = []
        input_tensors = combine_tensors[1]
        rate = np.product(current_stride_list[:len(input_tensors)])
        for idx in range(len(input_tensors)):
            # rate = 2 ** (2-idx)
            rate /= current_stride_list[idx]
            rate = int(rate)
            print ('During deep supervision rate of stage %d is %d'%(idx,rate))

            # rate = int(input_tensor._keras_shape[1]/input_tensors[idx]._keras_shape[1])
            block_pre = 'NoduleSegDecoder_DeepSupervision_block%02d'%(idx+1)
            current_tensor = UpsampleBlock(input_tensors[idx],num_filters=seg_num_class,kernel_size=final_kernel_size,rate=rate,
                                  kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                                   block_pre=block_pre,upsample_choice=upsample_choice,activation_func=current_acti)
                
            conv_mins1.append(current_tensor)
            
        if OCR_choice:
            features = combine_tensors[1][-1]
            coarse_seg_output = final_conv_func(conv_mins1[-1])
            print ('feature shape and coarse_seg_output shape',features.shape,coarse_seg_output.shape)
            proxy = SpatialGather(features,coarse_seg_output)
            ocr_output = OCRModule(features,proxy,norm_func=norm_func,activation_func=activation_func,num_filters=features._keras_shape[-1]*4,
                                   kernel_size=1,kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer)
            ocr_output = Convolution3D(nb_filter = seg_num_class,kernel_size=final_kernel_size, 
                                 kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer,
                                border_mode='same',activation=current_acti,name='NoduleSegDecoder_conv3d_ocr_output')(ocr_output)
            
            conv_mins1.append(ocr_output)
 
        
        
    output = conv_mins1
    if aux_task:
        ### Do classification
        pool_result = classification_layers(name='aux_task_pool_layer')(combine_tensors[1][-1])
        ### Dense unit
        dense_result = Dense(num_classes*32)(pool_result)
        class_result = Dense(num_classes)(dense_result)
        output.append(class_result)
    model = Model(input_tensor,output)
    return model