# -*- coding:utf-8 -*- import math import torch import numpy as np import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable class FocalLoss(nn.Module): r""" This criterion is a implemenation of Focal Loss, which is proposed in Focal Loss for Dense Object Detection. Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) The losses are averaged across observations for each minibatch. Args: alpha(1D Tensor, Variable) : the scalar factor for this criterion gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), putting more focus on hard, misclassified examples size_average(bool): By default, the losses are averaged over observations for each minibatch. However, if the field size_average is set to False, the losses are instead summed for each minibatch. """ def __init__(self, class_num, alpha=None, gamma=2, size_average=False,show_loss=False,sigmoid_choice=False,alpha_multi_ratio=None,convert_softmax=True): super(FocalLoss, self).__init__() if alpha is None: self.alpha = Variable(torch.ones(class_num, 1)) else: if isinstance(alpha, Variable): self.alpha = alpha else: new_alpha = [1.0/val for val in alpha] sum_val = sum(new_alpha) new_alpha = [val/float(sum_val) for val in new_alpha] if alpha_multi_ratio is not None: new_alpha = [x*y for x,y in zip(new_alpha,alpha_multi_ratio)] else: new_alpha = new_alpha self.alpha = torch.tensor(new_alpha) self.gamma = gamma self.class_num = class_num self.size_average = size_average self.show_loss = show_loss self.sigmoid_choice = False self.convert_softmax = convert_softmax def forward(self, loss_input): inputs, targets,_ = loss_input print('inputs: ', inputs) print('targets: ', targets) inputs = inputs[0] targets = targets[0] N = inputs.size(0) C = inputs.size(1) print(self.convert_softmax, self.sigmoid_choice) if self.convert_softmax: if not self.sigmoid_choice: P = torch.softmax(inputs,dim=1) else: P = torch.sigmoid(inputs) else: P = inputs print('after si: ', P) class_mask = inputs.data.new(N, C).fill_(0) class_mask = Variable(class_mask) ids = targets.view(-1, 1) print('class_mask: ', class_mask) print('ids', ids) class_mask.scatter_(1, ids.data, 1.) print('class_mask: ', class_mask) if inputs.is_cuda and not self.alpha.is_cuda: print('convert alpha to cuda ...') print(self.alpha) self.alpha = self.alpha.cuda() alpha = self.alpha[ids.data.view(-1)] print('self.alpha: ', self.alpha) print('alpha: ', alpha) probs = (P*class_mask).sum(1).view(-1,1).squeeze(-1) print('probs: ', probs) log_p = probs.log() print('log_p: ', log_p) # temp_loss =-(torch.pow((1-probs), self.gamma))*log_p print('pow(a, b)', torch.pow((1-probs), self.gamma)) batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p if self.size_average: loss = batch_loss.mean() else: loss = batch_loss.sum() return loss class FocalLossBin(nn.Module): def __init__(self, alpha=1, gamma=2, logits=False, reduce=True): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.logits = logits self.reduce = reduce def forward(self, loss_input): inputs,targets = loss_input[0],loss_input[1] if self.logits: BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) else: BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False) ############ Target cls prob. If target cls is 1. then val is original prob,otherwise val is 1 - original prob pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss if self.reduce: return torch.mean(F_loss) else: return F_loss class BCEFocalLoss(torch.nn.Module): """ 二分类的Focalloss alpha 固定 """ def __init__(self, gamma=2, alpha=0.25, reduction='sum'): super().__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction def forward(self, loss_input): _input, target = loss_input[0],loss_input[1] pt = torch.sigmoid(_input) pt = torch.clamp(pt,1e-5,1.0-(1e-5)) alpha = self.alpha log_pt,minus_log_pt = torch.log(pt),torch.log(1-pt) loss = - alpha * (1 - pt) ** self.gamma * target * log_pt - \ (1 - alpha) * pt ** self.gamma * (1 - target) * minus_log_pt if self.reduction == 'elementwise_mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss class SeLossWithBCEFocalLoss(torch.nn.Module): def __init__(self, gamma=2, alpha=0.25, reduction='sum',lambda_val=1,prob_diff=False): super().__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction self.lambda_val = lambda_val self.prob_diff_choice = prob_diff def forward(self,loss_input): ''' Calculate CE loss ''' preds,labels,_ = loss_input pred1,pred2 = preds label1,label2 = labels alpha = self.alpha pt = torch.sigmoid(pred1) loss0 = - alpha * (1 - pt) ** self.gamma * label1 * torch.log(pt) - \ (1 - alpha) * pt ** self.gamma * (1 - label1) * torch.log(1 - pt) pt_se = torch.sigmoid(pred2) pt_se = torch.clamp(pt_se,1e-5,1.0) loss1 = - alpha * (1 - pt_se) ** self.gamma * label2 * torch.log(pt_se) - \ (1 - alpha) * pt_se ** self.gamma * (1 - label2) * torch.log(1 - pt_se) same_label_pos = label1*label2 same_label_neg = (1-label1)*(1-label2) if not self.prob_diff_choice: pred_diff = (pred1-pred2)**2 else: pred_diff = (pt-pt_se)**2 pos_se_loss = same_label_pos*pred_diff neg_se_loss = same_label_neg*pred_diff loss = alpha*pos_se_loss+(1-alpha)*neg_se_loss loss = (loss0+loss1) + self.lambda_val*loss if self.reduction == 'elementwise_mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss class PeerLoss(torch.nn.Module): def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean',balance_alpha=0.05,device=None): super().__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction self.balance_alpha = balance_alpha self.device=device def _calculateLoss(self,_input,target): pt = torch.sigmoid(_input) alpha = self.alpha loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \ (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt) pos_loss = alpha * (1 - pt) ** self.gamma * target * torch.log(pt) neg_loss = (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt) pre_pos_loss = target * torch.log(pt) pre_neg_loss = (1 - target) * torch.log(1 - pt) if self.reduction == 'elementwise_mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss def forward(self,_input,target): loss_0 = self._calculateLoss(_input,target) random_label = [np.random.choice([0,1],size=_input.size(-1))] random_label = torch.from_numpy(np.array(random_label)) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') random_label = random_label.to(device=device,dtype=torch.float32) loss_1 = self._calculateLoss(_input,random_label) return loss_0-self.balance_alpha*loss_1 class MIFocalLoss(torch.nn.Module): """ 二分类的Focalloss alpha 固定 """ def __init__(self, gamma=2, alpha=0.25, reduction='sum'): super().__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction def forward(self, _input, target): p0 = self.alpha log_p1 = math.log(1-p0) pt = torch.sigmoid(_input+log_p1) pt = torch.clamp(pt,1e-5,1.0) # print ('pt and target is ',target,pt) alpha = self.alpha loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \ (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt) pos_loss = alpha * (1 - pt) ** self.gamma * target * torch.log(pt) neg_loss = (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt) pre_pos_loss = target * torch.log(pt) pre_neg_loss = (1 - target) * torch.log(1 - pt) if self.reduction == 'elementwise_mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) # print ('loss is ',loss) return loss class SE_FocalLoss(nn.Module): def __init__(self, class_num, alpha=None, gamma=2, size_average=False,lambda_val=1,prob_diff=True,gradient_boost=False,k=1,alpha_multi_ratio=None,convert_softmax=True,margin=0,new_se_loss=False): super(SE_FocalLoss, self).__init__() if alpha is None: self.alpha = Variable(torch.ones(class_num, 1)) else: if isinstance(alpha, Variable): self.alpha = alpha else: new_alpha = [1.0/val for val in alpha] sum_val = sum(new_alpha) new_alpha = [val/float(sum_val) for val in new_alpha] if alpha_multi_ratio is not None: new_alpha = [x*y for x,y in zip(new_alpha,alpha_multi_ratio)] else: new_alpha = new_alpha self.alpha = torch.tensor(new_alpha) self.gamma = gamma self.class_num = class_num self.size_average = size_average self.lambda_val = lambda_val self.prob_diff = prob_diff self.k = k self.gradient_boost = gradient_boost self.convert_softmax = convert_softmax self.margin = margin self.new_se_loss = new_se_loss def _maskOutEasyClasses(self,probs,label,e_val=1e-6): zero_vals = torch.ones_like(probs) *e_val label_indices_full = torch.zeros_like(probs) ids = label.view(-1, 1) label_indices_full.scatter_(1, ids.data, 1.) zeroOut_probs = torch.where(label_indices_full==0,probs,zero_vals) values,indices = zeroOut_probs.topk(k=self.k,dim=1,largest =True) indices = indices.view(-1,1) label_indices_full.scatter_(1,indices.data,1.) probs = torch.where(label_indices_full>0,probs,zero_vals) return probs def _exp(self,probs): values,indices = probs.topk(k=self.k,dim=1,largest =True) new_probs = probs - values exp_probs = torch.exp(new_probs) return exp_probs def GetLogProb(self,pred,label): N1 = pred.size(0) C1 = pred.size(1) print('GetLogProb -- self.gradient_boost: ', self.gradient_boost) print('GetLogProb -- self.convert_softmax: ', self.convert_softmax) if not self.gradient_boost: if self.convert_softmax: P1 = torch.softmax(pred,dim=1) else: P1 = pred P1 = torch.clamp(P1,1e-5,1.0-1e-5) print('GetLogProb -- pred: ', pred) print('GetLogProb -- P1: ', P1) print('GetLogProb -- label: ', label) else: P1 = self._exp(pred) zeroOutProbs = self._maskOutEasyClasses(P1,label) total_loss = torch.sum(zeroOutProbs,dim=1) total_loss = total_loss.repeat((C1,1)).transpose(0,1) P1 = zeroOutProbs/total_loss class_mask_1 = pred.data.new(N1, C1).fill_(0) class_mask_1 = Variable(class_mask_1) self.class_mask = Variable(pred.data.new(N1, C1).fill_(0)) ids1 = label.view(-1, 1) class_mask_1.scatter_(1, ids1.data, 1.) print('GetLogProb -- ids: ', ids1) print('GetLogProb -- class_mask_1: ', class_mask_1) probs1 = (P1*class_mask_1).sum(1).view(-1,1) log_p1 = probs1.log() print('GetLogProb -- probs1: ', probs1) print('GetLogProb -- log_p1: ', log_p1) if pred.is_cuda and not self.alpha.is_cuda: self.alpha = self.alpha.cuda() print('GetLogProb -- self.alpha: ', self.alpha) print('GetLogProb -- ids1.data: ', ids1.data) alpha = self.alpha[ids1.data.view(-1)] return ids1,P1,probs1,log_p1.squeeze(-1),alpha def _CalculateSamePosProbLoss(self,P1,P2,ids1,ids2): print('_CalculateSamePosProbLoss -- ids1: ', ids1) print('_CalculateSamePosProbLoss -- ids2: ', ids2) equal = torch.eq(ids1,ids2) print('_CalculateSamePosProbLoss -- equal: ', equal) # print('_CalculateSamePosProbLoss -- ids1: ', ids1) # print('_CalculateSamePosProbLoss -- ids2: ', ids2) # print('_CalculateSamePosProbLoss -- equal: ', equal) ids = torch.zeros_like(ids1) ids = torch.where(equal>0,ids1,ids) class_mask = self.class_mask class_mask.scatter_(1, ids.data, 1.) same_probs1 = (P1*class_mask).sum(1).view(-1,1) same_probs2 = (P2*class_mask).sum(1).view(-1,1) # print('_CalculateSamePosProbLoss -- same_probs1: ', same_probs1) # print('_CalculateSamePosProbLoss -- same_probs2: ', same_probs2) zero_prob = torch.zeros_like(same_probs1) same_log_p1 = torch.where(equal>0,same_probs1,zero_prob) same_log_p2 = torch.where(equal>0,same_probs2,zero_prob) # print('_CalculateSamePosProbLoss -- same_log_p1: ', same_log_p1) # print('_CalculateSamePosProbLoss -- same_log_p2: ', same_log_p2) diff_prob = (same_log_p1-same_log_p2)**2 diff_prob = diff_prob.squeeze(1) # print('_CalculateSamePosProbLoss -- diff_prob: ', diff_prob) alpha = self.alpha[ids.data.view(-1)] zero_alpha = torch.zeros_like(self.alpha[ids.data.view(-1)]) equal = equal.squeeze(1) alpha = torch.where(equal>0,alpha,zero_alpha) # print('_CalculateSamePosProbLoss -- alpha: ', alpha) batch_diff_loss = -alpha*diff_prob return batch_diff_loss def _CalculateSamePosProbLossNew(self,P1,P2,ids1,ids2): equal = torch.eq(ids1,ids2) ########### generate class_mask class_mask = self.class_mask class_mask.scatter_(1,ids1.data,1) target_class_prob1 = (P1*class_mask).sum(1).view(-1,1) target_class_prob2 = (P2*class_mask).sum(1).view(-1,1) diff_prob = (target_class_prob1-target_class_prob2)**2 diff_prob = diff_prob.squeeze(1) diff_with_margin = torch.clamp(self.margin-torch.abs(target_class_prob1-target_class_prob2),0)**2 deux_loss = torch.where(equal>0,diff_prob,diff_with_margin) return deux_loss def forward(self,loss_input): preds,labels,_ = loss_input print('forward -- preds: ', preds) print('forward -- labels: ', labels) pred1,pred2 = preds label1,label2 = labels print('forward -- preds: ', preds) print('forward -- labels: ', labels) print('forward -- pred1: ', pred1) print('forward -- label1: ', label1) print('forward -- pred2: ', pred2) print('forward -- label2: ', label2) ids1,p1,mask_p1,log_p1,alpha1 = self.GetLogProb(pred1,label1) ids2,p2,mask_p2,log_p2,alpha2 = self.GetLogProb(pred2,label2) if pred1.is_cuda and not self.alpha.is_cuda: print('convert slf.alpha to cuda....') print(self.alpha) self.alpha = self.alpha.cuda() print('self.alpha to cuda. ') # print('self.new_se_loss: ', self.new_se_loss) # print('self.prob_diff: ', self.prob_diff) if not self.new_se_loss: if self.prob_diff: print('forward -- ids1: ', ids1) print('forward -- ids2: ', ids2) se_loss = self._CalculateSamePosProbLoss(p1,p2,ids1,ids2) # print('se loss: ', se_loss) else: se_loss = self._CalculateSamePosProbLoss(pred1,pred2,ids1,ids2) else: if self.prob_diff: se_loss = self._CalculateSamePosProbLossNew(p1,p2,ids1,ids2) else: se_loss = self._CalculateSamePosProbLossNew(pred1,pred2,ids1,ids2) # temp00 = torch.pow((1-mask_p1), self.gamma) # temp01 = alpha1*(torch.pow((1-mask_p1), self.gamma).squeeze(1))*log_p1 # temp02 = (torch.pow((1-mask_p1), self.gamma).squeeze(1))*log_p1 batch_loss01 = -alpha1*(torch.pow((1-mask_p1), self.gamma).squeeze(1))*log_p1 batch_loss02 = -alpha2*(torch.pow((1-mask_p2), self.gamma).squeeze(1))*log_p2 # print('batch_loss 1: ', batch_loss01) # print('batch_loss 2: ', batch_loss02) batch_loss = batch_loss01+batch_loss02+self.lambda_val*se_loss # print('batch loss: ', batch_loss) # temp_loss = batch_loss01+batch_loss02 # print ('cls loss and se loss is',temp_loss.sum(),self.lambda_val*se_loss.sum()) if self.size_average: loss = batch_loss.mean() else: loss = batch_loss.sum() # loss = 0 return loss class FeatureSeLossWithBCEFocalLoss(torch.nn.Module): def __init__(self, gamma=2, alpha=0.25, reduction='sum',lambda_val=1,prob_diff=False): super().__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction self.lambda_val = lambda_val def forward(self,loss_input): ''' Calculate CE loss ''' preds,features,cons,labels,labels_diff = loss_input pred1,pred2 = preds feature1,feature2 = features pred0 = cons[0] label1,label2 = labels label0 = labels_diff[0] feature_loss_weights = 0.25 alpha = self.alpha ####### single case cls loss pt = torch.sigmoid(pred1) loss0 = - alpha * (1 - pt) ** self.gamma * label1 * torch.log(pt) - \ (1 - alpha) * pt ** self.gamma * (1 - label1) * torch.log(1 - pt) pt_se = torch.sigmoid(pred2) pt_se = torch.clamp(pt_se,1e-5,1.0) loss_single = - alpha * (1 - pt_se) ** self.gamma * label2 * torch.log(pt_se) - \ (1 - alpha) * pt_se ** self.gamma * (1 - label2) * torch.log(1 - pt_se) ####### Feature similar loss same_label_pos = label1*label2 same_label_neg = (1-label1)*(1-label2) same_label = same_label_pos+same_label_neg feature_diff = feature1-feature2 base_loss = feature_diff**2 base_loss = torch.sum(base_loss,dim=1) feature_sim_loss = same_label*base_loss F_cls_loss = F.binary_cross_entropy_with_logits(pred0, label0, reduce=False) loss = loss_single + feature_loss_weights*feature_sim_loss+self.lambda_val*(F_cls_loss) if self.reduction == 'elementwise_mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss