import os
import sys
import glob
import numpy as np

import keras

import tensorflow as tf
from keras import Model
from keras import backend as K
from keras.regularizers import l2
from keras.engine import Layer,InputSpec
from keras.layers.merge import concatenate
from keras.callbacks import TensorBoard,Callback
from keras.layers.advanced_activations import LeakyReLU
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from keras import initializers, regularizers, constraints,optimizers
from keras.callbacks import ModelCheckpoint, LearningRateScheduler,TensorBoard
from keras.layers import Add,Input,Conv3D,Convolution3D,Dropout,UpSampling3D,Concatenate,MaxPooling3D,\
GlobalAveragePooling3D,Dense,GlobalMaxPooling3D,Lambda,Activation,Reshape,Permute, PReLU, Deconvolution3D

sys.path.append('../baseLayers/')
from CommonLayers import NormActi,ConvUnit

def VGGConvBlock(x,**kwargs):
    norm_func = kwargs.get('norm_func')
    activation_func = kwargs.get('activation_func')
    num_filters = kwargs.get('num_filters')
    kernel_size = kwargs.get('kernel_size',3)
    padding = kwargs.get('padding','valid')
    
    block_idx = kwargs.get('block_idx',1)
    atrous_rate = kwargs.get('atrous_rate',1)
    stride = kwargs.get('stride')
    num_units = kwargs.get('num_units')
    conv_first = kwargs.get('conv_first',True)
    dropout_rate= kwargs.get('dropout_rate',0)
    kernel_initializer = kwargs.get('kernel_initializer','he_normal')
    kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4))
    
    block_prefix = 'VGG_block%02d'%block_idx
    
#     print ('number of units is ',num_units)
    for idx in range(num_units):
        x = ConvUnit(x,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters,
                    kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix,
                     kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                    layer_idx=idx+1,conv_first=conv_first)
    if dropout_rate>0:
        x = Dropout(rate=dropout_rate,name='%s_dropout'%(block_prefix))(x)
    if stride>atrous_rate:
        x = MaxPooling3D(pool_size=stride,name='%s_pool'%(block_prefix))(x)
    
    return x
    
    
    
def VGG(input_shape,**kwargs):
    
    num_blocks = kwargs.get('num_blocks',5)
    strides = kwargs.get('strides')
    atrous_rates = kwargs.get('atrous_rates')
    base_filters = kwargs.get('base_filters',32)
    norm_func = kwargs.get('norm_func')
    activation_func = kwargs.get('activation_func')
    kernel_size = kwargs.get('kernel_size',3)
    padding = kwargs.get('padding','valid')
    dropout_rate = kwargs.get('dropout_rate',0)
    classification_layers = kwargs.get('classification_layers')
    num_classes = kwargs.get('num_classes',1)
    kernel_initializer = kwargs.get('kernel_initializer','he_normal')
    kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4))
    
    
    num_layers = [2,2,3,3,3]
    input_tensor = Input(shape=input_shape)
    layers = [input_tensor]
#     print ('type(dropout_rate',type(dropout_rate),dropout_rate)
    if type(dropout_rate)== float or type(dropout_rate)== int:
        dropout_rates = [dropout_rate for _ in range(num_blocks)]
    else:
        dropout_rates = dropout_rate
    
    real_stride_ratios = []
    for block_idx in range(num_blocks):
        current_stride = strides[block_idx]
        current_atrous_rate = atrous_rates[block_idx]
        if current_stride>current_atrous_rate:
            real_stride_ratios.append(current_stride)
        else:
            real_stride_ratios.append(1)
        num_filters = base_filters * (2**block_idx)
        x = VGGConvBlock(layers[-1],norm_func=norm_func,activation_func=activation_func,num_filters=num_filters,kernel_size=kernel_size,
                        padding=padding,block_idx=block_idx+1,atrous_rate=current_atrous_rate,stride=current_stride,
                         kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                        num_units = num_layers[block_idx],dropout_rate = dropout_rates[block_idx])
        layers.append(x)
    
    
    '''
    Add classification part
    '''
    final_feature = classification_layers(name='VGG_classification')(x)
    if num_classes == 1:
        final_acti = 'sigmoid'
    else:
        final_acti = 'softmax'
    final_output = Dense(num_classes,activation=final_acti,name='VGG_final_output')(final_feature)
    
    
    model = Model(input_tensor,final_output)
    return model,layers,real_stride_ratios