import horovod.torch as hvd import torch from torch import nn # BatchNorm3d = nn.BatchNorm3d # BatchNorm3d = hvd.SyncBatchNorm class ResBlock3d(nn.Module): def __init__(self, n_in, n_out, parameters, stride=1, bn_type='normal'): """ :param bn_type: enum type, normal for nn.BatchNorm3d while sync for hvd.SyncBatchNorm """ super(ResBlock3d, self).__init__() self.conv1 = nn.Conv3d(n_in, n_out, kernel_size=3, stride=stride, padding=1) self.bn1 = hvd.SyncBatchNorm(n_out, **parameters) if bn_type == 'sync' else nn.BatchNorm3d(n_out, **parameters) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv3d(n_out, n_out, kernel_size=3, padding=1) self.bn2 = hvd.SyncBatchNorm(n_out, **parameters) if bn_type == 'sync' else nn.BatchNorm3d(n_out, **parameters) if stride != 1 or n_out != n_in: self.shortcut = nn.Sequential( nn.Conv3d(n_in, n_out, kernel_size=1, stride=stride), hvd.SyncBatchNorm(n_out, **parameters) if bn_type == 'sync' else nn.BatchNorm3d(n_out, **parameters)) else: self.shortcut = None def forward(self, x): residual = x if self.shortcut is not None: residual = self.shortcut(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += residual out = self.relu(out) return out class ResUNet(nn.Module): def __init__(self, cfg, in_channels, out_channels, bn_type='normal'): super(ResUNet, self).__init__() parameters = { "momentum": cfg.MODEL.BACKBONE.BN_MOMENTUM } self.preBlock = nn.Sequential( nn.Conv3d(in_channels, 24, kernel_size=3, padding=1, stride=2), nn.BatchNorm3d(24, **parameters) if bn_type == 'normal' else hvd.SyncBatchNorm(24, **parameters), nn.ReLU(inplace=True), nn.Conv3d(24, 24, kernel_size=3, padding=1), nn.BatchNorm3d(24, **parameters) if bn_type == 'normal' else hvd.SyncBatchNorm(24, **parameters), nn.ReLU(inplace=True)) self.forw1 = nn.Sequential( ResBlock3d(24, 32, parameters, bn_type=bn_type), ResBlock3d(32, 32, parameters, bn_type=bn_type)) self.forw2 = nn.Sequential( ResBlock3d(32, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type)) self.forw3 = nn.Sequential( ResBlock3d(64, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type)) self.forw4 = nn.Sequential( ResBlock3d(64, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type)) # skip connection in U-net self.back2 = nn.Sequential( ResBlock3d(128, 128, parameters, bn_type=bn_type), ResBlock3d(128, 128, parameters, bn_type=bn_type), ResBlock3d(128, 128, parameters, bn_type=bn_type)) # skip connection in U-net self.back3 = nn.Sequential( ResBlock3d(128, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type)) self.maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) self.maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) self.maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) self.maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) # upsampling in U-net self.path1 = nn.Sequential( nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), nn.BatchNorm3d(64) if bn_type == 'normal' else hvd.SyncBatchNorm(64), nn.ReLU(inplace=True)) # upsampling in U-net self.path2 = nn.Sequential( nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), nn.BatchNorm3d(64) if bn_type == 'normal' else hvd.SyncBatchNorm(64), nn.ReLU(inplace=True)) def forward(self, x): out_0 = self.preBlock(x) # 24, 64 out1 = self.forw1(out_0) # 32, 64 out1_pool, _ = self.maxpool2(out1) # 32, 32 out2 = self.forw2(out1_pool) # 64, 32 out2_pool, _ = self.maxpool3(out2) # 64, 16 out3 = self.forw3(out2_pool) # 64, 16 out3_pool, _ = self.maxpool4(out3) # 64, 8 out4 = self.forw4(out3_pool) # 64, 8 rev3 = self.path1(out4) # 64, 16 comb3 = self.back3(torch.cat((rev3, out3), 1)) # 64+64, 16 → 64, 16 rev2 = self.path2(comb3) # 64, 32 comb2 = self.back2(torch.cat((rev2, out2), 1)) # 64+64, 32 → 128, 32 return [x, out1, comb2], comb2 # out2 class ResUNet_Large(nn.Module): def __init__(self, cfg, in_channels, out_channels, bn_type='sync'): super(ResUNet_Large, self).__init__() parameters = { "momentum": cfg.MODEL.BACKBONE.BN_MOMENTUM } self.preBlock = nn.Sequential( nn.Conv3d(in_channels, 24, kernel_size=3, padding=1, stride=2), hvd.SyncBatchNorm(24, **parameters) if bn_type == 'sync' else nn.BatchNorm3d(24, **parameters), nn.ReLU(inplace=True), nn.Conv3d(24, 24, kernel_size=3, padding=1), hvd.SyncBatchNorm(24, **parameters) if bn_type == 'sync' else nn.BatchNorm3d(24, **parameters), nn.ReLU(inplace=True)) self.forw1 = nn.Sequential( ResBlock3d(24, 32, parameters, bn_type=bn_type), ResBlock3d(32, 32, parameters, bn_type=bn_type)) self.forw2 = nn.Sequential( ResBlock3d(32, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type)) self.forw3 = nn.Sequential( ResBlock3d(64, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type)) self.forw4 = nn.Sequential( ResBlock3d(64, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type)) # skip connection in U-net self.back1 = nn.Sequential( ResBlock3d(96, 128, parameters, bn_type=bn_type), ResBlock3d(128, 128, parameters, bn_type=bn_type), ResBlock3d(128, 128, parameters, bn_type=bn_type)) # skip connection in U-net self.back2 = nn.Sequential( ResBlock3d(128, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type)) # skip connection in U-net self.back3 = nn.Sequential( ResBlock3d(128, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type), ResBlock3d(64, 64, parameters, bn_type=bn_type)) self.maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) self.maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) self.maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) self.maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) # upsampling in U-net self.path1 = nn.Sequential( nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), hvd.SyncBatchNorm(64) if bn_type == 'sync' else nn.BatchNorm3d(64), nn.ReLU(inplace=True)) # upsampling in U-net self.path2 = nn.Sequential( nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), hvd.SyncBatchNorm(64) if bn_type == 'sync' else nn.BatchNorm3d(64), nn.ReLU(inplace=True)) # upsampling in U-net self.path3 = nn.Sequential( nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), hvd.SyncBatchNorm(64) if bn_type == 'sync' else nn.BatchNorm3d(64), nn.ReLU(inplace=True)) def forward(self, x): out_0 = self.preBlock(x) # 24, 64 out1 = self.forw1(out_0) # 32, 64 out1_pool, _ = self.maxpool2(out1) # 32, 32 out2 = self.forw2(out1_pool) # 64, 32 out2_pool, _ = self.maxpool3(out2) # 64, 16 out3 = self.forw3(out2_pool) # 64, 16 out3_pool, _ = self.maxpool4(out3) # 64, 8 out4 = self.forw4(out3_pool) # 64, 8 rev3 = self.path1(out4) # 64, 16 comb3 = self.back3(torch.cat((rev3, out3), 1)) # 64+64, 16 -> 64, 16 rev2 = self.path2(comb3) # 64, 32 comb2 = self.back2(torch.cat((rev2, out2), 1)) # 64+64, 32 -> 64, 32 rev1 = self.path3(comb2) # 64, 64 comb1 = self.back1(torch.cat((rev1, out1), 1)) # 64+32, 64 -> 128, 64 return [x, out1, comb1], comb1 # out2