# -*- coding:utf-8 -*- import math import torch import numpy as np import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable class TripletFocalLoss(nn.Module): def __init__(self, class_num, alpha=None, gamma=2, size_average=False,lambda_val=1,prob_diff=True,margin=0.5,alpha_multi_ratio=None): super(TripletFocalLoss, self).__init__() if alpha is None: self.alpha = Variable(torch.ones(class_num, 1)) else: if isinstance(alpha, Variable): self.alpha = alpha else: new_alpha = [1.0/val for val in alpha] sum_val = sum(new_alpha) new_alpha = [val/float(sum_val) for val in new_alpha] if alpha_multi_ratio is not None: new_alpha = [x*y for x,y in zip(new_alpha,alpha_multi_ratio)] else: new_alpha = new_alpha self.alpha = torch.tensor(new_alpha) self.gamma = gamma self.class_num = class_num self.size_average = size_average self.lambda_val = lambda_val self.prob_diff = prob_diff self.class_mask = None self.margin = margin def GetLogProb(self,pred,label): N1 = pred.size(0) C1 = pred.size(1) P1 = torch.softmax(pred,dim=1) P1 = torch.clamp(P1,1e-5,1.0-1e-5) class_mask_1 = pred.data.new(N1, C1).fill_(0) class_mask_1 = Variable(class_mask_1) ids1 = label.view(-1, 1) class_mask_1.scatter_(1, ids1.data, 1.) if self.class_mask is None: self.class_mask = Variable(pred.data.new(N1, C1).fill_(0)) self.class_mask.scatter_(1, ids1.data, 1.) probs1 = (P1*class_mask_1).sum(1).view(-1,1) log_p1 = probs1.log() if pred.is_cuda and not self.alpha.is_cuda: self.alpha = self.alpha.cuda() alpha = self.alpha[ids1.data.view(-1)] return ids1,P1,probs1,log_p1.squeeze(-1),alpha def _CalculateSamePosProbLoss(self,P1,P2,ids1,ids2): margin_val =self.margin class_mask = self.class_mask equal = torch.eq(ids1,ids2) same_probs1 = (P1*class_mask).sum(1).view(-1,1) same_probs2 = (P2*class_mask).sum(1).view(-1,1) zero_prob = torch.zeros_like(same_probs1) one_prob = torch.ones_like(same_probs1) same_log_p1 = torch.where(equal>0,same_probs1,zero_prob) same_log_p2 = torch.where(equal>0,same_probs2,zero_prob) same_prob_diff = ((same_log_p1-same_log_p2)**2).squeeze(1) diff_log_p1 = torch.where(equal==0,same_probs1,zero_prob) diff_log_p2 = torch.where(equal==0,same_probs2,zero_prob) diff_prob_diff = -((diff_log_p1-diff_log_p2)**2).squeeze(1) batch_diff_loss = same_prob_diff+diff_prob_diff return batch_diff_loss def forward(self,loss_input): self.class_mask = None preds,labels,labels_diff = loss_input pred1,pred2,pred3 = preds label1,label2,label3 = labels labels_diff1,labels_diff2 = labels_diff ids1,p1,mask_p1,log_p1,alpha1 = self.GetLogProb(pred1,label1) ids2,p2,mask_p2,log_p2,alpha2 = self.GetLogProb(pred2,label2) ids3,p3,mask_p3,log_p3,alpha3 = self.GetLogProb(pred3,label3) if pred1.is_cuda and not self.alpha.is_cuda: self.alpha = self.alpha.cuda() ################### Single case focal loss temp00 = torch.pow((1-mask_p1), self.gamma) temp01 = alpha1*(torch.pow((1-mask_p1), self.gamma).squeeze(1))*log_p1 temp02 = (torch.pow((1-mask_p1), self.gamma).squeeze(1))*log_p1 batch_loss01 = -alpha1*(torch.pow((1-mask_p1), self.gamma).squeeze(1))*log_p1 batch_loss02 = -alpha2*(torch.pow((1-mask_p2), self.gamma).squeeze(1))*log_p2 batch_loss03 = -alpha3*(torch.pow((1-mask_p3), self.gamma).squeeze(1))*log_p3 if self.prob_diff: se_loss_0 = self._CalculateSamePosProbLoss(p1,p2,ids1,ids2) se_loss_1 = self._CalculateSamePosProbLoss(p1,p3,ids1,ids3) else: se_loss_0 = self._CalculateSamePosProbLoss(pred1,pred2,ids1,ids2) se_loss_1 = self._CalculateSamePosProbLoss(pred1,pred3,ids1,ids3) se_loss = se_loss_0 +se_loss_1+self.margin zero_loss = torch.zeros_like(se_loss) se_loss = torch.where(se_loss>0,se_loss,zero_loss) counts = (se_loss>0).sum() batch_loss = batch_loss01+batch_loss02+batch_loss03 loss_class = batch_loss.mean() loss_triplet = self.lambda_val *(se_loss.sum()/counts) if counts==0: return loss_class else: return loss_class+loss_triplet