from torch import nn class CenterRpnHead(nn.Module): def __init__(self, in_channels=128): super(CenterRpnHead, self).__init__() self.conv = nn.Sequential(nn.Conv3d(in_channels, 64, kernel_size=1), nn.ReLU()) self.logits = nn.Conv3d(64, 1, kernel_size=1) self.logits.bias.data.fill_(-4) self.deltas = nn.Conv3d(64, 3, kernel_size=1) self.offsets = nn.Conv3d(64, 3, kernel_size=1) def forward(self, f): out = self.conv(f) logits = self.logits(out) deltas = self.deltas(out) offsets = self.offsets(out) return logits, deltas, offsets