# -*- coding:utf-8 -*- import math import torch import numpy as np import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable class GradientBoostFocal(nn.Module): r""" This criterion is a implemenation of Focal Loss, which is proposed in Focal Loss for Dense Object Detection. Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) The losses are averaged across observations for each minibatch. Args: alpha(1D Tensor, Variable) : the scalar factor for this criterion gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), putting more focus on hard, misclassified examples size_average(bool): By default, the losses are averaged over observations for each minibatch. However, if the field size_average is set to False, the losses are instead summed for each minibatch. """ def __init__(self, class_num, alpha=None, gamma=2, size_average=False,show_loss=False,sigmoid_choice=False,k=1): super(GradientBoostFocal, self).__init__() if alpha is None: self.alpha = Variable(torch.ones(class_num, 1)) else: if isinstance(alpha, Variable): self.alpha = alpha else: new_alpha = [1.0/val for val in alpha] sum_val = sum(new_alpha) new_alpha = [val/float(sum_val) for val in new_alpha] self.alpha = torch.tensor(new_alpha) self.gamma = gamma self.class_num = class_num self.size_average = size_average self.show_loss = show_loss self.sigmoid_choice = False self.k = k def _maskOutEasyClasses(self,probs,label,e_val=1e-6): zero_vals = torch.ones_like(probs) *e_val label_indices_full = torch.zeros_like(probs) ids = label.view(-1, 1) label_indices_full.scatter_(1, ids.data, 1.) zeroOut_probs = torch.where(label_indices_full==0,probs,zero_vals) values,indices = zeroOut_probs.topk(k=self.k,dim=1,largest =True) indices = indices.view(-1,1) label_indices_full.scatter_(1,indices.data,1.) probs = torch.where(label_indices_full>0,probs,zero_vals) # print ('probs is',probs) return probs def _exp(self,probs): values,indices = probs.topk(k=self.k,dim=1,largest =True) new_probs = probs - values exp_probs = torch.exp(new_probs) return exp_probs def forward(self, loss_input): inputs, targets,_ = loss_input inputs = inputs[0] targets = targets[0] N = inputs.size(0) C = inputs.size(1) # P = torch.softmax(inputs,dim=1) P = self._exp(inputs) zeroOutProbs = self._maskOutEasyClasses(P,targets) total_loss = torch.sum(zeroOutProbs,dim=1) total_loss = total_loss.repeat((C,1)).transpose(0,1) P = zeroOutProbs/total_loss class_mask = inputs.data.new(N, C).fill_(0) class_mask = Variable(class_mask) ids = targets.view(-1, 1) class_mask.scatter_(1, ids.data, 1.) if inputs.is_cuda and not self.alpha.is_cuda: self.alpha = self.alpha.cuda() alpha = self.alpha[ids.data.view(-1)] probs = (P*class_mask).sum(1).view(-1,1).squeeze(-1) log_p = probs.log() temp_loss =-(torch.pow((1-probs), self.gamma))*log_p batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p if self.size_average: loss = batch_loss.mean() else: loss = batch_loss.sum() return loss