import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict def weight_init(m): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm3d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class ConvStage(nn.Module): def __init__(self,**kwargs): super().__init__() self.negative_slop = kwargs.get('negative_slop',0) self.norm_choice = kwargs.get('norm_choice','BN') self.num_in_features = kwargs.get('num_in_features') self.num_out_features = kwargs.get('num_out_features') self.kernel_size = kwargs.get('kernel_size',3) self.repeat_times = kwargs.get('repeat_times',3) padding_func = kwargs.get('padding_func','zero') stride = kwargs.get('stride',1) self.func_list = {} for idx in range(self.repeat_times): if idx==0: current_func = ConvBlock(negative_slop = self.negative_slop,norm_choice = self.norm_choice,num_in_features=self.num_in_features, num_out_features=self.num_out_features,kernel_size=self.kernel_size, padding_func=padding_func,stride=stride) else: current_func = ConvBlock(negative_slop = self.negative_slop,norm_choice = self.norm_choice,num_in_features=self.num_out_features, num_out_features=self.num_out_features,kernel_size=self.kernel_size, padding_func=padding_func) self.func_list['convBlock_%d'%idx]=current_func self.convStage_func = nn.Sequential(OrderedDict([(key,self.func_list[key]) for key in self.func_list.keys()])) def forward(self,x): output = x return self.convStage_func(x) class ConvBlock(nn.Module): def __init__(self,**kwargs): super().__init__() self.negative_slop = kwargs.get('negative_slop',0) self.norm_choice = kwargs.get('norm_choice','BN') self.num_in_features = kwargs.get('num_in_features') self.num_out_features = kwargs.get('num_out_features') self.kernel_size = kwargs.get('kernel_size',3) atrous_rate= kwargs.get('atrous_rate',1) stride = kwargs.get('stride',1) self.atrous_rate,self.stride = atrous_rate,stride padding_func = kwargs.get('padding_func','zero') # padding_vals = [int(self.kernel_size/2)*self.atrous_rate*2 for _ in range(3)] # padding_vals = [0]+padding_vals+[0] # self.padding_func = nn.ReplicationPad3d(padding_vals) if padding_func!='zero' else None if padding_func == 'zero': padding_size = [int(self.kernel_size/2)*self.atrous_rate for _ in range(3)] else: padding_size = [0 for _ in range(3)] self.conv = nn.Conv3d(in_channels=self.num_in_features,out_channels=self.num_out_features, kernel_size=self.kernel_size,padding=[int(self.kernel_size/2)*self.atrous_rate for _ in range(3)], dilation=atrous_rate,stride=stride) self.post_func = BNReLU(negative_slop=self.negative_slop,norm_choice=self.norm_choice, num_features=self.num_out_features) def forward(self,x): output = x # if self.padding_func is not None: # print ('='*60) # print ('padding func is ',self.padding_func) # output = self.padding_func(output) output = self.conv(output) output = self.post_func(output) return output class BNReLU(nn.Module): def __init__(self,**kwargs): super().__init__() self.negative_slop = kwargs.get('negative_slop',0) self.norm_choice = kwargs.get('norm_choice','BN') self.num_features = kwargs.get('num_features') if self.negative_slop<=0: self.relu = nn.ReLU() else: self.relu = nn.LeakyReLU(negative_slope=self.negative_slop) if self.norm_choice == 'IN': # self.norm = nn.InstanceNorm3d(num_features=self.num_features,momentum=1.0,track_running_stats=False) self.norm = nn.InstanceNorm3d(num_features=self.num_features,momentum=1.0) elif self.norm_choice == 'BN': # self.norm = nn.BatchNorm3d(num_features=self.num_features,momentum=1.0,track_running_stats=False) self.norm = nn.BatchNorm3d(num_features=self.num_features,momentum=1.0) else: self.norm = None def forward(self,x): if self.norm: x = self.norm(x) x = self.relu(x) return x class ConvFunc3D(nn.Module): def __init__(self,**kwargs): super().__init__() self.num_in_features = kwargs.get('num_in_features') self.num_out_features = kwargs.get('num_out_features') self.stride = kwargs.get('stride',1) self.conv_choice = kwargs.get('conv_choice','basic') self.kernel_size = kwargs.get('kernel_size',1) if self.conv_choice == 'basic' or self.kernel_size == 1: if self.kernel_size == 1: self.conv = nn.Conv3d(in_channels=self.num_in_features,out_channels=self.num_out_features,kernel_size=self.kernel_size, stride=self.stride) else: self.conv = nn.Conv3d(in_channels=self.num_in_features,out_channels=self.num_out_features,kernel_size=self.kernel_size, stride=self.stride,padding=[int(self.kernel_size/2) for _ in range(3)]) else: self.conv = nn.Sequential( nn.Conv3d(in_channels=self.num_in_features,out_channels=self.num_in_features,kernel_size=self.kernel_size, stride=self.stride,padding=[int(self.kernel_size/2) for _ in range(3)],groups=self.num_in_features), nn.Conv3d(in_channels=self.num_in_features,out_channels=self.num_out_features,kernel_size=1) ) def forward(self,x): return self.conv(x) class BasicBlock(nn.Module): expansion = 1 def __init__(self,**kwargs): super().__init__() # self.expansion = 4 self.negative_slope = kwargs.get('negative_slop',0) self.norm_choice = kwargs.get('norm_choice','BN') self.num_in_features = kwargs.get('num_in_features') self.num_out_features = kwargs.get('num_out_features') self.kernel_size = kwargs.get('kernel_size',3) self.stride = kwargs.get('stride',1) self.downsample = kwargs.get('downsample',None) self.conv_choice = kwargs.get('conv_choice','basic') padding_func = kwargs.get('padding_func','zero') self.conv1 = ConvFunc3D(num_in_features=self.num_in_features,num_out_features=self.num_out_features,conv_choice=self.conv_choice, padding_func=padding_func) self.bn1 = BNReLU(negative_slop=self.negative_slope,norm_choice=self.norm_choice,num_features=self.num_out_features) self.conv2 = ConvFunc3D(num_in_features=self.num_out_features,num_out_features=self.num_out_features, stride = self.stride,kernel_size=self.kernel_size,conv_choice=self.conv_choice, padding_func=padding_func) self.bn2 = BNReLU(negative_slop=self.negative_slope,norm_choice=self.norm_choice,num_features=self.num_out_features) self.conv3 = ConvFunc3D(num_in_features=self.num_out_features,num_out_features=self.num_out_features*self.expansion, conv_choice=self.conv_choice,padding_func=padding_func) self.bn3 = BNReLU(negative_slop=self.negative_slope,norm_choice=self.norm_choice,num_features=self.num_out_features*self.expansion) self.relu = nn.LeakyReLU(negative_slope=self.negative_slope) def forward(self,x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class DarkNetShortCut(nn.Module): def __init__(self,**kwargs): super().__init__() def forward(self,x,y): if x.shape[1]==y.shape[1]: return x+y elif x.shape[1]