import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict import sys # sys.path.append('../') from ..baseLayers.BasicModules import * class NoduleSegConvBlock_proxima(nn.Module): def __init__(self,**kwargs): super().__init__() basic_parameters = kwargs.get('basic_parameters') model_parameters = kwargs.get('stage_parameters') self.model_parameters = model_parameters stage_idx = model_parameters['stage_idx'] self.stage_idx = stage_idx layer_map_dict = {} num_in_features = model_parameters['num_in_features'] num_out_features = model_parameters['num_filters'] layer_parameters = {'num_out_features':num_out_features, 'atrous_rate':model_parameters['atrous_rate'], 'ASPP_mode':model_parameters['ASPP_mode']} layer_parameters.update(basic_parameters) layer_parameters['num_out_features'] = num_out_features self.first_layer_of_current_stage_dict = {} for layer_idx in range(self.model_parameters['repeat_time']): conv_stride = int(model_parameters.get('stride',1)) if layer_idx==0 else 1 if layer_idx==0: self.conv_stride = conv_stride layer_parameters['stride'] = conv_stride layer_parameters['num_in_features'] = num_in_features if model_parameters.get('ASPP_choice',-1)>0: if layer_idx==0: self.first_layer_of_current_stage_dict['NoduleSegConvBlock_proxima_stage%02d_layer%02d_ASPP'%(stage_idx,layer_idx+1)] = ASPP(**layer_parameters) else: layer_map_dict['NoduleSegConvBlock_proxima_stage%02d_layer%02d_ASPP'%(stage_idx,layer_idx+1)] = ASPP(**layer_parameters) layer_parameters['stride'] = 1 if model_parameters['ASPP_mode']==1: layer_parameters['num_in_features'] = num_in_features+num_out_features else: layer_parameters['num_in_features'] = num_out_features if layer_idx==0 : self.first_layer_of_current_stage_dict['NoduleSegConvBlock_proxima_stage%02d_layer%02d_conv'%(stage_idx,layer_idx+1)] = ConvBlock(**layer_parameters) else: layer_map_dict['NoduleSegConvBlock_proxima_stage%02d_layer%02d_conv'%(stage_idx,layer_idx+1)] = ConvBlock(**layer_parameters) num_in_features = num_out_features self.first_layer_of_current_stage = nn.Sequential(OrderedDict([(key,self.first_layer_of_current_stage_dict[key]) for key in self.first_layer_of_current_stage_dict.keys()])) self.stage_layers = nn.Sequential(OrderedDict([(key,layer_map_dict[key]) for key in layer_map_dict.keys()])) def forward(self,x): x1 = self.first_layer_of_current_stage(x) output = self.stage_layers(x1) return output+x1 def GetEnlarge(self): return self.conv_stride class NoduleSegEncoder_proxima(nn.Module): def __init__(self,**kwargs): super().__init__() r''' ''' model_parameters = kwargs.get('model_params') self.model_parameters = model_parameters basic_model_params = {'conv_func':model_parameters.get('conv_func','basic'),'norm_choice':model_parameters.get('norm_choice'), 'activation_func':model_parameters.get('activation_func',None),'kernel_size':model_parameters.get('kernel_size',3), 'norm_axis':model_parameters.get('norm_axis'),'negative_slop':model_parameters.get('negative_slop'), 'padding_func':model_parameters.get('padding_func','zero')} ASPP_choice = model_parameters.get('ASPP_choice',1) dropout_rate = model_parameters.get('dropout_rate',0.0) num_blocks = model_parameters.get('num_blocks',5) strides = model_parameters.get('strides') atrous_rates = model_parameters.get('atrous_rates') base_filters = model_parameters.get('base_filters',32) num_classes = model_parameters.get('num_classes',1) ASPP_block_list = model_parameters.get('ASPP_block_list',[num_blocks-1]) repeat_times_list = model_parameters.get("repeat_times_list",3) ASPP_pooling_choice = model_parameters.get('ASPP_pooling_choice',True) do_cls = model_parameters.get('do_cls',False) repeat_times_list = [repeat_times_list for _ in range(num_blocks)] if type(repeat_times_list)==int else repeat_times_list self.real_stride_ratios = [] self.output_features = [] num_in_features = 1 layers = [] for idx in range(num_blocks): layer_map_dict = {} current_ASPP_choice = True if (idx in ASPP_block_list) and ASPP_choice else False current_layer_parameters = {'atrous_rate':atrous_rates[idx],'stride':strides[idx],'num_filters':base_filters*(2**idx), 'repeat_time':repeat_times_list[idx],'ASPP_choice':current_ASPP_choice, 'stage_idx':idx+1,'num_in_features':num_in_features, 'ASPP_pooling_choice':ASPP_pooling_choice,"ASPP_mode":ASPP_choice} current_layer = NoduleSegConvBlock_proxima(basic_parameters=basic_model_params,stage_parameters=current_layer_parameters) current_enlarge_ratio = current_layer.GetEnlarge() self.real_stride_ratios.append(current_enlarge_ratio) layer_map_dict['NoduleSegEncoder_proxima_stage%02d_func'%(idx+1)] = current_layer if dropout_rate>0: layer_map_dict['NoduleSegEncoder_proxima_stage%02d_dropout'%(idx+1)] = nn.Dropout3d(p=dropout_rate) num_in_features = base_filters*(2**idx) self.output_features.append(num_in_features) layers.append(nn.Sequential(OrderedDict([(key,layer_map_dict[key]) for key in layer_map_dict.keys()]))) self.encoder_layer0 = layers[0] self.encoder_layer1 = layers[1] self.encoder_layer2 = layers[2] self.encoder_layer3 = layers[3] if len(layers)>3 else None self.encoder_layer4 = layers[4] if len(layers)>4 else None if do_cls: self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) self.fc0 = nn.Linear(self.output_features[-1],1) self.do_cls = do_cls def GetRealStrides(self): return self.real_stride_ratios def GetOutputFeatures(self): return self.output_features def forward(self,x): x1 = self.encoder_layer0(x) x2 = self.encoder_layer1(x1) x3 = self.encoder_layer2(x2) output = [x1,x2,x3] if self.encoder_layer3 is not None: x4 = self.encoder_layer3(x3) output.append(x4) if self.encoder_layer4 is not None: x5 = self.encoder_layer4(x4) output.append(x5) if self.do_cls: gap_x = self.avgpool(output[-1]) gap_x = gap_x.view(x.size(0), -1) output = self.fc0(gap_x) # print ('output is',output) return output if __name__ == "__main__": cfg_path = '../../files/modelParams/model_parameters.json' model_name = "NoduleSegEncoder_proxima" sys.path.append('/hdd/disk6/zrx/projects/Pulmonary_Nodule/develop_code/grouplung/lungNoduleDensityClassification/nodulecls/utils/') from RWconfig import LoadJson cfg = LoadJson(cfg_path)[model_name] from torchsummary import summary model = NoduleSegEncoder_proxima(model_params=cfg) # summary(model.cuda(),(1,24,24,24)) # total_params = sum(p.numel() for p in model.parameters()) # print ('output of model is',model.output[0])