import torch from torch import nn class Loss(nn.Module): def __init__(self, rpn_loss_fun, bbox_loss_fun, use_rcnn): super(Loss, self).__init__() self.rpn_loss_fun = rpn_loss_fun self.bbox_loss_fun = bbox_loss_fun self.rcnn_cls_loss, self.rcnn_reg_loss = torch.tensor(0.).cuda(), torch.tensor(0.).cuda() self.rcnn_stats = None self.use_rcnn = use_rcnn def forward(self, rpn_logits_flat, rpn_deltas_flat, rpn_labels, rpn_label_weights, rpn_targets, rpn_target_weights, rcnn_logits=None, rcnn_deltas=None, rcnn_labels=None, rcnn_targets=None): rpn_cls_loss, rpn_reg_loss, rpn_stats = self.rpn_loss_fun(rpn_logits_flat, rpn_deltas_flat, rpn_labels, rpn_label_weights, rpn_targets, rpn_target_weights) if self.use_rcnn: self.rcnn_cls_loss, self.rcnn_reg_loss, self.rcnn_stats = \ self.bbox_loss_fun(rcnn_logits, rcnn_deltas, rcnn_labels, rcnn_targets) total_loss = rpn_cls_loss + rpn_reg_loss + self.rcnn_cls_loss + self.rcnn_reg_loss return total_loss, [rpn_cls_loss, rpn_reg_loss, self.rcnn_cls_loss, self.rcnn_reg_loss], rpn_stats, self.rcnn_stats class Loss_FPN(nn.Module): def __init__(self, rpn_loss_fun, bbox_loss_fun, use_rcnn): super(Loss_FPN, self).__init__() self.rpn_loss_fun = rpn_loss_fun self.bbox_loss_fun = bbox_loss_fun self.rcnn_cls_loss, self.rcnn_reg_loss = torch.tensor(0.).cuda(), torch.tensor(0.).cuda() self.rcnn_stats = None self.use_rcnn = use_rcnn def forward(self, rpn_logits_flat, rpn_deltas_flat, rpn_labels, rpn_label_weights, rpn_targets, rpn_target_weights, rcnn_logits=None, rcnn_deltas=None, rcnn_labels=None, rcnn_targets=None): fpn_cls_loss, fpn_reg_loss, fpn_stats = [], [], [] for i in range(3): rpn_cls_loss, rpn_reg_loss, rpn_stats = self.rpn_loss_fun(rpn_logits_flat[i], rpn_deltas_flat[i], rpn_labels[i], rpn_label_weights[i], rpn_targets[i], rpn_target_weights[i]) fpn_cls_loss.append(rpn_cls_loss) fpn_reg_loss.append(rpn_reg_loss) fpn_stats.append(rpn_stats) if self.use_rcnn: self.rcnn_cls_loss, self.rcnn_reg_loss, self.rcnn_stats = \ self.bbox_loss_fun(rcnn_logits, rcnn_deltas, rcnn_labels, rcnn_targets) total_loss = self.rcnn_cls_loss + self.rcnn_reg_loss for i in range(3): total_loss += fpn_cls_loss[i] + fpn_reg_loss[i] return total_loss, [fpn_cls_loss, fpn_reg_loss, self.rcnn_cls_loss, self.rcnn_reg_loss], fpn_stats, self.rcnn_stats