import torch from torch import nn from .center_loss import _sigmoid, FocalLoss, RegL1Loss class Loss(nn.Module): def __init__(self, rpn_loss_fun, bbox_loss_fun, use_rcnn): super(Loss, self).__init__() self.rpn_loss_fun = rpn_loss_fun self.bbox_loss_fun = bbox_loss_fun self.rcnn_cls_loss, self.rcnn_reg_loss = torch.tensor(0.).cuda(), torch.tensor(0.).cuda() self.rcnn_stats = None self.use_rcnn = use_rcnn def forward(self, rpn_logits_flat, rpn_deltas_flat, rpn_labels, rpn_label_weights, rpn_targets, rpn_target_weights, rcnn_logits=None, rcnn_deltas=None, rcnn_labels=None, rcnn_targets=None): rpn_cls_loss, rpn_reg_loss, rpn_stats = self.rpn_loss_fun(rpn_logits_flat, rpn_deltas_flat, rpn_labels, rpn_label_weights, rpn_targets, rpn_target_weights) if self.use_rcnn: self.rcnn_cls_loss, self.rcnn_reg_loss, self.rcnn_stats = self.bbox_loss_fun(rcnn_logits, rcnn_deltas, rcnn_labels, rcnn_targets) total_loss = rpn_cls_loss + rpn_reg_loss + self.rcnn_cls_loss + self.rcnn_reg_loss return total_loss, [rpn_cls_loss, rpn_reg_loss, self.rcnn_cls_loss, self.rcnn_reg_loss], rpn_stats, self.rcnn_stats class CT_Loss(nn.Module): def __init__(self): super(CT_Loss, self).__init__() self.crit = FocalLoss() self.crit_reg = RegL1Loss() # RegLoss() def forward(self, rpn_logits, rpn_deltas, rpn_offsets, gaussian_hm, center_idxs, bboxes_diameters, reg_offset, reg_mask): hm_loss, reg_loss, off_loss = 0, 0, 0 rpn_logits = _sigmoid(rpn_logits) hm_loss = self.crit(rpn_logits, gaussian_hm) reg_loss = self.crit_reg(rpn_deltas, reg_mask, center_idxs, bboxes_diameters) if torch.isnan(reg_loss.detach().cpu()): reg_loss.data = torch.tensor(0.).cuda() off_loss = self.crit_reg(rpn_offsets, reg_mask, center_idxs, reg_offset) if torch.isnan(off_loss.detach().cpu()): off_loss.data = torch.tensor(0.).cuda() loss = hm_loss + 0.1 * reg_loss + 1 * off_loss loss_stats = [hm_loss, reg_loss, off_loss] return loss, loss_stats