import os import sys import glob import numpy as np import tensorflow as tf from keras import Model from keras.regularizers import l2 from keras import backend as K from keras.engine import Layer,InputSpec from keras.layers.merge import concatenate from keras.callbacks import TensorBoard,Callback from keras.layers.advanced_activations import LeakyReLU from keras.preprocessing.image import ImageDataGenerator from keras.layers.normalization import BatchNormalization from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau from keras import initializers, regularizers, constraints,optimizers from keras.callbacks import ModelCheckpoint, LearningRateScheduler,TensorBoard from keras.layers import Add,Input,Conv3D,Convolution3D,Dropout,UpSampling3D,Concatenate,MaxPooling3D,\ GlobalAveragePooling3D,Dense,GlobalMaxPooling3D,Lambda,Activation,Reshape,Permute, PReLU, Deconvolution3D,Conv3DTranspose sys.path.append('../baseLayers/') from unet_utils import * from SEB import SEB_block from OCR import SpatialGather,OCRModule from CommonLayers import NormActi,ConvUnit def UpsampleBlock(x,**kwargs): num_filters = kwargs.get('num_filters') kernel_initializer = kwargs.get('kernel_initializer','he_normal') kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4)) block_pre = kwargs.get('block_pre','') kernel_size = kwargs.get('kernel_size',3) rate = kwargs.get('rate',2) upsample_choice = kwargs.get('upsample_choice','upsample') norm_func = kwargs.get('norm_func') activation_func = kwargs.get('activation_func') atrous_rate = kwargs.get('atrous_rate',1) padding = kwargs.get('padding','same') if type(activation_func)==str: current_acti = activation_func activation_func = None else: current_acti = None # if rate>1: # print ('kernel of %s is %d'%(block_pre,kernel_size)) # print ('current_acti is ',current_acti,norm_func) if upsample_choice == 'upsample': x = UpSampling3D(size=rate,name='%s_upsample'%(block_pre))(x) x = Convolution3D(nb_filter = num_filters,kernel_size=kernel_size, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, border_mode=padding,activation=current_acti,name='%s_conv'%block_pre)(x) elif upsample_choice == 'deconv': x = Deconvolution3D(filters = num_filters,kernel_size=kernel_size, strides=rate, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, padding=padding,activation=current_acti,name='%s_conv'%block_pre)(x) else: x = Conv3DTranspose(filters = num_filters,kernel_size=kernel_size, strides=rate, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, padding=padding,activation=current_acti,name='%s_conv'%block_pre)(x) x = NormActi(x,norm_func=norm_func,activation_func=activation_func, block_prefix=block_pre,block_suffix=block_pre) return x def NoduleSegDecoderBlock(x,encoder_tensors,**kwargs): block_idx = kwargs.get('block_idx') num_units = kwargs.get('num_units') kernel_initializer = kwargs.get('kernel_initializer','he_normal') kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4)) SEB_choice = kwargs.get('SEB_choice',False) kernel_size = kwargs.get('kernel_size',3) norm_func = kwargs.get('norm_func') activation_func = kwargs.get('activation_func') num_filters = kwargs.get('num_filters') padding = kwargs.get('padding','same') atrous_rate = kwargs.get('atrous_rate',1) stride = kwargs.get('upsample_rate',2) conv_first = kwargs.get('conv_first',True) merge_axis = kwargs.get('merge_axis',-1) upsample_choice = kwargs.get('upsample_choice','upsample') SEB_upsample_stride_list = kwargs.get('SEB_upsample_stride_list',None) try: SEB_upsample_ratio_value = int(np.product(SEB_upsample_stride_list[:-1],axis=0)) except: SEB_upsample_ratio_value = None DIM1_AXIS,DIM2_AXIS,DIM3_AXIS = 1,2,3 block_prefix = 'NoduleSegDecoder_Block%02d'%block_idx x1 = UpsampleBlock(x,num_filters=num_filters,kernel_size=kernel_size,rate=stride,atrous_rate=atrous_rate, padding=padding,block_pre=block_prefix,kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer,upsample_choice=upsample_choice, norm_func=norm_func,activation_func=activation_func) target_depth,target_height,target_width = x1._keras_shape[DIM1_AXIS],x1._keras_shape[DIM2_AXIS],x1._keras_shape[DIM3_AXIS] depth_before_SEB,height_before_SEB,width_before_SEB = x._keras_shape[DIM1_AXIS],x._keras_shape[DIM2_AXIS],x._keras_shape[DIM3_AXIS] if not SEB_choice: encoder_target_tensor = encoder_tensors[-(block_idx+1)] else: target_tensor = encoder_tensors[-(block_idx+1)] SEB_tensors = encoder_tensors[-(block_idx):][::-1] upsample_tensor_list = [] for layer_idx,current_tensor in enumerate(SEB_tensors): current_depth,current_height,current_width = current_tensor._keras_shape[DIM1_AXIS],current_tensor._keras_shape[DIM2_AXIS],current_tensor._keras_shape[DIM3_AXIS] if SEB_upsample_stride_list: upsample_rate = SEB_upsample_ratio_value SEB_upsample_ratio_value/= SEB_upsample_stride_list[layer_idx] else: upsample_rate = [int(depth_before_SEB/current_depth),int(height_before_SEB/current_height),int(width_before_SEB/current_width)] upsample_tensor = UpSampling3D(size=upsample_rate,name='%s_SEB_preUpsample%02d'%(block_prefix,layer_idx+1))(current_tensor) upsample_tensor_list.append(upsample_tensor) target_num_filters = target_tensor._keras_shape[-1] encoder_target_tensor = SEB_block([upsample_tensor_list,target_tensor],norm_func=norm_func,activation_func=activation_func, num_filters=target_num_filters,kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding, block_prefix=block_prefix,kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, conv_first=conv_first,merge_axis=merge_axis,upsample_rate=stride) # print ("shape of encoder_target_tensor is ",encoder_target_tensor.shape) x = Concatenate(axis=merge_axis,name='%s_Concatenate'%(block_prefix))([encoder_target_tensor,x1]) for layer_idx in range(1,num_units): x = ConvUnit(x,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters, kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix, kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, layer_idx=layer_idx+1,conv_first=conv_first) return x def NoduleSegDeepCombineBlock(input_tensors,**kwargs): kernel_initializer = kwargs.get('kernel_initializer','he_normal') kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4)) kernel_size = kwargs.get('kernel_size',3) norm_func = kwargs.get('norm_func') activation_func = kwargs.get('activation_func') num_filters = kwargs.get('num_filters') padding = kwargs.get('padding','same') atrous_rate = kwargs.get('atrous_rate',1) stride = kwargs.get('upsample_rate',2) conv_first = kwargs.get('conv_first',True) merge_axis = kwargs.get('merge_axis',-1) upsample_choice = kwargs.get('upsample_choice','upsample') stride_list = kwargs.get('stride_list',None) x1,x2,x3 = input_tensors block_prefix = 'NoduleSegDeepCombineBlock' try: uprate_1 = int(x2._keras_shape[1]/x1._keras_shape[1]) uprate_2 = int(x3._keras_shape[1]/x2._keras_shape[1]) except: uprate_1 = stride_list[0] uprate_2 = stride_list[1] x1 = ConvUnit(x1,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters, kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_block01_', kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, layer_idx=1,conv_first=conv_first) up_x1 = UpSampling3D(size=uprate_1,name='%s_block01_upsample'%block_prefix)(x1) x2 = ConvUnit(x2,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters, kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_block02_', kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, layer_idx=1,conv_first=conv_first) up_x2 = UpSampling3D(size=uprate_2,name='%s_block02_upsample'%block_prefix)(Add(name='%s_Add01'%block_prefix)([up_x1,x2])) x3 = ConvUnit(x3,norm_func=norm_func,activation_func=activation_func,num_filters=num_filters, kernel_size=kernel_size,atrous_rate=atrous_rate,padding=padding,block_prefix=block_prefix+'_block03_', kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, layer_idx=1,conv_first=conv_first) final_result = Add(name='%s_Add02'%block_prefix)([up_x2,x3]) return final_result,[x1,x2,x3] def NoduleSegDecoder_proxima(encoder_result,**kwargs): kernel_initializer = kwargs.get('kernel_initializer','he_normal') kernel_regularizer = kwargs.get('kernel_regularizer',l2(1e-4)) kernel_size = kwargs.get('kernel_size',3) final_kernel_size = kwargs.get('final_kernel_size',3) norm_func = kwargs.get('norm_func') activation_func = kwargs.get('activation_func') padding = kwargs.get('padding','same') atrous_rate = kwargs.get('atrous_rate',1) conv_first = kwargs.get('conv_first',True) merge_axis = kwargs.get('merge_axis',-1) SEB_choice = kwargs.get('SEB_choice',True) ACF_choice = kwargs.get('ACF_choice',False) OCR_choice = kwargs.get('OCR_choice',False) deep_supervision = kwargs.get('deep_supervision',True) num_units = kwargs.get('num_units',[3,3,3,3]) seg_num_class = kwargs.get('seg_num_class') aux_task = kwargs.get('aux_task',False) num_classes = kwargs.get('num_classes') classification_layers = kwargs.get('classification_layers') upsample_choice = kwargs.get('upsample_choice','upsample') divide_ratio = kwargs.get('divide_ratio',4) encoder_model = encoder_result[0] encoder_tensors = encoder_result[1] try: encoder_stride_list = encoder_result[2] except: encoder_stride_list = [None for _ in range(len(encoder_tensors))] input_tensor = encoder_model.input x = encoder_tensors[-1] decoder_tensors = [] for block_idx in range(len(encoder_tensors)-2): if encoder_stride_list[-(block_idx+1)]: current_stride = encoder_stride_list[-(block_idx+1)] else: current_stride = int(encoder_tensors[-(block_idx+2)]._keras_shape[1]/x._keras_shape[1]) num_filters = int(x._keras_shape[-1]/2) if block_idx == len(encoder_tensors)-3: SEB_choice = False x = NoduleSegDecoderBlock(x,encoder_tensors,block_idx=block_idx+1,num_units=num_units[block_idx], kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, SEB_choice=SEB_choice,kernel_size=kernel_size,norm_func=norm_func,activation_func=activation_func, num_filters=num_filters,padding=padding,atrous_rate=atrous_rate,upsample_rate=current_stride,conv_first=conv_first, merge_axis=merge_axis,upsample_choice=upsample_choice,SEB_upsample_stride_list=encoder_stride_list[::-1][:block_idx+1]) decoder_tensors.append(x) start_idx = len(decoder_tensors)-3 current_stride_list = encoder_stride_list[-4:-1][::-1] combine_tensors = NoduleSegDeepCombineBlock(decoder_tensors[start_idx:start_idx+4],kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer,kernel_size=kernel_size,norm_func=norm_func, activation_func=activation_func,num_filters=int(decoder_tensors[-1]._keras_shape[-1]/divide_ratio), padding=padding,atrous_rate=atrous_rate,conv_first=conv_first,merge_axis=merge_axis,upsample_choice=upsample_choice, stride_list=current_stride_list) aux_conv_tensor = 0 if seg_num_class==1: current_acti = 'sigmoid' else: current_acti = 'softmax' final_conv_func = Convolution3D(nb_filter = seg_num_class,kernel_size=final_kernel_size, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, border_mode='same',activation=current_acti,name='NoduleSegDecoder_conv3d_mins1') if not deep_supervision: block_prefix = 'NoduleSegDecoder_NotSupervision' conv_mins1 = final_conv_func(combine_tensors[0]) aux_conv_tensor = conv_mins1 ########### TODO:revise this part during non fix if ACF_choice: coarse_feature_map = conv_mins1 conv_mins1 = [] feature_map_class = CCB([coarse_feature_map,combine_tensors[0]],num_classes=seg_num_class) CAB_feature_map = CAB([coarse_feature_map,feature_map_class],num_classes=seg_num_class,activaion=current_acti) concatenate_feature_map = Concatenate()([CAB_feature_map,combine_tensors[0]]) aux_conv_tensor = concatenate_feature_map final_result = Convolution3D(nb_filter = seg_num_class,kernel_size=final_kernel_size, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, border_mode='same',activation=current_acti,name='NoduleSegDecoder_conv_ACF')(concatenate_feature_map) conv_mins1 = [coarse_feature_map,final_result] else: conv_mins1 = [] input_tensors = combine_tensors[1] rate = np.product(current_stride_list[:len(input_tensors)]) for idx in range(len(input_tensors)): # rate = 2 ** (2-idx) rate /= current_stride_list[idx] rate = int(rate) # rate = int(input_tensor._keras_shape[1]/input_tensors[idx]._keras_shape[1]) block_pre = 'NoduleSegDecoder_DeepSupervision_block%02d'%(idx+1) current_tensor = UpsampleBlock(input_tensors[idx],num_filters=seg_num_class,kernel_size=final_kernel_size,rate=rate, kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer, block_pre=block_pre,upsample_choice=upsample_choice,activation_func=current_acti) conv_mins1.append(current_tensor) if OCR_choice: features = combine_tensors[1][-1] coarse_seg_output = final_conv_func(conv_mins1[-1]) print ('feature shape and coarse_seg_output shape',features.shape,coarse_seg_output.shape) #### TODO proxy = SpatialGather(features,coarse_seg_output,upsample_ratio=2) ocr_output = OCRModule(features,proxy,norm_func=norm_func,activation_func=activation_func,num_filters=features._keras_shape[-1]*4, kernel_size=1,kernel_initializer=kernel_initializer,kernel_regularizer=kernel_regularizer) ocr_output = Convolution3D(nb_filter = seg_num_class,kernel_size=final_kernel_size, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, border_mode='same',activation=current_acti,name='NoduleSegDecoder_conv3d_ocr_output')(ocr_output) conv_mins1.append(ocr_output) output = conv_mins1 if aux_task: ### Do classification ''' 1. ''' temp01 = encoder_tensors[-1] temp02 = combine_tensors[1][-1] print ('shape of temp01 and temp02 ',temp01.shape,temp02.shape) pool_result = classification_layers(name='aux_task_pool_layer')(combine_tensors[1][-1]) ### Dense unit dense_result = Dense(num_classes*32)(pool_result) class_result = Dense(num_classes)(dense_result) output.append(class_result) model = Model(input_tensor,output) return model