import torch from torch import nn from .loss import CT_Loss from ..backbone import ResUNet, ResUNet_Large from ..center_utils import _topk, _transpose_and_gather_feat from ..rpn import CenterRpnHead class CenterNet(nn.Module): def __init__(self, cfg, mode='train'): super(CenterNet, self).__init__() self.mode = mode self.cfg = cfg self.ct_top_k = cfg.MODEL.RPN.DEPLOY_PRE_NMS_TOP_K if cfg.MODEL.BACKBONE.CONV_BODY == "ResUNet": self.feature_net = ResUNet(cfg, 1, 128) elif cfg.MODEL.BACKBONE.CONV_BODY == "ResUNet_Large": self.feature_net = ResUNet_Large(cfg, 1, 128) else: raise ValueError("Detector backbone %s is not implemented." % self.cfg.MODEL.BACKBONE.CONV_BODY) self.rpn_head = CenterRpnHead(in_channels=128) self.loss = CT_Loss() def forward(self, inputs): features, feat_4 = self.feature_net(inputs) self.rpn_logits, self.rpn_deltas, self.rpn_offsets = self.rpn_head(features[-1]) if self.mode in ['train', 'valid']: return self.rpn_logits, self.rpn_deltas, self.rpn_offsets elif self.mode in ['test']: return ctdet_decode(self.rpn_logits, self.rpn_deltas, self.rpn_offsets, self.ct_top_k) else: raise ValueError('rpn_nms(): invalid mode = %s?' % self.mode) def set_mode(self, mode): assert mode in ['train', 'valid', 'test'] self.mode = mode if mode in ['train']: self.train() else: self.eval() def ctdet_decode(heat, diameters, offsets=None, K=100): batch, _, depth, height, width = heat.size() heat = torch.sigmoid(heat) # perform nms on heatmaps # heat = _nms(heat) scores, inds, clses, zs, ys, xs = _topk(heat, K=K) if offsets is not None: offsets = _transpose_and_gather_feat(offsets, inds) offsets = offsets.view(batch, K, 3) zs = zs.view(batch, K, 1) + offsets[:, :, 0:1] ys = ys.view(batch, K, 1) + offsets[:, :, 1:2] xs = xs.view(batch, K, 1) + offsets[:, :, 2:3] else: zs = zs.view(batch, K, 1) + 0.5 ys = ys.view(batch, K, 1) + 0.5 xs = xs.view(batch, K, 1) + 0.5 diameters = _transpose_and_gather_feat(diameters, inds) diameters = diameters.view(batch, K, 3) clses = clses.view(batch, K, 1).float() scores = scores.view(batch, K, 1) bboxes = torch.cat([zs, ys, xs, diameters[..., 0:1], diameters[..., 1:2], diameters[..., 2:3] ], dim=2) bboxes = bboxes * 2 detections = torch.cat([bboxes, scores], dim=2) return detections