import torch import torch.nn as nn import torch.nn.functional as F from ..center_utils import _transpose_and_gather_feat def _sigmoid(x): y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4) return y def _neg_loss(pred, gt, alpha=2, gamma=4): ''' Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory Arguments: pred (batch x c x d x h x w) gt_regr (batch x c x d x h x w) ''' pos_inds = gt.eq(1).float() neg_inds = (gt.lt(1) & gt.gt(-1)).float() neg_weights = torch.pow(1 - gt, gamma) loss = 0 pos_loss = torch.log(pred) * torch.pow(1 - pred, alpha) * pos_inds neg_loss = torch.log(1 - pred) * torch.pow(pred, alpha) * neg_weights * neg_inds num_pos = pos_inds.float().sum() pos_loss = pos_loss.sum() neg_loss = neg_loss.sum() if num_pos == 0: loss = loss - neg_loss else: loss = loss - (pos_loss + neg_loss) / num_pos return loss class FocalLoss(nn.Module): '''nn.Module warpper for focal loss''' def __init__(self): super(FocalLoss, self).__init__() self.neg_loss = _neg_loss def forward(self, out, target): return self.neg_loss(out, target) class RegL1Loss(nn.Module): def __init__(self): super(RegL1Loss, self).__init__() def forward(self, output, mask, ind, target): pred = _transpose_and_gather_feat(output, ind) mask = mask.unsqueeze(2).expand_as(pred).float() # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean') loss = F.l1_loss(pred * mask, target * mask, size_average=False) loss = loss / (mask.sum() + 1e-4) return loss