ResNet.py 5.01 KB
from collections import OrderedDict
import torch.nn.functional as F
from .BasicModules import *

class ResNet(nn.Module):
    def __init__(self,**kwargs):
        super().__init__()
            
        self.block = kwargs.get('block')
        self.block_inplanes = kwargs.get('block_inplanes')
        self.negative_slop = kwargs.get('negative_slop',0)
        self.norm_choice = kwargs.get('norm_choice','BN')
        self.kernel_size = kwargs.get('kernel_size',3)
        self.num_stages = kwargs.get('num_stages',4)
        self.pooling_ratios = kwargs.get('pooling_ratios')
        self.repeat_time_list = kwargs.get('repeat_time_list')
        self.dropout_rate = kwargs.get('dropout_rate',0)
        self.num_class = kwargs.get('num_class',1)
        self.shortcut_type = kwargs.get('shortcut_type','B')
        self.widen_factor = kwargs.get('widen_factor',1.0)
        self.conv_choice = kwargs.get('conv_choice','basic')
        self.SCI_choice = kwargs.get('SCI_choice',False)

        self.block_inplanes = [int(val*self.widen_factor) for val in self.block_inplanes]
        
        self.block_in_channels = self.block_inplanes[0]

        n_input_channel = 1

        self.conv1 = ConvFunc3D(num_in_features=n_input_channel,num_out_features=self.block_in_channels,conv_choice=self.conv_choice,
                                kernel_size=self.kernel_size)
        self.bn1 = BNReLU(negative_slop=self.negative_slop,norm_choice=self.norm_choice,num_features=self.block_in_channels)

        ################# is this part necessary?
        #self.maxpool = nn.MaxPool3d(kernel_size = self.kernel_size,stride=self.pooling_ratios[0],padding=1)
        layer_map_dict = {}
        final_feature_num = 0
        for stage_idx in range(self.num_stages):
            final_feature_num = self.block_inplanes[stage_idx]
            current_layer = self._make_layer(self.block,self.block_inplanes[stage_idx],self.repeat_time_list[stage_idx],
                                            self.shortcut_type,self.pooling_ratios[stage_idx])
            layer_map_dict['ResNet_%02d'%(stage_idx+1)] = current_layer
            if self.dropout_rate>0:
                current_dropout_func = nn.Dropout3d(p=self.dropout_rate)
            else:
                current_dropout_func = None
            layer_map_dict['ResNet_%02d_dropout'%(stage_idx+1)] = current_dropout_func

        self.res_layer_func = nn.Sequential(OrderedDict([(key,layer_map_dict[key]) for key in layer_map_dict.keys()]))
        self.SCI = SCIModule(n_input_channel=final_feature_num,n_output_channel=final_feature_num,conv_choice = self.conv_choice)
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))

        ############ todo change this part
        self.fc0 = nn.Linear(final_feature_num,int(final_feature_num/4))
        self.dp0 = nn.Dropout3d(p=0.5)
        self.fc = nn.Linear(int(final_feature_num/4),self.num_class)


    def _downsample_basic_block(self,x,out_num_channel,stride):
        print ('stride is ',stride)
        out = F.avg_pool3d(x,kernel_size=1,stride=stride)
        zero_pads = torch.zeros(out.size(0),out_num_channel-out.size(1),out.size(2),out.size(3),out.size(4))
        if isinstance(out.data,torch.cuda.FloatTensor):
            zero_pads = zero_pads.cuda()
        
        out = torch.cat([out.data,zero_pads],dim=1)
        return out

    
    def _make_layer(self,block_func,planes,repeat_time,shortcut_type,stride=1):
        downsample = None
        if stride!=1 or self.block_in_channels!=planes*self.block.expansion:
            if shortcut_type=='A':
                ######### downsample only
                downsample = partial(self._downsample_basic_block,planes*block_func*expansion,stride)
            else:
                downsample = nn.Sequential(nn.Conv3d(in_channels=self.block_in_channels,out_channels=planes*block_func.expansion,kernel_size=1,stride=stride),
                                            BNReLU(negative_slop=self.negative_slop,norm_choice=self.norm_choice,num_features=planes*block_func.expansion))
        layers = []

        layers.append(block_func(negative_slop = self.negative_slop,norm_choice=self.norm_choice,num_in_features=self.block_in_channels,
                            num_out_features=planes,kernel_size = self.kernel_size,downsample=downsample,stride=stride,conv_choice=self.conv_choice))
        self.block_in_channels = planes*block_func.expansion
        for layer_idx in range(1,repeat_time):
            layers.append(block_func(negative_slop = self.negative_slop,norm_choice=self.norm_choice,num_in_features=self.block_in_channels,
                            num_out_features=planes,kernel_size = self.kernel_size,stride=1,conv_choice=self.conv_choice))
        return nn.Sequential(*layers)

    def forward(self,x,**kwargs):
        labels = kwargs.get('labels',None)
        x = self.conv1(x)
        x = self.bn1(x)
        # print ('shape of x is ',x.shape)
        x = self.res_layer_func(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc0(x)
        x = self.dp0(x)
        x = self.fc(x)
        return x