# -*- coding:utf-8 -*- import math import torch import numpy as np import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable class SurfaceLoss(nn.Module): def __init__(self): super(SurfaceLoss,self).__init__() def forward(self,pred,label): pred = torch.sigmoid(pred) distance_map_flatten = torch.flatten(label) prob_map_flatten = torch.flatten(pred) surface_loss = torch.mul(distance_map_flatten,prob_map_flatten) return surface_loss.mean() class Tversky_loss(nn.Module): def __init__(self,alpha,beta): super(Tversky_loss, self).__init__() self.alpha = alpha self.beta = beta self.smooth_val = 1e-6 def forward(self,pred,label): temp = label.data.cpu().numpy() pred = torch.sigmoid(pred) y_pred_f = torch.flatten(pred) y_true_f = torch.flatten(label) intersection = torch.sum(torch.mul(y_pred_f,y_true_f)) neg_pred_region = 1-y_pred_f neg_pred_region = neg_pred_region.float() negative_label_region = 1-y_true_f FP = torch.sum(torch.mul(negative_label_region,y_pred_f)) FN = torch.sum(torch.mul(y_true_f,neg_pred_region)) weighted_dice = (intersection+self.smooth_val)/(self.alpha*FP+self.beta*FN+intersection+self.smooth_val) weighted_dice = weighted_dice.mean() return 1-weighted_dice class DiceLoss_TwoLabels(nn.Module): def __init__(self,alpha,beta): super(DiceLoss_TwoLabels, self).__init__() self.alpha = alpha self.beta = beta self.smooth_val = 1e-6 def _splitLabel(self,label,val): fake_label = torch.ones_like(label) zero_label = torch.zeros_like(label) cond_label = torch.where(label>=val,fake_label,zero_label) return cond_label def _calculateDice(self,pred,label): y_pred_f = torch.flatten(pred) y_true_f = torch.flatten(label) intersection = torch.sum(torch.mul(y_pred_f,y_true_f)) union = torch.sum(y_pred_f)+torch.sum(y_true_f) weighted_dice = (2*intersection+self.smooth_val)/(union+self.smooth_val) weighted_dice = weighted_dice.mean() return weighted_dice def forward(self,pred,label): pred = torch.sigmoid(pred) label0 = self._splitLabel(label,1) label1 = self._splitLabel(label,2) # print ('mean of label',label0.mean(),label1.mean(),label.mean()) ############# 针对整个union区域的loss dice0 = self._calculateDice(pred,label0) ############## 针对intersection区域的loss dice1 = self._calculateDice(pred,label1) # print ('dice0 is',dice0) # print ('dice1 is',dice1) loss0 = 1-dice0 loss1 = 1-dice1 loss = self.alpha*loss0+self.beta*loss1 return loss class DiceLoss(nn.Module): def __init__(self,alpha,beta): super(DiceLoss, self).__init__() self.alpha = alpha self.beta = beta self.smooth_val = 1e-6 def forward(self,pred,label): pred = torch.sigmoid(pred) y_pred_f = torch.flatten(pred) y_true_f = torch.flatten(label) intersection = torch.sum(torch.mul(y_pred_f,y_true_f)) union = torch.sum(y_pred_f)+torch.sum(y_true_f) weighted_dice = (2*intersection+self.smooth_val)/(union+self.smooth_val) # print ('weighted_dice is',torch.sum(y_pred_f).item(),torch.sum(y_true_f).item(),weighted_dice.item()) weighted_dice = weighted_dice.mean() return 1-weighted_dice if __name__ == "__main__": import os os.environ["CUDA_VISIBLE_DEVICES"]="7" a = np.array([[0,0,0],[0,1,0],[1,0,0]]) b = np.array([[0,1,0],[0,1,0],[0,1,1]]) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') pred = torch.from_numpy(a).to(device=device,dtype=torch.float32) label = torch.from_numpy(b).to(device=device,dtype=torch.float32) criterion = Tversky_loss(1,1) result = criterion(pred,label) print ('result is',result)