unet_utils.py 3.21 KB
import sys
import keras
import numpy as np
import random
import tensorflow as tf
from keras import backend as K
from keras.engine import Layer,InputSpec
from keras import initializers, regularizers, constraints,optimizers
from keras.layers import Add,Input,Conv3D,Convolution3D,Dropout,UpSampling3D,Concatenate,Multiply,GlobalAveragePooling3D,Dense,Permute,Dot
from keras.layers.advanced_activations import LeakyReLU
from keras import Model
from keras.callbacks import TensorBoard,Callback
from keras.layers.merge import concatenate
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from keras.callbacks import ModelCheckpoint, LearningRateScheduler,TensorBoard
from keras.layers import Add,Input,Conv3D,Conv1D,Convolution3D,Dropout,UpSampling3D,Concatenate,Reshape,Softmax,Multiply

sys.path.append('../baseLayers/')
from InstanceNorm import InstanceNormalization

def CCB(input_tensors,num_classes=2):
    '''
    input to CCB is segmentation map P_coarse and last feature map F 
    
    '''
    #### shape should be (batch_size,D,H,W,num_class)
    P_coarse = input_tensors[0]
    ### shape should be (batch_size,D,H,W,num_channel)
    feature_map = input_tensors[1]

    num_channel = K.int_shape(feature_map)[-1]
    batch_size,depth,height,width,num_classes  = K.int_shape(P_coarse)
    
    ##### shape should be (batch_size,DHW,num_channel)
    reshape_coarse = Reshape(target_shape=(depth*height*width,num_classes))(P_coarse)

    compress_feature_map = Conv3D(filters = max(int(num_channel/2),1),kernel_size=1,padding='same',name='CCB_conv')(feature_map)

    
    #### shape should be (batch_size,D*H*W,num_classes)
    reshape_feature_map = Reshape(target_shape=(depth*height*width,K.int_shape(compress_feature_map)[-1]))(compress_feature_map)

    #### shape should be (batch_size,num_classes,DHW)
    transpose_feature_map = Permute([2,1])(reshape_feature_map)

    ###### shape should be (batch_size,num_classes,num_channel_in_greyscale_image)

    feature_map_class = Dot(axes=1)([reshape_feature_map,reshape_coarse])

    return feature_map_class


def CAB(input_tensors,num_classes=2,activaion='softmax'):
    '''
    input to CAB is segmentation map P_coarse and feature_map_class from CCB
    
    '''
#     def layer():
    P_coarse = input_tensors[0]
    ###### shape should be (batch_size,num_classes,num_channel_in_greyscale_image)
    feature_map_class = input_tensors[1]
    

    batch_size,depth,height,width,num_channel = K.int_shape(P_coarse)

    ###### shape should be (batch_size,DHW,num_channel_in_greyscale_image)
    reshape_coarse = Reshape(target_shape=(depth*height*width,num_channel))(P_coarse)

    ###### shape should be (batch_size,num_channel_in_greyscale_image,num_classes)
#     transpose_feature_map_class = tf.transpose(feature_map_class,perm=[0,2,1])
    transpose_feature_map_class = Permute([2,1])(feature_map_class)


    ###### shape should be (batch_size,DHW,num_classes)

    multi_result = Dot(axes=2)([reshape_coarse,feature_map_class])
    
    reshape_multi_result = Reshape(target_shape=(depth,height,width,K.int_shape(multi_result)[-1]))(multi_result) 
    
    return reshape_multi_result