1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from collections import OrderedDict
import torch.nn.functional as F
from .BasicModules import *
class ResNet(nn.Module):
def __init__(self,**kwargs):
super().__init__()
self.block = kwargs.get('block')
self.block_inplanes = kwargs.get('block_inplanes')
self.negative_slop = kwargs.get('negative_slop',0)
self.norm_choice = kwargs.get('norm_choice','BN')
self.kernel_size = kwargs.get('kernel_size',3)
self.num_stages = kwargs.get('num_stages',4)
self.pooling_ratios = kwargs.get('pooling_ratios')
self.repeat_time_list = kwargs.get('repeat_time_list')
self.dropout_rate = kwargs.get('dropout_rate',0)
self.num_class = kwargs.get('num_class',1)
self.shortcut_type = kwargs.get('shortcut_type','B')
self.widen_factor = kwargs.get('widen_factor',1.0)
self.conv_choice = kwargs.get('conv_choice','basic')
self.SCI_choice = kwargs.get('SCI_choice',False)
self.block_inplanes = [int(val*self.widen_factor) for val in self.block_inplanes]
self.block_in_channels = self.block_inplanes[0]
n_input_channel = 1
self.conv1 = ConvFunc3D(num_in_features=n_input_channel,num_out_features=self.block_in_channels,conv_choice=self.conv_choice,
kernel_size=self.kernel_size)
self.bn1 = BNReLU(negative_slop=self.negative_slop,norm_choice=self.norm_choice,num_features=self.block_in_channels)
################# is this part necessary?
#self.maxpool = nn.MaxPool3d(kernel_size = self.kernel_size,stride=self.pooling_ratios[0],padding=1)
layer_map_dict = {}
final_feature_num = 0
for stage_idx in range(self.num_stages):
final_feature_num = self.block_inplanes[stage_idx]
current_layer = self._make_layer(self.block,self.block_inplanes[stage_idx],self.repeat_time_list[stage_idx],
self.shortcut_type,self.pooling_ratios[stage_idx])
layer_map_dict['ResNet_%02d'%(stage_idx+1)] = current_layer
if self.dropout_rate>0:
current_dropout_func = nn.Dropout3d(p=self.dropout_rate)
else:
current_dropout_func = None
layer_map_dict['ResNet_%02d_dropout'%(stage_idx+1)] = current_dropout_func
self.res_layer_func = nn.Sequential(OrderedDict([(key,layer_map_dict[key]) for key in layer_map_dict.keys()]))
self.SCI = SCIModule(n_input_channel=final_feature_num,n_output_channel=final_feature_num,conv_choice = self.conv_choice)
self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
############ todo change this part
self.fc0 = nn.Linear(final_feature_num,int(final_feature_num/4))
self.dp0 = nn.Dropout3d(p=0.5)
self.fc = nn.Linear(int(final_feature_num/4),self.num_class)
def _downsample_basic_block(self,x,out_num_channel,stride):
print ('stride is ',stride)
out = F.avg_pool3d(x,kernel_size=1,stride=stride)
zero_pads = torch.zeros(out.size(0),out_num_channel-out.size(1),out.size(2),out.size(3),out.size(4))
if isinstance(out.data,torch.cuda.FloatTensor):
zero_pads = zero_pads.cuda()
out = torch.cat([out.data,zero_pads],dim=1)
return out
def _make_layer(self,block_func,planes,repeat_time,shortcut_type,stride=1):
downsample = None
if stride!=1 or self.block_in_channels!=planes*self.block.expansion:
if shortcut_type=='A':
######### downsample only
downsample = partial(self._downsample_basic_block,planes*block_func*expansion,stride)
else:
downsample = nn.Sequential(nn.Conv3d(in_channels=self.block_in_channels,out_channels=planes*block_func.expansion,kernel_size=1,stride=stride),
BNReLU(negative_slop=self.negative_slop,norm_choice=self.norm_choice,num_features=planes*block_func.expansion))
layers = []
layers.append(block_func(negative_slop = self.negative_slop,norm_choice=self.norm_choice,num_in_features=self.block_in_channels,
num_out_features=planes,kernel_size = self.kernel_size,downsample=downsample,stride=stride,conv_choice=self.conv_choice))
self.block_in_channels = planes*block_func.expansion
for layer_idx in range(1,repeat_time):
layers.append(block_func(negative_slop = self.negative_slop,norm_choice=self.norm_choice,num_in_features=self.block_in_channels,
num_out_features=planes,kernel_size = self.kernel_size,stride=1,conv_choice=self.conv_choice))
return nn.Sequential(*layers)
def forward(self,x,**kwargs):
labels = kwargs.get('labels',None)
x = self.conv1(x)
x = self.bn1(x)
# print ('shape of x is ',x.shape)
x = self.res_layer_func(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc0(x)
x = self.dp0(x)
x = self.fc(x)
return x