import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict import sys sys.path.append('../') from baseLayers.BasicModules import * class SEB(nn.Module): def __init__(self,**kwargs): super().__init__() basic_model_params = kwargs.get('basic_model_params') num_in_filters = kwargs.get('num_in_filters') num_out_filters = kwargs.get('num_out_filters') self.upsample_rate = kwargs.get('upsample_rate') ''' layers 1. concateante ''' layer_parameters = {'num_in_features':num_in_filters,'num_out_features':num_out_filters} layer_parameters.update(basic_model_params) # print ('layer_parameters is',layer_parameters) self.conv_func = ConvBlock(**layer_parameters) if self.upsample_rate>1: self.up_layer = nn.Upsample(scale_factor=self.upsample_rate) def forward(self,x,x_list): if len(x_list)==1: x1 = x_list[0] else: x1 = torch.cat(x_list,1) x1 = self.conv_func(x1) if self.upsample_rate>1: x1 = self.up_layer(x1) output = torch.mul(x,x1) return output class NoduleSegDecoderBlockV2(nn.Module): def __init__(self,**kwargs): super().__init__() basic_model_params = kwargs.get('basic_model_params') stage_idx = kwargs.get('stage_idx') num_units = kwargs.get('num_units') SEB_choice = kwargs.get('SEB_choice',False) num_in_filters = kwargs.get('num_in_filters') num_out_filters = kwargs.get('num_out_filters') upsample_rate = kwargs.get('upsample_rate') layer_parameters = {'num_in_features':num_in_filters,'num_out_features':num_out_filters} layer_parameters.update(basic_model_params) self.upsample_layer = nn.Upsample(scale_factor=upsample_rate) self.conv0 = ConvBlock(**layer_parameters) layer_map_dict = {} for layer_idx in range(1,num_units): layer_parameters['num_in_features'] = 2*num_out_filters if layer_idx==1 else num_out_filters layer_map_dict['NoduleSegDecoderBlock_stage%02d_layer%02d_conv'%(stage_idx,layer_idx+1)] = ConvBlock(**layer_parameters) self.stage_layers = nn.Sequential(OrderedDict([(key,layer_map_dict[key]) for key in layer_map_dict.keys()])) def forward(self,x,encoder_tensors): x0 = self.upsample_layer(x) x0 = self.conv0(x0) encoder_target_tensors = encoder_tensors[0] merge_x = torch.cat([x0,encoder_target_tensors],1) output = self.stage_layers(merge_x) return output+x0 class NoduleSegDecoderBlock(nn.Module): def __init__(self,**kwargs): super().__init__() basic_model_params = kwargs.get('basic_model_params') stage_idx = kwargs.get('stage_idx') num_units = kwargs.get('num_units') SEB_choice = kwargs.get('SEB_choice',False) num_in_filters = kwargs.get('num_in_filters') num_out_filters = kwargs.get('num_out_filters') upsample_rate = kwargs.get('upsample_rate') SEB_params = kwargs.get('SEB_params') self.SEB_upsample_stride_list = SEB_params[0] self.upsample_rate = upsample_rate SEB_in_filters = SEB_params[1] layer_parameters = {'num_in_features':num_in_filters,'num_out_features':num_out_filters} layer_parameters.update(basic_model_params) self.upsample_layer = nn.Upsample(scale_factor=upsample_rate) self.upsamle_layer02 = nn.Upsample(scale_factor=2) self.conv0 = ConvBlock(**layer_parameters) if SEB_choice: self.SEB_layer = SEB(basic_model_params=basic_model_params,num_in_filters=SEB_in_filters, num_out_filters=num_out_filters,upsample_rate=upsample_rate) layer_map_dict = {} for layer_idx in range(1,num_units): layer_parameters['num_in_features'] = 2*num_out_filters if layer_idx==1 else num_out_filters layer_map_dict['NoduleSegDecoderBlock_stage%02d_layer%02d_conv'%(stage_idx,layer_idx+1)] = ConvBlock(**layer_parameters) self.stage_layers = nn.Sequential(OrderedDict([(key,layer_map_dict[key]) for key in layer_map_dict.keys()])) self.SEB_choice = SEB_choice def forward(self,x,encoder_tensors): SEB_tensors = [] x0 = self.upsample_layer(x) x0 = self.conv0(x0) for idx in range(1,len(encoder_tensors)): current_stride = self.SEB_upsample_stride_list[idx-1] x1 = encoder_tensors[idx] while current_stride>=2: x1 = self.upsamle_layer02(x1) current_stride/=2 SEB_tensors.append(x1) if self.SEB_choice: encoder_target_tensors = self.SEB_layer(encoder_tensors[0],SEB_tensors) else: encoder_target_tensors = encoder_tensors[0] merge_x = torch.cat([x0,encoder_target_tensors],1) output = self.stage_layers(merge_x) return output+x0 class NoduleSegDeepCombineBlock(nn.Module): def __init__(self,**kwargs): super().__init__() basic_model_params = kwargs.get('basic_model_params') self.upsample_rates = kwargs.get('upsample_rates') num_in_features = kwargs.get('num_in_features') num_out_features = kwargs.get('num_out_features') self.deep_combine = kwargs.get('deep_combine') seg_output_featurs = kwargs.get('seg_num_class',1) layer_parameters = {'num_in_features':num_in_features[2],'num_out_features':num_out_features} layer_parameters.update(basic_model_params) self.conv0 = ConvBlock(**layer_parameters) layer_parameters['num_in_features'] = num_in_features[1] self.conv1 = ConvBlock(**layer_parameters) layer_parameters['num_in_features'] = num_in_features[0] self.conv2 = ConvBlock(**layer_parameters) output_layer_parameters = layer_parameters output_layer_parameters['num_in_features'] = num_out_features output_layer_parameters['num_out_features'] = seg_output_featurs kernel_size = 1 if self.deep_combine: self.final_conv = nn.Conv3d(in_channels=num_out_features,out_channels=seg_output_featurs, kernel_size=kernel_size,padding=[int(kernel_size/2) for _ in range(3)]) else: self.final_conv1 = nn.Conv3d(in_channels=num_out_features,out_channels=seg_output_featurs, kernel_size=kernel_size,padding=[int(kernel_size/2) for _ in range(3)]) self.final_conv2 = nn.Conv3d(in_channels=num_out_features,out_channels=seg_output_featurs, kernel_size=kernel_size,padding=[int(kernel_size/2) for _ in range(3)]) self.final_conv3 = nn.Conv3d(in_channels=num_out_features,out_channels=seg_output_featurs, kernel_size=kernel_size,padding=[int(kernel_size/2) for _ in range(3)]) def forward(self,x2,x1,x0): x0 = self.conv0(x0) upsample_x0 = nn.Upsample(scale_factor=self.upsample_rates[0])(x0) x1 = self.conv1(x1) upsample_x1 = nn.Upsample(scale_factor=self.upsample_rates[1])(x1+upsample_x0) x2 = self.conv2(x2) # output = x2 output = upsample_x1+x2 # if self.deep_combine: # output = [self.final_conv(output)] # else: # # output1 = self.final_conv1(nn.Upsample(scale_factor=self.upsample_rates[0]*self.upsample_rates[1])(x0)) # # output2 = self.final_conv2(nn.Upsample(scale_factor=self.upsample_rates[1])(x1)) # # output3 = self.final_conv3(x2) output1 = self.final_conv1(nn.Upsample(scale_factor=self.upsample_rates[0]*self.upsample_rates[1])(x0)) output2 = self.final_conv2(upsample_x1) output3 = self.final_conv3(output) return output1,output2,output3 class NoduleSegDecoder_proxima(nn.Module): def __init__(self,**kwargs): super().__init__() model_parameters = kwargs.get('model_params') self.model_parameters = model_parameters basic_model_params = {'conv_func':model_parameters.get('conv_func','basic'),'norm_choice':model_parameters.get('norm_choice'), 'activation_func':model_parameters.get('activation_func',None),'kernel_size':model_parameters.get('kernel_size',3), 'norm_axis':model_parameters.get('norm_axis'),'negative_slop':model_parameters.get('negative_slop'), 'padding_func':model_parameters.get('padding_func')} ''' important parameters 1. SEB 2. deep supervision 3. num_units_list 4. seg_num_class 5. divide_ratio 6. encoder parameters ''' encoder_model = model_parameters.get('encoder_model') self.real_strides = encoder_model.GetRealStrides()[1:] # self.real_strides.append(1) output_features = encoder_model.GetOutputFeatures() self.output_features = output_features self.encoder_model = encoder_model self.tensor_name_list = kwargs.get('tensor_name_list') num_stages = len(self.real_strides) self.calculateStageShape() ''' funcs 1 . num_stages NoduleSegDecoderBlock s 2. NoduleSegDecoderBlock(deep supervision) ''' divide_ratio = kwargs.get('divide_ratio',4) deepcombine_output_features = int(self.output_features[0]/divide_ratio) stage_idx = 1 self.stage_1_func = NoduleSegDecoderBlock(basic_model_params=basic_model_params,stage_idx=stage_idx, num_units=model_parameters['num_units'][stage_idx-1], SEB_choice=model_parameters['SEB_choice'], num_in_filters=output_features[-stage_idx], num_out_filters=output_features[-(stage_idx+1)], upsample_rate=self.real_strides[-(stage_idx)], SEB_params=self.calculateSEBParameters(stage_idx)) stage_idx = 2 self.stage_2_func = NoduleSegDecoderBlock(basic_model_params=basic_model_params,stage_idx=stage_idx, num_units=model_parameters['num_units'][stage_idx-1], SEB_choice=model_parameters['SEB_choice'], num_in_filters=output_features[-stage_idx], num_out_filters=output_features[-(stage_idx+1)], upsample_rate=self.real_strides[-(stage_idx)], SEB_params=self.calculateSEBParameters(stage_idx)) stage_idx = 3 self.stage_3_func = NoduleSegDecoderBlock(basic_model_params=basic_model_params,stage_idx=stage_idx, num_units=model_parameters['num_units'][stage_idx-1], SEB_choice=model_parameters['SEB_choice'], num_in_filters=output_features[-stage_idx], num_out_filters=output_features[-(stage_idx+1)], upsample_rate=self.real_strides[-(stage_idx)], SEB_params=self.calculateSEBParameters(stage_idx)) self.combine_func = NoduleSegDeepCombineBlock(basic_model_params=basic_model_params,upsample_rates=self.real_strides[:2][::-1], num_in_features=self.output_features[-4:-1], num_out_features=deepcombine_output_features, deep_combine=model_parameters['deep_combine'], seg_num_class=model_parameters['seg_num_class']) def calculateStageShape(self): base_size = 1 self.stage_sizes = [base_size] for val in self.real_strides[::-1]: self.stage_sizes.append(self.stage_sizes[-1]*val) self.stage_sizes.reverse() def calculateSEBParameters(self,stage_idx): ''' for stage_idx in range(1-4) included tensors are 2/3/4/5 assume stride = [val1,val2,val3,val4,val5] ''' target_shape = self.stage_sizes[-stage_idx] strides = [] for idx in range(stage_idx): strides.append(target_shape/self.stage_sizes[-(idx+1)]) strides.reverse() num_output_features = sum(self.output_features[-(stage_idx):]) return strides,num_output_features def getMulti(self,vals): base = 1 for val in vals: base*=val return base def forward(self,x): encoder_tensors = self.encoder_model(x) x0 = encoder_tensors[-1] x1 = self.stage_1_func(x0,encoder_tensors[-2:]) x2 = self.stage_2_func(x1,encoder_tensors[-3:]) x3 = self.stage_3_func(x2,encoder_tensors[-4:]) # #################### deep combine output= self.combine_func(x3,x2,x1) return output