import os import sys import glob import numpy as np import keras import tensorflow as tf from keras import Model from keras.regularizers import l2 from keras import backend as K from keras.engine import Layer,InputSpec from keras.layers.merge import concatenate from keras.callbacks import TensorBoard,Callback from keras.layers.advanced_activations import LeakyReLU from keras.preprocessing.image import ImageDataGenerator from keras.layers.normalization import BatchNormalization from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau from keras import initializers, regularizers, constraints,optimizers from keras.callbacks import ModelCheckpoint, LearningRateScheduler,TensorBoard from keras.layers import Add,Input,Conv3D,Convolution3D,Dropout,UpSampling3D,Concatenate,MaxPooling3D,\ GlobalAveragePooling3D,Dense,GlobalMaxPooling3D,Lambda,Activation,Reshape,Permute, PReLU, Deconvolution3D,Conv3DTranspose sys.path.append('../baseLayers/') from unet_utils import * from SEB import SEB_block from CommonLayers import NormActi,ConvUnit def UpsampleBlock(x,**kwargs): num_filters = kwargs.get('num_filters') kernel_initializer = kwargs.get('kernel_initializer','he_normal') kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4)) block_pre = kwargs.get('block_pre','') kernel_size = kwargs.get('kernel_size',3) rate = kwargs.get('rate',2) upsample_choice = kwargs.get('upsample_choice','upsample') norm_func = kwargs.get('norm_func') activation_func = kwargs.get('activation_func') atrous_rate = kwargs.get('atrous_rate',1) padding = kwargs.get('padding','same') if type(activation_func)==str: current_acti = activation_func activation_func = None else: current_acti = None if upsample_choice == 'upsample': x = UpSampling3D(size=rate,name='%s_upsample'%(block_pre))(x) x = Convolution3D(nb_filter = num_filters,kernel_size=kernel_size, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, border_mode=padding,activation=current_acti,name='%s_conv'%block_pre)(x) elif upsample_choice == 'deconv': x = Deconvolution3D(filters = num_filters,kernel_size=kernel_size, strides=rate, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, padding=padding,activation=current_acti,name='%s_conv'%block_pre)(x) else: x = Conv3DTranspose(filters = num_filters,kernel_size=kernel_size, strides=rate, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, padding=padding,activation=current_acti,name='%s_conv'%block_pre)(x) x = NormActi(x,norm_func=norm_func,activation_func=activation_func, block_prefix=block_pre,block_suffix=block_pre) # else: # x = ConvUnit(x,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters, # kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_pre, # kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, # layer_idx=1,conv_first=True) return x def VnetDecoderBlock(x,encoder_tensors,**kwargs): block_idx = kwargs.get('block_idx') num_units = kwargs.get('num_units') kernel_initializer = kwargs.get('kernel_initializer','he_normal') kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4)) SEB_choice = kwargs.get('SEB_choice',False) kernel_size = kwargs.get('kernel_size',3) norm_func = kwargs.get('norm_func') activation_func = kwargs.get('activation_func') num_filters = kwargs.get('num_filters') padding = kwargs.get('padding','same') atrous_rate = kwargs.get('atrous_rate',1) stride = kwargs.get('upsample_rate',2) conv_first = kwargs.get('conv_first',True) merge_axis = kwargs.get('merge_axis',-1) upsample_choice = kwargs.get('upsample_choice','upsample') DIM1_AXIS,DIM2_AXIS,DIM3_AXIS = 1,2,3 block_prefix = 'NoduleSegDecoder_Block%02d'%block_idx input_tensor = x x1 = UpsampleBlock(x,num_filters=num_filters,kernel_size=kernel_size,rate=stride,atrous_rate=atrous_rate, padding=padding,block_pre=block_prefix,kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer,upsample_choice=upsample_choice, norm_func=norm_func,activation_func=activation_func) target_depth,target_height,target_width = x1._keras_shape[DIM1_AXIS],x1._keras_shape[DIM2_AXIS],x1._keras_shape[DIM3_AXIS] depth_before_SEB,height_before_SEB,width_before_SEB = x._keras_shape[DIM1_AXIS],x._keras_shape[DIM2_AXIS],x._keras_shape[DIM3_AXIS] if not SEB_choice: encoder_target_tensor = encoder_tensors[-(block_idx+1)] else: target_tensor = encoder_tensors[-(block_idx+1)] SEB_tensors = encoder_tensors[-(block_idx):][::-1] upsample_tensor_list = [] for layer_idx,current_tensor in enumerate(SEB_tensors): current_depth,current_height,current_width = current_tensor._keras_shape[DIM1_AXIS],current_tensor._keras_shape[DIM2_AXIS],current_tensor._keras_shape[DIM3_AXIS] upsample_rate = [int(depth_before_SEB/current_depth),int(height_before_SEB/current_height),int(width_before_SEB/current_width)] upsample_tensor = UpSampling3D(size=upsample_rate,name='%s_SEB_preUpsample%02d'%(block_prefix,layer_idx+1))(current_tensor) upsample_tensor_list.append(upsample_tensor) target_num_filters = target_tensor._keras_shape[-1] encoder_target_tensor = SEB_block([upsample_tensor_list,target_tensor],norm_func=norm_func,activation_func=activation_func, num_filters=target_num_filters,kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding, block_prefix=block_prefix,kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, conv_first=conv_first,merge_axis=merge_axis) x = Concatenate(axis=merge_axis,name='%s_Concatenate'%(block_prefix))([encoder_target_tensor,x1]) xcat = x for layer_idx in range(1,num_units): x = ConvUnit(x,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters, kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix, kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, layer_idx=layer_idx+1,conv_first=conv_first) return Add(name='%s_FinalAdd'%(block_prefix))([x1,x]) # print ('x.shape',x.shape) # return x def NoduleSegDeepCombineBlock(input_tensors,**kwargs): kernel_initializer = kwargs.get('kernel_initializer','he_normal') kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4)) kernel_size = kwargs.get('kernel_size',3) norm_func = kwargs.get('norm_func') activation_func = kwargs.get('activation_func') num_filters = kwargs.get('num_filters') padding = kwargs.get('padding','same') atrous_rate = kwargs.get('atrous_rate',1) stride = kwargs.get('upsample_rate',2) conv_first = kwargs.get('conv_first',True) merge_axis = kwargs.get('merge_axis',-1) upsample_choice = kwargs.get('upsample_choice','upsample') print ('input_tensors is ',input_tensors) x1,x2,x3 = input_tensors block_prefix = 'NoduleSegDeepCombineBlock' uprate_1 = int(x2._keras_shape[1]/x1._keras_shape[1]) uprate_2 = int(x3._keras_shape[1]/x2._keras_shape[1]) x1 = ConvUnit(x1,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters, kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_block01_', kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, layer_idx=1,conv_first=conv_first) up_x1 = UpSampling3D(size=uprate_1,name='%s_block01_upsample'%block_prefix)(x1) x2 = ConvUnit(x2,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters, kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_block02_', kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, layer_idx=1,conv_first=conv_first) up_x2 = UpSampling3D(size=uprate_2,name='%s_block02_upsample'%block_prefix)(Add(name='%s_Add01'%block_prefix)([up_x1,x2])) x3 = ConvUnit(x3,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters, kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_block03_', kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, layer_idx=1,conv_first=conv_first) final_result = Add(name='%s_Add02'%block_prefix)([up_x2,x3]) return final_result,[x1,x2,x3] def VnetDecoder(encoder_result,**kwargs): kernel_initializer = kwargs.get('kernel_initializer','he_normal') kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4)) kernel_size = kwargs.get('kernel_size',3) final_kernel_size = kwargs.get('final_kernel_size',3) norm_func = kwargs.get('norm_func') activation_func = kwargs.get('activation_func') padding = kwargs.get('padding','same') atrous_rate = kwargs.get('atrous_rate',1) conv_first = kwargs.get('conv_first',True) merge_axis = kwargs.get('merge_axis',-1) SEB_choice = kwargs.get('SEB_choice',True) ACF_choice = kwargs.get('ACF_choice',False) OCR_choice = kwargs.get('OCR_choice',False) deep_supervision = kwargs.get('deep_supervision',True) num_units = kwargs.get('num_units',[3,3,3,3]) seg_num_class = kwargs.get('seg_num_class') aux_task = kwargs.get('aux_task',False) num_classes = kwargs.get('num_classes') classification_layers = kwargs.get('classification_layers') upsample_choice = kwargs.get('upsample_choice','upsample') divide_ratio = kwargs.get('divide_ratio',4) encoder_model = encoder_result[0] encoder_tensors = encoder_result[1] input_tensor = encoder_model.input x = encoder_tensors[-1] decoder_tensors = [] for block_idx in range(len(encoder_tensors)-2): current_stride = int(encoder_tensors[-(block_idx+2)]._keras_shape[1]/x._keras_shape[1]) num_filters = int(x._keras_shape[-1]/2) if block_idx == len(encoder_tensors)-3: SEB_choice = False x = VnetDecoderBlock(x,encoder_tensors,block_idx=block_idx+1,num_units=num_units[block_idx], kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, SEB_choice=SEB_choice,kernel_size=kernel_size,norm_func=norm_func,activation_func=activation_func, num_filters=num_filters,padding=padding,atrous_rate=atrous_rate,upsample_rate=current_stride,conv_first=conv_first, merge_axis=merge_axis,upsample_choice=upsample_choice) decoder_tensors.append(x) start_idx = len(decoder_tensors)-3 combine_tensors = NoduleSegDeepCombineBlock(decoder_tensors[start_idx:start_idx+4],kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer,kernel_size=kernel_size,norm_func=norm_func, activation_func=activation_func,num_filters=int(decoder_tensors[-1]._keras_shape[-1]/divide_ratio), padding=padding,atrous_rate=atrous_rate,conv_first=conv_first,merge_axis=merge_axis,upsample_choice=upsample_choice) aux_conv_tensor = 0 if seg_num_class==1: current_acti = 'sigmoid' else: current_acti = 'softmax' final_conv_func = Convolution3D(nb_filter = seg_num_class,kernel_size=final_kernel_size, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, border_mode='same',activation=current_acti,name='NoduleSegDecoder_conv3d_mins1') if not deep_supervision: block_prefix = 'NoduleSegDecoder_NotSupervision' conv_mins1 = final_conv_func(combine_tensors[0]) aux_conv_tensor = conv_mins1 if ACF_choice: coarse_feature_map = conv_mins1 conv_mins1 = [] feature_map_class = CCB([coarse_feature_map,combine_tensors[0]],num_classes=seg_num_class) CAB_feature_map = CAB([coarse_feature_map,feature_map_class],num_classes=seg_num_class,activaion=current_acti) concatenate_feature_map = Concatenate()([CAB_feature_map,combine_tensors[0]]) aux_conv_tensor = concatenate_feature_map final_result = Convolution3D(nb_filter = seg_num_class,kernel_size=final_kernel_size, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, border_mode='same',activation=current_acti,name='NoduleSegDecoder_conv_ACF')(concatenate_feature_map) conv_mins1 = [coarse_feature_map,final_result] else: conv_mins1 = [] input_tensors = combine_tensors[1] for idx in range(len(input_tensors)): # rate = 2 ** (2-idx) rate = int(input_tensor._keras_shape[1]/input_tensors[idx]._keras_shape[1]) block_pre = 'NoduleSegDecoder_DeepSupervision_block%02d'%(idx+1) current_tensor = UpsampleBlock(input_tensors[idx],num_filters=seg_num_class,kernel_size=final_kernel_size,rate=rate, kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, block_pre=block_pre,upsample_choice=upsample_choice,activation_func=current_acti) conv_mins1.append(current_tensor) output = conv_mins1 if aux_task: ### Do classification pool_result = classification_layers(name='aux_task_pool_layer')(combine_tensors[1][-1]) ### Dense unit dense_result = Dense(num_classes*32)(pool_result) class_result = Dense(num_classes)(dense_result) output.append(class_result) model = Model(input_tensor,output) return model