import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def create_conv_block(in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=False, bn=True, activation=True): layers = [] layers.append(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)) if bn: layers.append(nn.InstanceNorm2d(out_planes)) if activation: layers.append(nn.LeakyReLU(inplace=True)) return nn.Sequential(*layers) def create_conv_block_k1(in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=False, bn=True, activation=True): return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation) def create_conv_block_k2(in_planes, out_planes, kernel_size=2, stride=1, padding=0, dilation=1, groups=1, bias=False, bn=True, activation=True): return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation) def create_conv_block_k3(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=False, bn=True, activation=True): return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation) class RfbfBlock2d(nn.Module): def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0): super(RfbfBlock2d, self).__init__() inter_planes = max(int(np.ceil(out_planes / 8)), 1) self.groups = groups self.group_num = inter_planes // groups self.droprate = droprate self.branch1 = nn.Sequential( create_conv_block(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), stride=stride, groups=groups), create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), dilation=(1, 1), groups=groups) ) self.branch2 = nn.Sequential( create_conv_block(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), stride=stride, groups=groups), create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), dilation=(1, 1), groups=groups) ) self.branch3 = nn.Sequential( create_conv_block(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), stride=stride, groups=groups), create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), dilation=(1, 1), groups=groups) ) self.branch4 = nn.Sequential( create_conv_block(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3), stride=stride, groups=groups), create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), dilation=(1, 1), groups=groups) ) self.branch5 = nn.Sequential( create_conv_block(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3), stride=stride, groups=groups), create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), dilation=(1, 1), groups=groups) ) self.branch6 = nn.Sequential( create_conv_block(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3), stride=stride, groups=groups), create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), dilation=(1, 1), groups=groups) ) self.branch7 = nn.Sequential( create_conv_block(in_planes, inter_planes, kernel_size=(9, 9), padding=(4, 4), stride=stride, groups=groups), create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), dilation=(1, 1), groups=groups) ) self.branch8 = nn.Sequential( create_conv_block(in_planes, inter_planes, kernel_size=(9, 9), padding=(4, 4), stride=stride, groups=groups), create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2), dilation=(1, 1), groups=groups) ) def forward(self, x): x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) x4 = self.branch4(x) x5 = self.branch5(x) x6 = self.branch6(x) x7 = self.branch7(x) x8 = self.branch8(x) if self.groups == 1: out = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 1) else: for group in range(self.groups): group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num], x2[:, group * self.group_num:(group + 1) * self.group_num], x3[:, group * self.group_num:(group + 1) * self.group_num], x4[:, group * self.group_num:(group + 1) * self.group_num], x5[:, group * self.group_num:(group + 1) * self.group_num], x6[:, group * self.group_num:(group + 1) * self.group_num], x7[:, group * self.group_num:(group + 1) * self.group_num], x8[:, group * self.group_num:(group + 1) * self.group_num]), 1) if group == 0: out = group_out else: out = torch.cat((out, group_out), 1) if self.droprate > 0: out = F.dropout2d(out, p=self.droprate, training=self.training) return out class ResBasicBlock2d(nn.Module): def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0, se=False): super(ResBasicBlock2d, self).__init__() self.conv1 = create_conv_block_k3(in_planes, out_planes, stride=stride, groups=groups) self.conv2 = create_conv_block_k3(out_planes, out_planes, groups=groups, activation=False) self.shortcut = None if stride != 1 or in_planes != out_planes: self.shortcut = create_conv_block_k1(in_planes, out_planes, stride=stride, groups=groups, activation=False) self.droprate = droprate self.se = se if se: self.fc1 = nn.Linear(in_features=out_planes, out_features=out_planes // 4) self.fc2 = nn.Linear(in_features=out_planes // 4, out_features=out_planes) def forward(self, x): identity = self.shortcut(x) if self.shortcut is not None else x out = self.conv1(x) if self.droprate > 0: out = F.dropout2d(out, p=self.droprate, training=self.training) out = self.conv2(out) if self.se: original_out = out out = F.adaptive_avg_pool2d(out, (1, 1)) out = torch.flatten(out, 1) out = self.fc1(out) out = F.leaky_relu(out, inplace=True) out = self.fc2(out) out = out.sigmoid() #这里需要测试一下 out = out.view(out.size(0), out.size(1), 1, 1) out = out * original_out out = out + identity out = F.leaky_relu(out, inplace=True) return out