import torch import torch.nn as nn import torch.nn.functional as F def init_modules(model): for name, param in model.named_parameters(): if 'weight' in name: if isinstance(param, torch.nn.Parameter): nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='relu') elif 'bias' in name: if isinstance(param, torch.nn.Parameter): nn.init.constant_(param.data, 0.04) def create_conv_block_3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, bn=True, activation=True): layers = [] layers.append(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)) if bn: layers.append(nn.InstanceNorm3d(out_planes)) if activation: layers.append(nn.LeakyReLU(inplace=True)) return nn.Sequential(*layers) def create_conv_block_k1_3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=False, bn=True, activation=True): return create_conv_block_3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation) def create_conv_block_k2_3d(in_planes, out_planes, kernel_size=2, stride=1, padding=0, dilation=1, groups=1, bias=False, bn=True, activation=True): return create_conv_block_3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation) def create_conv_block_k3_3d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=False, bn=True, activation=True): return create_conv_block_3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation) class RfbfBlock3d(nn.Module): def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0): super(RfbfBlock3d, self).__init__() inter_planes = max(int(torch.ceil(torch.tensor(out_planes / 8))), 1) self.groups = groups self.group_num = inter_planes // groups self.droprate = droprate self.rotation_attn = nn.Sequential( nn.AdaptiveAvgPool3d(1), nn.Conv3d(in_planes, max(inter_planes//2, 1), 1), nn.LeakyReLU(inplace=True), nn.Conv3d(max(inter_planes//2, 1), inter_planes, 1), nn.Sigmoid() ) self.branch1 = nn.Sequential( create_conv_block_3d(in_planes, inter_planes, kernel_size=(3,3,3), padding=(1,1,1), stride=stride, groups=groups), create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,3,3), padding=(1,1,1), dilation=(1,1,1), groups=groups) ) self.branch2 = nn.Sequential( create_conv_block_3d(in_planes, inter_planes, kernel_size=(3,5,5), padding=(1,2,2), stride=stride, groups=groups), create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,5,5), padding=(2,4,4), dilation=(2,2,2), groups=groups) ) self.branch3 = nn.Sequential( create_conv_block_3d(in_planes, inter_planes, kernel_size=(3,7,7), padding=(1,3,3), stride=stride, groups=groups), create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,7,7), padding=(3,3,3), dilation=(3,1,1), groups=groups) ) self.branch4 = nn.Sequential( create_conv_block_3d(in_planes, inter_planes, kernel_size=(3,5,5), padding=(1,2,2), stride=stride, groups=groups), create_conv_block_3d(inter_planes, inter_planes, kernel_size=(1,3,3), padding=(0,1,1), dilation=(1,1,1), groups=groups) ) self.branch5 = nn.Sequential( create_conv_block_3d(in_planes, inter_planes, kernel_size=(5,5,5), padding=(2,2,2), stride=stride, groups=groups), create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,3,3), padding=(2,2,2), dilation=(2,2,2), groups=groups) ) self.branch6 = nn.Sequential( create_conv_block_3d(in_planes, inter_planes, kernel_size=(3,9,9), padding=(1,4,4), stride=stride, groups=groups), create_conv_block_3d(inter_planes, inter_planes, kernel_size=(1,5,5), padding=(0,2,2), dilation=(1,1,1), groups=groups) ) self.branch7 = nn.Sequential( create_conv_block_3d(in_planes, inter_planes, kernel_size=(5,11,11), padding=(2,5,5), stride=stride, groups=groups), create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,5,5), padding=(1,4,4), dilation=(1,2,2), groups=groups) ) self.branch8 = nn.Sequential( create_conv_block_3d(in_planes, inter_planes, kernel_size=(7,7,7), padding=(3,3,3), stride=stride, groups=groups), create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,3,3), padding=(2,2,2), dilation=(2,2,2), groups=groups) ) def forward(self, x): attn = self.rotation_attn(x) x = x * attn x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) x4 = self.branch4(x) x5 = self.branch5(x) x6 = self.branch6(x) x7 = self.branch7(x) x8 = self.branch8(x) ''' RfbfBlock3d, x1.shape: torch.Size([1, 1, 48, 256, 256]), x2.shape: torch.Size([1, 1, 48, 256, 256]), x3.shape: torch.Size([1, 1, 48, 250, 250]), x4.shape: torch.Size([1, 1, 48, 256, 256]), x5.shape: torch.Size([1, 1, 48, 256, 256]), x6.shape: torch.Size([1, 1, 48, 256, 256]), x7.shape: torch.Size([1, 1, 48, 256, 256]), x8.shape: torch.Size([1, 1, 48, 256, 256]) ''' if self.groups == 1: out = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 1) else: for group in range(self.groups): group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num], x2[:, group * self.group_num:(group + 1) * self.group_num], x3[:, group * self.group_num:(group + 1) * self.group_num], x4[:, group * self.group_num:(group + 1) * self.group_num], x5[:, group * self.group_num:(group + 1) * self.group_num], x6[:, group * self.group_num:(group + 1) * self.group_num], x7[:, group * self.group_num:(group + 1) * self.group_num], x8[:, group * self.group_num:(group + 1) * self.group_num]), 1) if group == 0: out = group_out else: out = torch.cat((out, group_out), 1) if self.droprate > 0: out = F.dropout3d(out, p=self.droprate, training=self.training) return out class RfbeBlock3d(nn.Module): def __init__(self, in_planes, out_planes, stride=1, groups=1): super(RfbeBlock3d, self).__init__() inter_planes = max(int(torch.ceil(torch.tensor(out_planes / 8))), 2) self.groups = groups self.group_num = 2 * inter_planes // groups self.branch1 = nn.Sequential( create_conv_block_3d(in_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups), create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=1, dilation=1, groups=groups) ) self.branch2 = nn.Sequential( create_conv_block_3d(in_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups), create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=2, dilation=2, groups=groups) ) self.branch3 = nn.Sequential( create_conv_block_3d(in_planes, 2 * inter_planes, kernel_size=5, stride=stride, padding=2, groups=groups), create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=3, dilation=3, groups=groups) ) self.branch4 = nn.Sequential( create_conv_block_3d(in_planes, 2 * inter_planes, kernel_size=7, stride=stride, padding=3, groups=groups), create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=4, dilation=4, groups=groups) ) self.concat_conv = create_conv_block_k1_3d(8 * inter_planes, out_planes, groups=groups, activation=False) self.shortcut = create_conv_block_k1_3d(in_planes, out_planes, stride=stride, groups=groups, activation=False) def forward(self, x): identity = self.shortcut(x) x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) x4 = self.branch4(x) if self.groups == 1: out = torch.cat((x1, x2, x3, x4), 1) else: for group in range(self.groups): group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num], x2[:, group * self.group_num:(group + 1) * self.group_num], x3[:, group * self.group_num:(group + 1) * self.group_num], x4[:, group * self.group_num:(group + 1) * self.group_num]), 1) if group == 0: out = group_out else: out = torch.cat((out, group_out), 1) out = self.concat_conv(out) out = out + identity out = F.leaky_relu(out, inplace=True) return out class RfbBlock3d(nn.Module): def __init__(self, in_planes, out_planes, stride=1, groups=1): super(RfbBlock3d, self).__init__() inter_planes = max(int(torch.ceil(torch.tensor(out_planes / 8))), 2) self.groups = groups self.group_num = 2 * inter_planes // groups self.branch1 = nn.Sequential( create_conv_block_k1_3d(in_planes, 2 * inter_planes, stride=stride, groups=groups), create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=1, dilation=1, groups=groups) ) self.branch2 = nn.Sequential( create_conv_block_k1_3d(in_planes, inter_planes, groups=groups), create_conv_block_3d(inter_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups), create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=2, dilation=2, groups=groups) ) self.branch3 = nn.Sequential( create_conv_block_k1_3d(in_planes, inter_planes, groups=groups), create_conv_block_3d(inter_planes, 2 * inter_planes, kernel_size=5, stride=stride, padding=2, groups=groups), create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=3, dilation=3, groups=groups) ) self.concat_conv = create_conv_block_k1_3d(6 * inter_planes, out_planes, groups=groups, activation=False) self.shortcut = create_conv_block_k1_3d(in_planes, out_planes, stride=stride, groups=groups, activation=False) def forward(self, x): identity = self.shortcut(x) x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) if self.groups == 1: out = torch.cat((x1, x2, x3), 1) else: for group in range(self.groups): group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num], x2[:, group * self.group_num:(group + 1) * self.group_num], x3[:, group * self.group_num:(group + 1) * self.group_num]), 1) if group == 0: out = group_out else: out = torch.cat((out, group_out), 1) out = self.concat_conv(out) out = out + identity out = F.leaky_relu(out, inplace=True) return out class ResBasicBlock3d(nn.Module): def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0, se=False): super(ResBasicBlock3d, self).__init__() self.conv1 = create_conv_block_k3_3d(in_planes, out_planes, stride=stride, groups=groups) self.conv2 = create_conv_block_k3_3d(out_planes, out_planes, groups=groups, activation=False) self.shortcut = None if stride != 1 or in_planes != out_planes: self.shortcut = create_conv_block_k1_3d(in_planes, out_planes, stride=stride, groups=groups, activation=False) self.droprate = droprate self.se = se if se: self.fc1 = nn.Linear(in_features=out_planes, out_features=out_planes // 4) self.fc2 = nn.Linear(in_features=out_planes // 4, out_features=out_planes) def forward(self, x): identity = self.shortcut(x) if self.shortcut is not None else x out = self.conv1(x) if self.droprate > 0: out = F.dropout3d(out, p=self.droprate, training=self.training) out = self.conv2(out) if self.se: original_out = out out = F.adaptive_avg_pool3d(out, (1, 1, 1)) out = torch.flatten(out, 1) out = self.fc1(out) out = F.leaky_relu(out, inplace=True) out = self.fc2(out) out = out.sigmoid() out = out.view(out.size(0), out.size(1), 1, 1, 1) out = out * original_out out = out + identity out = F.leaky_relu(out, inplace=True) return out class UpBlock3d(nn.Module): def __init__(self, in_planes1, in_planes2, out_planes, groups=1, scale_factor=2): super(UpBlock3d, self).__init__() self.scale_factor = scale_factor self.conv1 = create_conv_block_k3_3d(in_planes1, out_planes, groups=groups, activation=False) self.conv2 = create_conv_block_k3_3d(in_planes2, out_planes, groups=groups, activation=False) def forward(self, x1, x2): if self.scale_factor != 1 and self.scale_factor != (1, 1, 1): x1 = F.interpolate(x1, scale_factor=self.scale_factor, mode='nearest') out1 = self.conv1(x1) out2 = self.conv2(x2) out = out1 + out2 out = F.leaky_relu(out, inplace=True) return out def create_conv_block_2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, bn=True, activation=True): layers = [] layers.append(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)) if bn: layers.append(nn.InstanceNorm2d(out_planes)) if activation: layers.append(nn.LeakyReLU(inplace=True)) return nn.Sequential(*layers) def create_conv_block_k1_2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=False, bn=True, activation=True): return create_conv_block_2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation) def create_conv_block_k2_2d(in_planes, out_planes, kernel_size=2, stride=1, padding=0, dilation=1, groups=1, bias=False, bn=True, activation=True): return create_conv_block_2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation) def create_conv_block_k3_2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=False, bn=True, activation=True): return create_conv_block_2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation) class RfbfBlock2d(nn.Module): def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0): super(RfbfBlock2d, self).__init__() inter_planes = max(int(torch.ceil(torch.tensor(out_planes / 8))), 1) self.groups = groups self.group_num = inter_planes // groups self.droprate = droprate self.rotation_attn = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_planes, max(inter_planes//2, 1), 1), nn.LeakyReLU(inplace=True), nn.Conv2d(max(inter_planes//2, 1), inter_planes, 1), nn.Sigmoid() ) self.branch1 = nn.Sequential( create_conv_block_2d(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), stride=stride, groups=groups), create_conv_block_2d(inter_planes, inter_planes, kernel_size=(3, 3), padding=(1, 1), dilation=(1, 1), groups=groups) ) self.branch2 = nn.Sequential( create_conv_block_2d(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), stride=stride, groups=groups), create_conv_block_2d(inter_planes, inter_planes, kernel_size=(3, 3), padding=(2, 2), dilation=(2, 2), groups=groups) ) self.branch3 = nn.Sequential( create_conv_block_2d(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3), stride=stride, groups=groups), create_conv_block_2d(inter_planes, inter_planes, kernel_size=(3, 3), padding=(3, 3), dilation=(3, 3), groups=groups) ) self.branch4 = nn.Sequential( create_conv_block_2d(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), stride=stride, groups=groups), create_conv_block_2d(inter_planes, inter_planes, kernel_size=(5, 5), padding=(4, 4), dilation=(2, 2), groups=groups) ) self.branch5 = nn.Sequential( create_conv_block_2d(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3), stride=stride, groups=groups), create_conv_block_2d(inter_planes, inter_planes, kernel_size=(5, 5), padding=(4, 4), dilation=(2, 2), groups=groups) ) self.branch6 = nn.Sequential( create_conv_block_2d(in_planes, inter_planes, kernel_size=(9, 9), padding=(4, 4), stride=stride, groups=groups), create_conv_block_2d(inter_planes, inter_planes, kernel_size=(3, 3), padding=(2, 2), dilation=(2, 2), groups=groups) ) self.branch7 = nn.Sequential( create_conv_block_2d(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3), stride=stride, groups=groups), create_conv_block_2d(inter_planes, inter_planes, kernel_size=(7, 7), padding=(6, 6), dilation=(2, 2), groups=groups) ) self.branch8 = nn.Sequential( create_conv_block_2d(in_planes, inter_planes, kernel_size=(11, 11), padding=(5, 5), stride=stride, groups=groups), create_conv_block_2d(inter_planes, inter_planes, kernel_size=(5, 5), padding=(4, 4), dilation=(2, 2), groups=groups) ) def forward(self, x): attn = self.rotation_attn(x) x = x * attn x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) x4 = self.branch4(x) x5 = self.branch5(x) x6 = self.branch6(x) x7 = self.branch7(x) x8 = self.branch8(x) if self.groups == 1: out = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 1) else: for group in range(self.groups): group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num], x2[:, group * self.group_num:(group + 1) * self.group_num], x3[:, group * self.group_num:(group + 1) * self.group_num], x4[:, group * self.group_num:(group + 1) * self.group_num], x5[:, group * self.group_num:(group + 1) * self.group_num], x6[:, group * self.group_num:(group + 1) * self.group_num], x7[:, group * self.group_num:(group + 1) * self.group_num], x8[:, group * self.group_num:(group + 1) * self.group_num]), 1) if group == 0: out = group_out else: out = torch.cat((out, group_out), 1) if self.droprate > 0: out = F.dropout2d(out, p=self.droprate, training=self.training) return out class ResBasicBlock2d(nn.Module): def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0, se=False): super(ResBasicBlock2d, self).__init__() self.conv1 = create_conv_block_k3_2d(in_planes, out_planes, stride=stride, groups=groups) self.conv2 = create_conv_block_k3_2d(out_planes, out_planes, groups=groups, activation=False) self.shortcut = None if stride != 1 or in_planes != out_planes: self.shortcut = create_conv_block_k1_2d(in_planes, out_planes, stride=stride, groups=groups, activation=False) self.droprate = droprate self.se = se if se: self.fc1 = nn.Linear(in_features=out_planes, out_features=out_planes // 4) self.fc2 = nn.Linear(in_features=out_planes // 4, out_features=out_planes) def forward(self, x): identity = self.shortcut(x) if self.shortcut is not None else x out = self.conv1(x) if self.droprate > 0: out = F.dropout2d(out, p=self.droprate, training=self.training) out = self.conv2(out) if self.se: original_out = out out = F.adaptive_avg_pool2d(out, (1, 1)) out = torch.flatten(out, 1) out = self.fc1(out) out = F.leaky_relu(out, inplace=True) out = self.fc2(out) out = out.sigmoid() #这里需要测试一下 out = out.view(out.size(0), out.size(1), 1, 1) out = out * original_out out = out + identity out = F.leaky_relu(out, inplace=True) return out class FPN3d(nn.Module): def __init__(self, n_channels, n_base_filters, groups=1): super(FPN3d, self).__init__() self.down_level1 = nn.Sequential( RfbfBlock3d(n_channels, n_base_filters, groups=groups) ) self.down_level2 = nn.Sequential( ResBasicBlock3d(1 * n_base_filters, 2 * n_base_filters, stride=(1, 2, 2), groups=groups), ResBasicBlock3d(2 * n_base_filters, 2 * n_base_filters, groups=groups) ) self.down_level3 = nn.Sequential( ResBasicBlock3d(2 * n_base_filters, 4 * n_base_filters, stride=(1, 2, 2), groups=groups), ResBasicBlock3d(4 * n_base_filters, 4 * n_base_filters, groups=groups) ) self.down_level4 = nn.Sequential( ResBasicBlock3d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True), ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True), ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True) ) self.down_level5 = nn.Sequential( ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True), ResBasicBlock3d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True), ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True) ) self.down_level6 = nn.Sequential( ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True) ) self.down_level7 = nn.Sequential( ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True) ) self.down_level8 = nn.Sequential( create_conv_block_k3_3d(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False) ) def forward(self, x): down_out1 = self.down_level1(x) down_out2 = self.down_level2(down_out1) down_out3 = self.down_level3(down_out2) down_out4 = self.down_level4(down_out3) down_out5 = self.down_level5(down_out4) down_out6 = self.down_level6(down_out5) down_out7 = self.down_level7(down_out6) down_out8 = self.down_level8(down_out7) return down_out8 class Net3d(nn.Module): def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8): super(Net3d, self).__init__() self.fpn = FPN3d(n_channels, n_base_filters) self.channel_attn = nn.Sequential( nn.Linear(32 * n_base_filters, 32 * n_base_filters // 4), nn.LeakyReLU(inplace=True), nn.Linear(32 * n_base_filters // 4, 32 * n_base_filters), nn.Sigmoid() ) self.diff_classifier = nn.Linear(32 * n_base_filters, n_diff_classes) def forward(self, data): out = self.fpn(data) #改成最大池化 out_max = F.adaptive_max_pool3d(out, (1, 1, 1)) out_avg = F.adaptive_avg_pool3d(out, (1, 1, 1)) out = torch.cat([torch.flatten(out_avg, 1), torch.flatten(out_max, 1)], 1) attn = self.channel_attn(out) out = out * attn diff_output = self.diff_classifier(out) diff_output = F.sigmoid(diff_output) diff_output = diff_output.squeeze(1) return diff_output def net_test_3d(): pass class FPN2d(nn.Module): def __init__(self, n_channels, n_base_filters, groups=1): super(FPN2d, self).__init__() self.down_level1 = nn.Sequential( RfbfBlock2d(n_channels, n_base_filters, groups=groups) ) self.down_level2 = nn.Sequential( ResBasicBlock2d(1 * n_base_filters, 2 * n_base_filters, stride=(2, 2), groups=groups), ResBasicBlock2d(2 * n_base_filters, 2 * n_base_filters, groups=groups) ) self.down_level3 = nn.Sequential( ResBasicBlock2d(2 * n_base_filters, 4 * n_base_filters, stride=(2, 2), groups=groups), ResBasicBlock2d(4 * n_base_filters, 4 * n_base_filters, groups=groups) ) self.down_level4 = nn.Sequential( ResBasicBlock2d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True), ResBasicBlock2d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True), ResBasicBlock2d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True) ) self.down_level5 = nn.Sequential( ResBasicBlock2d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True), ResBasicBlock2d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True), ResBasicBlock2d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True) ) self.down_level6 = nn.Sequential( ResBasicBlock2d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True) ) self.down_level7 = nn.Sequential( ResBasicBlock2d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True) ) self.down_level8 = nn.Sequential( create_conv_block_k3_2d(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False) ) def forward(self, x): down_out1 = self.down_level1(x) down_out2 = self.down_level2(down_out1) down_out3 = self.down_level3(down_out2) down_out4 = self.down_level4(down_out3) down_out5 = self.down_level5(down_out4) down_out6 = self.down_level6(down_out5) down_out7 = self.down_level7(down_out6) down_out8 = self.down_level8(down_out7) return down_out8 class Net2d(nn.Module): def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8): super(Net2d, self).__init__() self.fpn = FPN2d(n_channels, n_base_filters) self.channel_attn = nn.Sequential( nn.Linear(32 * n_base_filters, 32 * n_base_filters // 4), nn.LeakyReLU(inplace=True), nn.Linear(32 * n_base_filters // 4, 32 * n_base_filters), nn.Sigmoid() ) self.diff_classifier = nn.Linear(32 * n_base_filters, n_diff_classes) def forward(self, data): batch_size = data.shape[0] out_feat = self.fpn(data) out_avg = F.adaptive_avg_pool2d(out_feat, (1, 1)) out_max = F.adaptive_max_pool2d(out_feat, (1, 1)) out_flat = torch.cat([torch.flatten(out_avg, 1), torch.flatten(out_max, 1)], 1) out_attn = self.channel_attn(out_flat) out = out_flat * out_attn diff_output = self.diff_classifier(out) diff_output = F.sigmoid(diff_output) diff_output = diff_output.squeeze(1) return diff_output def net_test_2d(): pass class Net2d3d(nn.Module): def __init__(self): super(Net2d3d, self).__init__() self.net2d = Net2d() self.net3d = Net3d() def forward(self, data_2d, data_3d): out2d = self.net2d(data_2d) out3d = self.net3d(data_3d) return out2d, out3d def net_test_2d3d(): pass def test_initialization(): pass if __name__ == '__main__': test_initialization()