import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict def weight_init(m): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv3d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm3d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class ConvStage(nn.Module): def __init__(self, **kwargs): super().__init__() self.negative_slop = kwargs.get('negative_slop', 0) self.norm_choice = kwargs.get('norm_choice', 'BN') self.num_in_features = kwargs.get('num_in_features') self.num_out_features = kwargs.get('num_out_features') self.kernel_size = kwargs.get('kernel_size', 3) self.repeat_times = kwargs.get('repeat_times', 3) self.func_list = {} for idx in range(self.repeat_times): if idx == 0: current_func = ConvBlock(negative_slop=self.negative_slop, norm_choice=self.norm_choice, num_in_features=self.num_in_features, num_out_features=self.num_out_features, kernel_size=self.kernel_size) else: current_func = ConvBlock(negative_slop=self.negative_slop, norm_choice=self.norm_choice, num_in_features=self.num_out_features, num_out_features=self.num_out_features, kernel_size=self.kernel_size) self.func_list['convBlock_%d' % idx] = current_func self.convStage_func = nn.Sequential(OrderedDict([(key, self.func_list[key]) for key in self.func_list.keys()])) def forward(self, x): output = x return self.convStage_func(x) class ConvBlock(nn.Module): def __init__(self, **kwargs): super().__init__() self.negative_slop = kwargs.get('negative_slop', 0) self.norm_choice = kwargs.get('norm_choice', 'BN') self.num_in_features = kwargs.get('num_in_features') self.num_out_features = kwargs.get('num_out_features') self.kernel_size = kwargs.get('kernel_size', 3) self.conv = nn.Conv3d(in_channels=self.num_in_features, out_channels=self.num_out_features, kernel_size=self.kernel_size, padding=[int(self.kernel_size / 2) for _ in range(3)]) self.post_func = BNReLU(negative_slop=self.negative_slop, norm_choice=self.norm_choice, num_features=self.num_out_features) def forward(self, x): output = x output = self.conv(output) output = self.post_func(output) return output class BNReLU(nn.Module): def __init__(self, **kwargs): super().__init__() self.negative_slop = kwargs.get('negative_slop', 0) self.norm_choice = kwargs.get('norm_choice', 'BN') self.num_features = kwargs.get('num_features') if self.negative_slop <= 0: self.relu = nn.ReLU() else: self.relu = nn.LeakyReLU(negative_slope=self.negative_slop) if self.norm_choice == 'IN': self.norm = nn.InstanceNorm3d(num_features=self.num_features) else: self.norm = nn.BatchNorm3d(num_features=self.num_features) def forward(self, x): x = self.norm(x) x = self.relu(x) return x class ConvFunc3D(nn.Module): def __init__(self, **kwargs): super().__init__() self.num_in_features = kwargs.get('num_in_features') self.num_out_features = kwargs.get('num_out_features') self.stride = kwargs.get('stride', 1) self.conv_choice = kwargs.get('conv_choice', 'basic') self.kernel_size = kwargs.get('kernel_size', 1) if self.conv_choice == 'basic' or self.kernel_size == 1: if self.kernel_size == 1: self.conv = nn.Conv3d(in_channels=self.num_in_features, out_channels=self.num_out_features, kernel_size=self.kernel_size, stride=self.stride) else: self.conv = nn.Conv3d(in_channels=self.num_in_features, out_channels=self.num_out_features, kernel_size=self.kernel_size, stride=self.stride, padding=[int(self.kernel_size / 2) for _ in range(3)]) else: self.conv = nn.Sequential( nn.Conv3d(in_channels=self.num_in_features, out_channels=self.num_in_features, kernel_size=self.kernel_size, stride=self.stride, padding=[int(self.kernel_size / 2) for _ in range(3)], groups=self.num_in_features), nn.Conv3d(in_channels=self.num_in_features, out_channels=self.num_out_features, kernel_size=1) ) def forward(self, x): return self.conv(x) class BasicBlock(nn.Module): expansion = 1 def __init__(self, **kwargs): super().__init__() # self.expansion = 4 self.negative_slope = kwargs.get('negative_slop', 0) self.norm_choice = kwargs.get('norm_choice', 'BN') self.num_in_features = kwargs.get('num_in_features') self.num_out_features = kwargs.get('num_out_features') self.kernel_size = kwargs.get('kernel_size', 3) self.stride = kwargs.get('stride', 1) self.downsample = kwargs.get('downsample', None) self.conv_choice = kwargs.get('conv_choice', 'basic') self.conv1 = ConvFunc3D(num_in_features=self.num_in_features, num_out_features=self.num_out_features, conv_choice=self.conv_choice) self.bn1 = BNReLU(negative_slop=self.negative_slope, norm_choice=self.norm_choice, num_features=self.num_out_features) self.conv2 = ConvFunc3D(num_in_features=self.num_out_features, num_out_features=self.num_out_features, stride=self.stride, kernel_size=self.kernel_size, conv_choice=self.conv_choice) self.bn2 = BNReLU(negative_slop=self.negative_slope, norm_choice=self.norm_choice, num_features=self.num_out_features) self.conv3 = ConvFunc3D(num_in_features=self.num_out_features, num_out_features=self.num_out_features * self.expansion, conv_choice=self.conv_choice) self.bn3 = BNReLU(negative_slop=self.negative_slope, norm_choice=self.norm_choice, num_features=self.num_out_features * self.expansion) self.relu = nn.LeakyReLU(negative_slope=self.negative_slope) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class DarkNetShortCut(nn.Module): def __init__(self, **kwargs): super().__init__() def forward(self, x, y): if x.shape[1] == y.shape[1]: return x + y elif x.shape[1] < y.shape[1]: y[:x.shape[1]] += x.shape[1] return y else: x[:y.shape[1]] += y.shape[1] return x class SCIModule(nn.Module): def __init__(self, **kwargs): super().__init__() n_input_channel = kwargs.get('n_input_channel') n_output_channel = kwargs.get('n_output_channel') self.conv_choice = kwargs.get('conv_choice') self.kernel_size = kwargs.get('kernel_size', 3) self.softmax = nn.Softmax(dim=2) self.conv_func = ConvBlock(num_in_features=n_input_channel, num_out_features=n_output_channel, negative_slop=0.05, kernel_size=self.kernel_size) self.conv_func1x1 = ConvBlock(num_in_features=n_input_channel, num_out_features=n_output_channel, negative_slop=0.05, kernel_size=1) def forward(self, x): ''' 1. reshape 2. ''' batch_size, channel, depth, height, width = x.shape x0 = x.view((batch_size, channel, -1)) x_t = torch.transpose(x0, 1, 2) weights = torch.zeros((batch_size, channel, channel)).cuda() for slice_idx in range(batch_size): weights[slice_idx] = -torch.matmul(x0[slice_idx], x_t[slice_idx]) norm_weights = self.softmax(weights) weighted_feature = torch.zeros_like(x0) for slice_idx in range(batch_size): weighted_feature[slice_idx] = torch.matmul(norm_weights[slice_idx], x0[slice_idx]) weighted_feature = weighted_feature.view((batch_size, channel, depth, height, width)) output = self.conv_func(weighted_feature) final_output = self.conv_func1x1(output + x) return final_output