# -*- coding:utf-8 -*- import math import torch import numpy as np import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable class SeLossWithBCEFocalLoss(torch.nn.Module): def __init__(self, gamma=2, alpha=0.25, reduction='sum',lambda_val=1): super().__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction self.lambda_val = lambda_val self.prob_diff_choice = False def forward(self,loss_input): ''' Calculate CE loss ''' preds,labels = loss_input pred1,pred2 = preds label1,label2 = labels alpha = self.alpha min_prob = 1e-5 pt = torch.sigmoid(pred1) pt = torch.clamp(pt,min_prob,1.0-(min_prob)) loss0 = - alpha * (1 - pt) ** self.gamma * label1 * torch.log(pt) - \ (1 - alpha) * pt ** self.gamma * (1 - label1) * torch.log(1 - pt) pt_se = torch.sigmoid(pred2) pt_se = torch.clamp(pt_se,min_prob,1.0-min_prob) loss1 = - alpha * (1 - pt_se) ** self.gamma * label2 * torch.log(pt_se) - \ (1 - alpha) * pt_se ** self.gamma * (1 - label2) * torch.log(1 - pt_se) same_label_pos = label1*label2 same_label_neg = (1-label1)*(1-label2) if not self.prob_diff_choice: pred_diff = (pred1-pred2)**2 else: pred_diff = (pt-pt_se)**2 pos_se_loss = same_label_pos*pred_diff neg_se_loss = same_label_neg*pred_diff loss = alpha*pos_se_loss+(1-alpha)*neg_se_loss loss = (loss0+loss1) + self.lambda_val*loss if self.reduction == 'elementwise_mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss