from torch import nn import torch.nn.functional as F class BoxHead(nn.Module): def __init__(self, num_class, resolution, in_channels=128): super(BoxHead, self).__init__() self.fc1 = nn.Linear(in_channels * resolution * resolution * resolution, 512) self.fc2 = nn.Linear(512, 256) self.logit = nn.Linear(256, num_class) self.delta = nn.Linear(256, num_class * 6) def forward(self, crops): x = crops.view(crops.size(0), -1) x = F.relu(self.fc1(x), inplace=True) x = F.relu(self.fc2(x), inplace=True) # x = F.dropout(x, 0.5, training=self.training) logits = self.logit(x) deltas = self.delta(x) return logits, deltas