from torch import nn import torch.nn.functional as F class RpnHead(nn.Module): def __init__(self, anchor_len, in_channels=128): super(RpnHead, self).__init__() self.anchor_len = anchor_len self.drop = nn.Dropout3d(p=0.5, inplace=False) self.conv = nn.Sequential(nn.Conv3d(in_channels, 64, kernel_size=1), nn.ReLU()) self.logits = nn.Conv3d(64, 1 * self.anchor_len, kernel_size=1) self.deltas = nn.Conv3d(64, 6 * self.anchor_len, kernel_size=1) def forward(self, f): # out = self.drop(f) out = self.conv(f) logits = self.logits(out) deltas = self.deltas(out) size = logits.size() logits = logits.view(logits.size(0), logits.size(1), -1) logits = logits.transpose(1, 2).contiguous().view(size[0], size[2], size[3], size[4], self.anchor_len, 1) logits = logits.view(logits.size(0), -1, 1) size = deltas.size() deltas = deltas.view(deltas.size(0), deltas.size(1), -1) deltas = deltas.transpose(1, 2).contiguous().view(size[0], size[2], size[3], size[4], self.anchor_len, 6) deltas = deltas.view(deltas.size(0), -1, 6); return logits, deltas