from collections import OrderedDict import torch.nn.functional as F from .BasicModules import * class VGG3D(nn.Module): def __init__(self,**kwargs): super(VGG3D,self).__init__() self.negative_slop = kwargs.get('negative_slop',0) self.norm_choice = kwargs.get('norm_choice','BN') self.base_filters = kwargs.get('base_filters',16) self.kernel_size = kwargs.get('kernel_size',3) self.num_stages = kwargs.get('num_stages',4) self.pooling_ratios = kwargs.get('pooling_ratios') self.repeat_time_list = kwargs.get('repeat_time_list') self.dropout_rate = kwargs.get('dropout_rate',0) self.num_class = kwargs.get('num_class',1) self.num_dense = kwargs.get('num_dense',256) self.num_task = kwargs.get('num_task',1) assert len(self.pooling_ratios)>self.num_stages assert len(self.repeat_time_list)>self.num_stages self.func_dict = {} final_features = 0 size_after_gap = 1 for idx in range(self.num_stages): in_filters = self.base_filters*(2**(idx-1)) if idx>0 else 1 out_filters = self.base_filters*(2**idx) final_features = out_filters current_func = ConvStage(negative_slope=self.negative_slop,norm_choice=self.norm_choice,num_in_features=in_filters, num_out_features=out_filters,kernel_size=self.kernel_size,repeat_times = self.repeat_time_list[idx]) if self.dropout_rate>0: current_dropout_func = nn.Dropout3d(p=self.dropout_rate) else: current_dropout_func = None self.func_dict['convStage_%02d_BlockFunc'%(idx+1)] = current_func if current_dropout_func is not None: self.func_dict['convStage_%02d_Dropout'%(idx+1)] = current_dropout_func self.func_dict['convStage_%02d_MP'%(idx+1)] = nn.MaxPool3d(kernel_size=self.pooling_ratios[idx]) self.vgg_basic_func = nn.Sequential(OrderedDict([(key,self.func_dict[key]) for key in self.func_dict.keys()])) self.fc1 = nn.AdaptiveAvgPool3d(output_size=size_after_gap) self.fc2 = torch.nn.Linear(in_features=final_features,out_features=self.num_class) self.fc3 = torch.nn.Linear(in_features=final_features,out_features=self.num_class) self.fc4 = torch.nn.Linear(in_features=final_features,out_features=self.num_class) # self.ac_final = torch.nn.Sigmoid() def forward(self,x): output = x output = self.vgg_basic_func(output) output = self.fc1(output) output = output.view(output.size(0),-1) outputs = [] if self.num_task>1: for idx in range(self.num_task): if idx==0: current_output = self.fc2(output) elif idx==1: current_output = self.fc3(output) else: current_output = self.fc4(output) outputs.append(current_output) return outputs else: output = self.fc2(output) return output class TempModel(nn.Module): def __init__(self,**kwargs): super(TempModel,self).__init__() self.conv = nn.Conv3d(in_channels=1,out_channels=16, kernel_size=3) def forward(self,x): print ('self.conv',self.conv) return self.conv(x)