from collections import OrderedDict import torch.nn.functional as F from .BasicModules import * class FSe_ResNet(nn.Module): def __init__(self,**kwargs): super().__init__() self.block = kwargs.get('block') self.block_inplanes = kwargs.get('block_inplanes') self.negative_slop = kwargs.get('negative_slop',0) self.norm_choice = kwargs.get('norm_choice','BN') self.kernel_size = kwargs.get('kernel_size',3) self.num_stages = kwargs.get('num_stages',4) self.pooling_ratios = kwargs.get('pooling_ratios') self.repeat_time_list = kwargs.get('repeat_time_list') self.dropout_rate = kwargs.get('dropout_rate',0) self.num_class = kwargs.get('num_class',1) self.shortcut_type = kwargs.get('shortcut_type','B') self.widen_factor = kwargs.get('widen_factor',1.0) self.block_inplanes = [int(val*self.widen_factor) for val in self.block_inplanes] self.block_in_channels = self.block_inplanes[0] n_input_channel = 1 self.conv1 = nn.Conv3d(n_input_channel,self.block_in_channels,kernel_size=self.kernel_size,padding=[int(self.kernel_size/2) for _ in range(3)]) self.bn1 = BNReLU(negative_slop=self.negative_slop,norm_choice=self.norm_choice,num_features=self.block_in_channels) ################# is this part necessary? #self.maxpool = nn.MaxPool3d(kernel_size = self.kernel_size,stride=self.pooling_ratios[0],padding=1) layer_map_dict = {} final_feature_num = 0 for stage_idx in range(self.num_stages): final_feature_num = self.block_inplanes[stage_idx] current_layer = self._make_layer(self.block,self.block_inplanes[stage_idx],self.repeat_time_list[stage_idx], self.shortcut_type,self.pooling_ratios[stage_idx]) layer_map_dict['ResNet_%02d'%(stage_idx+1)] = current_layer if self.dropout_rate>0: current_dropout_func = nn.Dropout3d(p=self.dropout_rate) else: current_dropout_func = None layer_map_dict['ResNet_%02d_dropout'%(stage_idx+1)] = current_dropout_func self.res_layer_func = nn.Sequential(OrderedDict([(key,layer_map_dict[key]) for key in layer_map_dict.keys()])) self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) ############ todo change this part self.fc0 = nn.Linear(final_feature_num,int(final_feature_num/4)) self.dp0 = nn.Dropout3d(p=0.5) self.fc = nn.Linear(int(final_feature_num/4),self.num_class) self.fc2 = nn.Linear(int(final_feature_num/2),1) def _downsample_basic_block(self,x,out_num_channel,stride): print ('stride is ',stride) out = F.avg_pool3d(x,kernel_size=1,stride=stride) zero_pads = torch.zeros(out.size(0),out_num_channel-out.size(1),out.size(2),out.size(3),out.size(4)) if isinstance(out.data,torch.cuda.FloatTensor): zero_pads = zero_pads.cuda() out = torch.cat([out.data,zero_pads],dim=1) return out def _make_layer(self,block_func,planes,repeat_time,shortcut_type,stride=1): downsample = None if stride!=1 or self.block_in_channels!=planes*self.block.expansion: if shortcut_type=='A': ######### downsample only downsample = partial(self._downsample_basic_block,planes*block_func*expansion,stride) else: downsample = nn.Sequential(nn.Conv3d(in_channels=self.block_in_channels,out_channels=planes*block_func.expansion,kernel_size=1,stride=stride), BNReLU(negative_slop=self.negative_slop,norm_choice=self.norm_choice,num_features=planes*block_func.expansion)) layers = [] layers.append(block_func(negative_slop = self.negative_slop,norm_choice=self.norm_choice,num_in_features=self.block_in_channels, num_out_features=planes,kernel_size = self.kernel_size,downsample=downsample,stride=stride)) self.block_in_channels = planes*block_func.expansion for layer_idx in range(1,repeat_time): layers.append(block_func(negative_slop = self.negative_slop,norm_choice=self.norm_choice,num_in_features=self.block_in_channels, num_out_features=planes,kernel_size = self.kernel_size,stride=1)) return nn.Sequential(*layers) def forward(self,x_list): x,x2 = x_list x = self.conv1(x) x = self.bn1(x) # print ('shape of x is ',x.shape) x = self.res_layer_func(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc0(x) feature0 = x x = self.dp0(x) x = self.fc(x) x2 = self.conv1(x2) x2 = self.bn1(x2) x2 = self.res_layer_func(x2) x2 = self.avgpool(x2) x2 = x2.view(x2.size(0), -1) x2 = self.fc0(x2) feature1 = x2 x2 = self.dp0(x2) x2 = self.fc(x2) conc_output = self.fc2(torch.cat([feature0,feature1],dim=1)) return x,x2,feature0,feature1,conc_output