import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict import sys # sys.path.append('../') from ..baseLayers.BasicModules import * from ..backbones.NoduleSegEncoder import NoduleSegConvBlock_proxima class XnetEncoder(nn.Module): def __init__(self,**kwargs): super().__init__() r''' ''' model_parameters = kwargs.get('model_params') print ('model_parameters',model_parameters) self.model_parameters = model_parameters basic_model_params = {'conv_func':model_parameters.get('conv_func','basic'),'norm_choice':model_parameters.get('norm_choice'), 'activation_func':model_parameters.get('activation_func',None),'kernel_size':model_parameters.get('kernel_size',3), 'norm_axis':model_parameters.get('norm_axis'),'negative_slop':model_parameters.get('negative_slop')} self.basic_model_params = basic_model_params ASPP_choice = model_parameters.get('ASPP_choice',1) self.ASPP_choice = ASPP_choice dropout_rate = model_parameters.get('dropout_rate',0.0) num_blocks = model_parameters.get('num_blocks',5) strides = model_parameters.get('strides') atrous_rates = model_parameters.get('atrous_rates') base_filters = model_parameters.get('base_filters',32) num_classes = model_parameters.get('num_classes',1) ASPP_block_list = model_parameters.get('ASPP_block_list',[num_blocks-1]) repeat_times_list = model_parameters.get("repeat_times_list",3) ASPP_pooling_choice = model_parameters.get('ASPP_pooling_choice',True) do_cls = model_parameters.get('do_cls',False) self.pre_conv_num = model_parameters.get('pred_conv_num',2) self.pre_conv_stride = model_parameters.get('pre_conv_stride',1) repeat_times_list = [repeat_times_list for _ in range(num_blocks)] if type(repeat_times_list)==int else repeat_times_list self._GetPreConvLayer(base_filters) self.real_stride_ratios = [1] self.output_features = [base_filters] num_in_features = base_filters layers = [] for idx in range(num_blocks): layer_map_dict = {} current_ASPP_choice = True if (idx in ASPP_block_list) and (ASPP_choice>0) else False temp_val = base_filters*(2**(idx+1)) current_layer_parameters = {'atrous_rate':atrous_rates[idx],'stride':strides[idx],'num_filters':base_filters*(2**(idx+1)), 'repeat_time':repeat_times_list[idx],'ASPP_choice':current_ASPP_choice, 'stage_idx':idx+1,'num_in_features':num_in_features, 'ASPP_pooling_choice':ASPP_pooling_choice,"ASPP_mode":ASPP_choice} current_layer = NoduleSegConvBlock_proxima(basic_parameters=basic_model_params,stage_parameters=current_layer_parameters) current_enlarge_ratio = current_layer.GetEnlarge() self.real_stride_ratios.append(current_enlarge_ratio) layer_map_dict['NoduleSegEncoder_proxima_stage%02d_func'%(idx+1)] = current_layer if dropout_rate>0: layer_map_dict['NoduleSegEncoder_proxima_stage%02d_dropout'%(idx+1)] = nn.Dropout3d(p=dropout_rate) num_in_features = base_filters*(2**(idx+1)) self.output_features.append(num_in_features) layers.append(nn.Sequential(OrderedDict([(key,layer_map_dict[key]) for key in layer_map_dict.keys()]))) self.encoder_layer0 = layers[0] self.encoder_layer1 = layers[1] self.encoder_layer2 = layers[2] self.encoder_layer3 = layers[3] if len(layers)>3 else None self.encoder_layer4 = layers[4] if len(layers)>4 else None if do_cls: self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) print ('='*60) print ('self.output_features[-1] is',self.output_features[-1]) self.fc0 = nn.Linear(self.output_features[-1],1) self.do_cls = do_cls def _GetPreConvLayer(self,base_filters): current_params = self.basic_model_params current_params['num_in_features'] = 1 current_params['num_out_features'] = base_filters current_params['repeat_times'] = self.pre_conv_num current_params['stride'] = self.pre_conv_stride self.preLayer = ConvStage(**current_params) def GetRealStrides(self): return self.real_stride_ratios def GetOutputFeatures(self): return self.output_features def forward(self,x): ''' 1. pre-conv module(1-2 convolutions without stride) ''' x0 = self.preLayer(x) x1 = self.encoder_layer0(x0) x2 = self.encoder_layer1(x1) x3 = self.encoder_layer2(x2) output = [x0,x1,x2,x3] if self.encoder_layer3 is not None: x4 = self.encoder_layer3(x3) output.append(x4) if self.encoder_layer4 is not None: x5 = self.encoder_layer4(x4) output.append(x5) if self.do_cls: gap_x = self.avgpool(output[-1]) gap_x = gap_x.view(x.size(0), -1) output = self.fc0(gap_x) return output if __name__ == "__main__": cfg_path = '../../files/modelParams/model_parameters.json' model_name = "NoduleSegEncoder_proxima" sys.path.append('/hdd/disk6/zrx/projects/Pulmonary_Nodule/develop_code/grouplung/lungNoduleDensityClassification/nodulecls/utils/') from RWconfig import LoadJson cfg = LoadJson(cfg_path)[model_name] from torchsummary import summary model = NoduleSegEncoder_proxima(model_params=cfg) # summary(model.cuda(),(1,24,24,24)) # total_params = sum(p.numel() for p in model.parameters()) # print ('output of model is',model.output[0])