ResNet.py 11.8 KB
import os
import sys
import six
import math
import glob
import numpy as np

import keras
import tensorflow as tf
from keras import Model
from keras import backend as K
from keras.regularizers import l2
from keras.engine import Layer,InputSpec
from keras.layers.merge import concatenate
from keras.callbacks import TensorBoard,Callback
from keras.layers.advanced_activations import LeakyReLU
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from keras import initializers, regularizers, constraints,optimizers
from keras.callbacks import ModelCheckpoint, LearningRateScheduler,TensorBoard
from keras.layers import Add,Input,Conv3D,Convolution3D,Dropout,UpSampling3D,Concatenate,MaxPooling3D,\
GlobalAveragePooling3D,Dense,GlobalMaxPooling3D,Lambda,Activation,Reshape,Permute, PReLU, Deconvolution3D

sys.path.append('../baseLayers/')
from CommonLayers import NormActi,ConvUnit

def get_block(identifier):
    if isinstance(identifier,six.string_types):
        res = globals().get(identifier)
        if not res:
            raise ValueError('Invalid {}'.format(identifier))
        return res
    return identifier


def ShortCut(x1,x2,**kwargs):
    
    padding = kwargs.get('padding')
    block_prefix = kwargs.get('block_prefix')
    kernel_size = kwargs.get('kernel_size',3)
    kernel_initializer = kwargs.get('kernel_initializer','he_normal')
    kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4))
    
    DIM1_AXIS,DIM2_AXIS,DIM3_AXIS = 1,2,3
    CHANNEL_AXIS = 4
    stride_dim1 = math.ceil(x1._keras_shape[DIM1_AXIS]/x2._keras_shape[DIM1_AXIS])
    stride_dim2 = math.ceil(x1._keras_shape[DIM2_AXIS]/x2._keras_shape[DIM2_AXIS])
    stride_dim3 = math.ceil(x1._keras_shape[DIM3_AXIS]/x2._keras_shape[DIM3_AXIS])
    
    equal_channels = x2._keras_shape[CHANNEL_AXIS]==x1._keras_shape[CHANNEL_AXIS]
    
    add_tensor = x1
    if stride_dim1>1 or stride_dim2>1 or stride_dim3>3 or (not equal_channels):
        add_tensor = Convolution3D(filters=x2._keras_shape[CHANNEL_AXIS],
                                   kernel_size = kernel_size,
                                  strides = (stride_dim1,stride_dim2,stride_dim3),
                                   kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                                  padding = padding,name='%s_shortcut'%(block_prefix))(x1)
        
    return Add(name='%s_add'%(block_prefix))([add_tensor,x2])


def ResConvBlock(x,**kwargs):
    
    block_func = kwargs.get('block_func')
    activation_func = kwargs.get('activation_func')
    norm_func = kwargs.get('norm_func')
    kernel_size = kwargs.get('kernel_size',3)
    atrous_rate = kwargs.get('atrous_rate',1)
    padding = kwargs.get('padding')
    
    num_filters = kwargs.get('num_filters')
    block_idx = kwargs.get('block_idx')
    repetitions = kwargs.get('repetitions')
    
    converted_block_func = get_block(block_func)
    block_prefix = 'ResNet_%sBlock%02d'%(block_func,block_idx)
    kernel_initializer = kwargs.get('kernel_initializer','he_normal')
    kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4))
    
    for layer_idx in range(repetitions):
        strides = 1
        if layer_idx == 0 and block_idx>1:
            strides = 2
        if layer_idx == 0 and block_idx == 1:
            is_first_layer_of_first_block = True
        else:
            is_first_layer_of_first_block = False
        x = converted_block_func(x,activation_func=activation_func,norm_func=norm_func,num_filters = num_filters,
                       kernel_size = kernel_size,stride = strides,padding=padding,layer_idx=layer_idx+1,
                       kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                      block_prefix=block_prefix)
        
    return x
    
    
def BasicBlock(x,**kwargs):
    is_first_layer_of_first_block = kwargs.get('is_first_layer_of_first_block',False)
    activation_func = kwargs.get('activation_func')
    norm_func = kwargs.get('norm_func')
    num_filters = kwargs.get('num_filters')
    kernel_size = kwargs.get('kernel_size')
    stride = kwargs.get('stride')
    atrous_rate = kwargs.get('atrous_rate',1)
    padding = kwargs.get('padding')
    layer_idx = kwargs.get('layer_idx')
    block_prefix = kwargs.get('block_prefix')
    
    block_suffix = '%02d'%layer_idx
    kernel_initializer = kwargs.get('kernel_initializer','he_normal')
    kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4))
    
    
    if is_first_layer_of_first_block:
        x1 = Convolution3D(num_filters=num_filters,kernel_size=kernel_size,padding=padding,strides=stride,
                           kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                          name='%s_conv_%s'(block_prefix,block_suffix))(x)
    else:
        x1 = ConvUnit(x,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters,conv_stride = stride,
                    kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_basic',
                    kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                    layer_idx=layer_idx+1,conv_first=True)
    
    residual = ConvUnit(x1,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters,
                    kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_residual',
                    kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                    layer_idx=layer_idx+1,conv_first=True)
    return ShortCut(x,residual,padding=padding,kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                block_prefix=block_prefix+"_layer%02d"%layer_idx)

    

def BottleNeck(x,**kwargs):
    is_first_layer_of_first_block = kwargs.get('is_first_layer_of_first_block',False)
    activation_func = kwargs.get('activation_func')
    norm_func = kwargs.get('norm_func')
    num_filters = kwargs.get('num_filters')
    kernel_size = kwargs.get('kernel_size')
    stride = kwargs.get('stride')
    atrous_rate = kwargs.get('atrous_rate',1)
    padding = kwargs.get('padding')
    layer_idx = kwargs.get('layer_idx')
    block_prefix = kwargs.get('block_prefix')
    kernel_initializer = kwargs.get('kernel_initializer','he_normal')
    kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4))
    
    block_suffix = '%02d'%layer_idx
#     print ('block_suffix and num_filters is ',block_suffix,num_filters)
    
    if is_first_layer_of_first_block:
        conv_1_1 = Convolution3D(num_filters=num_filters,kernel_size=1,padding=padding,strides=strides,
                          kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                          name='%s_conv1v1_%s'(block_prefix,block_suffix))(x)
    else:
        conv_1_1 = ConvUnit(x,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters,conv_stride = stride,
                    kernel_size=1,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_1v1',
                    kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                    layer_idx=layer_idx+1,conv_first=True)
        
    conv_3_3 = ConvUnit(conv_1_1,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters*4,
                    kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_3v3_',
                    kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                    layer_idx=layer_idx+1,conv_first=True)
    residual =  ConvUnit(conv_3_3,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters,
                    kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_residual',
                    kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                    layer_idx=layer_idx+1,conv_first=True)

    return ShortCut(x,residual,padding=padding,block_prefix=block_prefix+"_layer%02d"%layer_idx)

class Resnet3DBuilder(object):
    def __init__(self,**kwargs):
        self.num_classes = kwargs.get('num_classes')
        self.classification_layers = kwargs.get('classification_layers')
        self.base_filters = kwargs.get('base_filters',32)
        self.dropout_rate = kwargs.get('dropout_rate',0)
        self.padding = kwargs.get('padding','same')
        self.kernel_size = kwargs.get('kernel_size',3)
        self.init_kernel_size = kwargs.get('init_kernel_size',3)
        self.norm_func = kwargs.get('norm_func')
        self.activation_func = kwargs.get('activation_func')
        self.kernel_initializer = kwargs.get('kernel_initializer','he_normal')
        self.kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4))

    def build(self,input_shape,block_func,repetitions,model_prefix):
        num_classes = self.num_classes
        classification_layers = self.classification_layers
        base_filters = self.base_filters
        dropout_rate = self.dropout_rate
        padding = self.padding
        kernel_size = self.kernel_size
        init_kernel_size = self.init_kernel_size
        norm_func = self.norm_func
        activation_func = self.activation_func
        kernel_initializer = self.kernel_initializer
        kernel_regularizer = self.kernel_regularizer
        '''
        TODO: change atrous_rate
        '''
        atrous_rate = 1
        
        
        input_tensor = Input(shape=input_shape)
        block_prefix = 'ResNet_InitBlock'
        conv1 = ConvUnit(input_tensor,norm_func=norm_func,activation_func=activation_func,num_filters=base_filters,
                    kernel_size=init_kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix,
                    kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                    layer_idx=1,conv_first=True)
        pool1 = MaxPooling3D(strides=2,name='%s_pool'%(block_prefix))(conv1)
        
        block = pool1
        layers = [input_tensor,conv1,block]
        for block_idx,repeat_time  in enumerate(repetitions):
#             print ('base_filters*(2**block_idx) of block_idx %02d is '%block_idx,base_filters*(2**block_idx))
            block = ResConvBlock(block,block_func=block_func,norm_func=norm_func,activation_func=activation_func,
                                kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,
                                 kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer,
                                 num_filters=base_filters*(2**(block_idx+1)),block_idx=block_idx+1,repetitions=repeat_time)
            layers.append(block)
            
        final_feature = classification_layers(name='%s_classification'%(model_prefix))(block)
        if num_classes == 1:
            final_acti = 'sigmoid'
        else:
            final_acti = 'softmax'
        final_output = Dense(num_classes,activation=final_acti,name='%s_final_output'%(model_prefix))(final_feature)

        layers = layers[:2] + layers[3:]
        model = Model(input_tensor,final_output)
    
        return model,layers,[]
    
    
    def build_resnet_basic(self,input_shape):
        return self.build(input_shape,'BasicBlock',[2,2,3,4],'resnet_basic')
    

    def build_resnet_18(self,input_shape):
        return self.build(input_shape,'BasicBlock',[3,4,6,3],'resnet18')
    

    def build_resnet_50(self,input_shape):
        return self.build(input_shape,'BottleNeck',[3,4,6,3],'resnet50')

    def build_resnet_101(self,input_shape):
        return self.build(input_shape,'BottleNeck',[3,4,23,3],'resnet101')
    
    def build_resnet_152(self,input_shape):
        return self.build(input_shape,'BottleNeck',[3,8,36,3],'resnet152')