# -*- coding:utf-8 -*- import math import torch import numpy as np import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable class TripletBinFocal(torch.nn.Module): def __init__(self, class_num, alpha=None, gamma=2, size_average=True,lambda_val=1,prob_diff=True,margin=0.5): super().__init__() self.gamma = gamma self.alpha = alpha self.size_average = size_average self.lambda_val = lambda_val self.margin = margin self.prob_diff_choice = prob_diff def forward(self,loss_input): ''' Calculate CE loss ''' preds,labels,_ = loss_input pred1,pred2,pred3 = preds label1,label2,label3 = labels alpha = self.alpha ############### Calculate focal loss for each sample pt = torch.sigmoid(pred1) loss0 = - alpha * (1 - pt) ** self.gamma * label1 * torch.log(pt) - \ (1 - alpha) * pt ** self.gamma * (1 - label1) * torch.log(1 - pt) pt_same = torch.sigmoid(pred2) pt_same = torch.clamp(pt_same,1e-5,1.0) loss1 = - alpha * (1 - pt_same) ** self.gamma * label2 * torch.log(pt_same) - \ (1 - alpha) * pt_same ** self.gamma * (1 - label2) * torch.log(1 - pt_same) pt_diff = torch.sigmoid(pred3) pt_diff = torch.clamp(pt_diff,1e-5,1.0) loss2 = -alpha*(1-pt_diff)**self.gamma*label3*torch.log(pt_diff) - \ (1-alpha)*pt_diff ** self.gamma *(1-label3) * torch.log(1-pt_diff) sample_focal_loss = loss0 + loss1 + loss2 ############### Calculate triple loss part if not self.prob_diff_choice: same_sample_pred_diff = (pred1-pred2)**2 diff_sample_pred_diff = (pred1-pred3)**2 else: same_sample_pred_diff = (pt-pt_same)**2 diff_sample_pred_diff = (pt-pt_diff)**2 zero_loss = torch.zeros_like(same_sample_pred_diff) triplet_loss = same_sample_pred_diff - diff_sample_pred_diff + self.margin triplet_loss = torch.where(triplet_loss>0,triplet_loss,zero_loss) # print ('triplet_loss is',triplet_loss) # print ('sample_focal_loss',torch.mean(sample_focal_loss),torch.mean(triplet_loss),self.lambda_val *torch.mean(triplet_loss)) loss = sample_focal_loss + self.lambda_val * triplet_loss if self.size_average : loss = torch.mean(loss) else: loss = torch.sum(loss) return loss