# -*- coding:utf-8 -*-
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in 
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.

        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.


    """
    def __init__(self, class_num, alpha=None, gamma=2, size_average=False,show_loss=False,sigmoid_choice=False,alpha_multi_ratio=None,convert_softmax=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                new_alpha = [1.0/val for val in alpha]
                sum_val = sum(new_alpha)
                new_alpha = [val/float(sum_val) for val in new_alpha]
                if alpha_multi_ratio is not None:
                    new_alpha = [x*y for x,y in zip(new_alpha,alpha_multi_ratio)]
                else:
                    new_alpha = new_alpha
                self.alpha = torch.tensor(new_alpha)
  
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average
        self.show_loss = show_loss
        self.sigmoid_choice = False
        self.convert_softmax = convert_softmax
        
    def forward(self, loss_input):
        
        inputs, targets,_ = loss_input

        print('inputs: ', inputs)
        print('targets: ', targets)
        
        inputs = inputs[0]
        targets = targets[0]

        N = inputs.size(0)
        C = inputs.size(1)

        print(self.convert_softmax, self.sigmoid_choice)
        if self.convert_softmax:
            if not self.sigmoid_choice:
                P = torch.softmax(inputs,dim=1)
            else:
                P = torch.sigmoid(inputs)
        else:
            P = inputs
        
        print('after si: ', P)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)

        print('class_mask: ', class_mask)
        print('ids', ids)

        class_mask.scatter_(1, ids.data, 1.)
        print('class_mask: ', class_mask)
        
        if inputs.is_cuda and not self.alpha.is_cuda:
            print('convert alpha to cuda ...')
            print(self.alpha)
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]
        print('self.alpha: ', self.alpha)
        print('alpha: ', alpha)
        
        probs = (P*class_mask).sum(1).view(-1,1).squeeze(-1)
        print('probs: ', probs)

        log_p = probs.log()
        print('log_p: ', log_p)

        # temp_loss =-(torch.pow((1-probs), self.gamma))*log_p 
        print('pow(a, b)', torch.pow((1-probs), self.gamma))
        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
            
        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
    

        return loss

class FocalLossBin(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, loss_input):
        
        
        inputs,targets = loss_input[0],loss_input[1]
        
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        ############ Target cls prob. If target cls is 1. then val is original prob,otherwise val is  1 - original prob
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

class BCEFocalLoss(torch.nn.Module):
    """
    二分类的Focalloss alpha 固定
    """
    def __init__(self, gamma=2, alpha=0.25, reduction='sum'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
 
    def forward(self, loss_input):      
        _input, target = loss_input[0],loss_input[1]
        
        pt = torch.sigmoid(_input)
        pt = torch.clamp(pt,1e-5,1.0-(1e-5))
        alpha = self.alpha
        log_pt,minus_log_pt = torch.log(pt),torch.log(1-pt)
        loss = - alpha * (1 - pt) ** self.gamma * target * log_pt - \
               (1 - alpha) * pt ** self.gamma * (1 - target) * minus_log_pt

        if self.reduction == 'elementwise_mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        return loss

class SeLossWithBCEFocalLoss(torch.nn.Module):
    def __init__(self, gamma=2, alpha=0.25, reduction='sum',lambda_val=1,prob_diff=False):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.lambda_val = lambda_val
        self.prob_diff_choice = prob_diff
    
    def forward(self,loss_input):
        '''
        Calculate CE loss
        '''
        preds,labels,_ = loss_input
        pred1,pred2 = preds
        label1,label2 = labels
        
        alpha = self.alpha


        pt = torch.sigmoid(pred1)
        loss0 = - alpha * (1 - pt) ** self.gamma * label1 * torch.log(pt) - \
               (1 - alpha) * pt ** self.gamma * (1 - label1) * torch.log(1 - pt)
        pt_se = torch.sigmoid(pred2)
        pt_se = torch.clamp(pt_se,1e-5,1.0)
        loss1 = - alpha * (1 - pt_se) ** self.gamma * label2 * torch.log(pt_se) - \
               (1 - alpha) * pt_se ** self.gamma * (1 - label2) * torch.log(1 - pt_se)

        same_label_pos = label1*label2
        same_label_neg = (1-label1)*(1-label2)
        if not self.prob_diff_choice:
            pred_diff = (pred1-pred2)**2
        else:
            pred_diff = (pt-pt_se)**2
        pos_se_loss = same_label_pos*pred_diff
        neg_se_loss = same_label_neg*pred_diff
        loss = alpha*pos_se_loss+(1-alpha)*neg_se_loss
        loss = (loss0+loss1) + self.lambda_val*loss

        if self.reduction == 'elementwise_mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        return loss


class PeerLoss(torch.nn.Module):
    def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean',balance_alpha=0.05,device=None):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.balance_alpha = balance_alpha
        self.device=device

    def _calculateLoss(self,_input,target):
        pt = torch.sigmoid(_input)
        alpha = self.alpha
        loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
               (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)

        pos_loss = alpha * (1 - pt) ** self.gamma * target * torch.log(pt) 
        neg_loss = (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
        pre_pos_loss = target * torch.log(pt)
        pre_neg_loss =  (1 - target) * torch.log(1 - pt)
        if self.reduction == 'elementwise_mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        return loss

    def forward(self,_input,target):
        loss_0 = self._calculateLoss(_input,target)
        random_label = [np.random.choice([0,1],size=_input.size(-1))]
        random_label = torch.from_numpy(np.array(random_label))
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        random_label = random_label.to(device=device,dtype=torch.float32)
        loss_1 = self._calculateLoss(_input,random_label)
        return loss_0-self.balance_alpha*loss_1
    

class MIFocalLoss(torch.nn.Module):
    """
    二分类的Focalloss alpha 固定
    """
    def __init__(self, gamma=2, alpha=0.25, reduction='sum'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
 
    def forward(self, _input, target):
        p0 = self.alpha
        log_p1 = math.log(1-p0)
        pt = torch.sigmoid(_input+log_p1)
        pt = torch.clamp(pt,1e-5,1.0)
        # print ('pt and target is ',target,pt)
        alpha = self.alpha
        loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
               (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)

        pos_loss = alpha * (1 - pt) ** self.gamma * target * torch.log(pt) 
        neg_loss = (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
        pre_pos_loss = target * torch.log(pt)
        pre_neg_loss =  (1 - target) * torch.log(1 - pt)
        if self.reduction == 'elementwise_mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        # print ('loss is ',loss)
        return loss
    
    
    
class SE_FocalLoss(nn.Module):
    def __init__(self, class_num, alpha=None, gamma=2, size_average=False,lambda_val=1,prob_diff=True,gradient_boost=False,k=1,alpha_multi_ratio=None,convert_softmax=True,margin=0,new_se_loss=False):
        super(SE_FocalLoss, self).__init__()
        
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                new_alpha = [1.0/val for val in alpha]
                sum_val = sum(new_alpha)
                new_alpha = [val/float(sum_val) for val in new_alpha]
                if alpha_multi_ratio is not None:
                    new_alpha = [x*y for x,y in zip(new_alpha,alpha_multi_ratio)]
                else:
                    new_alpha = new_alpha
                self.alpha = torch.tensor(new_alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average
        self.lambda_val = lambda_val
        self.prob_diff = prob_diff
        self.k = k
        self.gradient_boost = gradient_boost
        self.convert_softmax = convert_softmax
        self.margin = margin
        self.new_se_loss = new_se_loss
        
        
        
    def _maskOutEasyClasses(self,probs,label,e_val=1e-6):
        zero_vals = torch.ones_like(probs) *e_val
        label_indices_full = torch.zeros_like(probs)
        ids = label.view(-1, 1)
        label_indices_full.scatter_(1, ids.data, 1.)
        zeroOut_probs = torch.where(label_indices_full==0,probs,zero_vals)
        
        
        
        values,indices = zeroOut_probs.topk(k=self.k,dim=1,largest =True)
        indices = indices.view(-1,1)
        label_indices_full.scatter_(1,indices.data,1.)
        
        probs = torch.where(label_indices_full>0,probs,zero_vals)
        return probs
    
    def _exp(self,probs):  
        values,indices = probs.topk(k=self.k,dim=1,largest =True)
        new_probs = probs - values
        exp_probs = torch.exp(new_probs)
        return exp_probs
        
        
    def GetLogProb(self,pred,label):
        N1 = pred.size(0)
        C1 = pred.size(1)
        print('GetLogProb -- self.gradient_boost: ', self.gradient_boost)
        print('GetLogProb -- self.convert_softmax: ', self.convert_softmax)
        if not self.gradient_boost:
            if self.convert_softmax:
                P1 = torch.softmax(pred,dim=1)
            else:
                P1 = pred
            P1 = torch.clamp(P1,1e-5,1.0-1e-5)
            print('GetLogProb -- pred: ', pred)
            print('GetLogProb -- P1: ', P1)
            print('GetLogProb -- label: ', label)
        else:
            P1 = self._exp(pred)
            zeroOutProbs = self._maskOutEasyClasses(P1,label)
            total_loss = torch.sum(zeroOutProbs,dim=1)
            total_loss = total_loss.repeat((C1,1)).transpose(0,1)
            P1 = zeroOutProbs/total_loss
        
        class_mask_1 = pred.data.new(N1, C1).fill_(0)
        class_mask_1 = Variable(class_mask_1)
        self.class_mask = Variable(pred.data.new(N1, C1).fill_(0))
        ids1 = label.view(-1, 1)
        class_mask_1.scatter_(1, ids1.data, 1.)
        print('GetLogProb -- ids: ', ids1)
        print('GetLogProb -- class_mask_1: ', class_mask_1)
        probs1 = (P1*class_mask_1).sum(1).view(-1,1)
        log_p1 = probs1.log()
        print('GetLogProb -- probs1: ', probs1)
        print('GetLogProb -- log_p1: ', log_p1)
        if pred.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        print('GetLogProb -- self.alpha: ', self.alpha)
        print('GetLogProb -- ids1.data: ', ids1.data)
        alpha = self.alpha[ids1.data.view(-1)]
        return ids1,P1,probs1,log_p1.squeeze(-1),alpha
        
    def _CalculateSamePosProbLoss(self,P1,P2,ids1,ids2):
        print('_CalculateSamePosProbLoss -- ids1: ', ids1)
        print('_CalculateSamePosProbLoss -- ids2: ', ids2)
        equal = torch.eq(ids1,ids2)
        print('_CalculateSamePosProbLoss -- equal: ', equal)
        # print('_CalculateSamePosProbLoss -- ids1: ', ids1)
        # print('_CalculateSamePosProbLoss -- ids2: ', ids2)
        # print('_CalculateSamePosProbLoss -- equal: ', equal)
        ids = torch.zeros_like(ids1)
        ids = torch.where(equal>0,ids1,ids)
        
        class_mask = self.class_mask
        class_mask.scatter_(1, ids.data, 1.)

        same_probs1 = (P1*class_mask).sum(1).view(-1,1)
        same_probs2 = (P2*class_mask).sum(1).view(-1,1)
        # print('_CalculateSamePosProbLoss -- same_probs1: ', same_probs1)
        # print('_CalculateSamePosProbLoss -- same_probs2: ', same_probs2)

        zero_prob = torch.zeros_like(same_probs1)
        
        same_log_p1 = torch.where(equal>0,same_probs1,zero_prob)
        same_log_p2 = torch.where(equal>0,same_probs2,zero_prob)
        # print('_CalculateSamePosProbLoss -- same_log_p1: ', same_log_p1)
        # print('_CalculateSamePosProbLoss -- same_log_p2: ', same_log_p2)

        diff_prob = (same_log_p1-same_log_p2)**2
        diff_prob = diff_prob.squeeze(1)
        # print('_CalculateSamePosProbLoss -- diff_prob: ', diff_prob)
        
        alpha = self.alpha[ids.data.view(-1)]
        zero_alpha = torch.zeros_like(self.alpha[ids.data.view(-1)])
        equal = equal.squeeze(1)
        
        alpha = torch.where(equal>0,alpha,zero_alpha)
        # print('_CalculateSamePosProbLoss -- alpha: ', alpha)
        
        batch_diff_loss = -alpha*diff_prob
 
        return batch_diff_loss

    def _CalculateSamePosProbLossNew(self,P1,P2,ids1,ids2):
        equal = torch.eq(ids1,ids2)
        
        ########### generate class_mask
        class_mask = self.class_mask
        class_mask.scatter_(1,ids1.data,1)
        
        target_class_prob1 = (P1*class_mask).sum(1).view(-1,1)
        target_class_prob2 = (P2*class_mask).sum(1).view(-1,1)
        
        diff_prob = (target_class_prob1-target_class_prob2)**2
        diff_prob = diff_prob.squeeze(1)
        diff_with_margin = torch.clamp(self.margin-torch.abs(target_class_prob1-target_class_prob2),0)**2
        
        deux_loss = torch.where(equal>0,diff_prob,diff_with_margin)
        return deux_loss
        
        
        
    def forward(self,loss_input):
        preds,labels,_ = loss_input
        print('forward -- preds: ', preds)
        print('forward -- labels: ', labels)
        pred1,pred2 = preds
        label1,label2 = labels

        print('forward -- preds: ', preds)
        print('forward -- labels: ', labels)
        print('forward -- pred1: ', pred1)
        print('forward -- label1: ', label1)
        print('forward -- pred2: ', pred2)
        print('forward -- label2: ', label2)

        ids1,p1,mask_p1,log_p1,alpha1 = self.GetLogProb(pred1,label1)
        ids2,p2,mask_p2,log_p2,alpha2 = self.GetLogProb(pred2,label2)

        if pred1.is_cuda and not self.alpha.is_cuda:
            print('convert slf.alpha to cuda....')
            print(self.alpha)
            self.alpha = self.alpha.cuda()
        print('self.alpha to cuda. ')

        # print('self.new_se_loss: ', self.new_se_loss)
        # print('self.prob_diff: ', self.prob_diff)
        if not self.new_se_loss:
            if self.prob_diff:
                print('forward -- ids1: ', ids1)
                print('forward -- ids2: ', ids2)
                se_loss = self._CalculateSamePosProbLoss(p1,p2,ids1,ids2)
                # print('se loss: ', se_loss)
            else:
                se_loss = self._CalculateSamePosProbLoss(pred1,pred2,ids1,ids2)
        else:
            if self.prob_diff:
                se_loss = self._CalculateSamePosProbLossNew(p1,p2,ids1,ids2)
            else:
                se_loss = self._CalculateSamePosProbLossNew(pred1,pred2,ids1,ids2)
    
        
        # temp00 = torch.pow((1-mask_p1), self.gamma)
        # temp01 = alpha1*(torch.pow((1-mask_p1), self.gamma).squeeze(1))*log_p1
        # temp02 = (torch.pow((1-mask_p1), self.gamma).squeeze(1))*log_p1

        batch_loss01 = -alpha1*(torch.pow((1-mask_p1), self.gamma).squeeze(1))*log_p1
        batch_loss02 = -alpha2*(torch.pow((1-mask_p2), self.gamma).squeeze(1))*log_p2
        # print('batch_loss 1: ', batch_loss01)
        # print('batch_loss 2: ', batch_loss02)

        batch_loss = batch_loss01+batch_loss02+self.lambda_val*se_loss
        # print('batch loss: ', batch_loss)
        # temp_loss =  batch_loss01+batch_loss02
#         print ('cls loss and se loss is',temp_loss.sum(),self.lambda_val*se_loss.sum())
        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
#         loss = 0
        return loss


class FeatureSeLossWithBCEFocalLoss(torch.nn.Module):
    def __init__(self, gamma=2, alpha=0.25, reduction='sum',lambda_val=1,prob_diff=False):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.lambda_val = lambda_val
    
    def forward(self,loss_input):
        '''
        Calculate CE loss
        '''
        preds,features,cons,labels,labels_diff = loss_input
        pred1,pred2 = preds
        feature1,feature2 = features
        pred0 = cons[0]
        label1,label2 = labels
        label0 = labels_diff[0]
        
        feature_loss_weights = 0.25
        alpha = self.alpha
        ####### single case cls loss
        pt = torch.sigmoid(pred1)
        loss0 = - alpha * (1 - pt) ** self.gamma * label1 * torch.log(pt) - \
               (1 - alpha) * pt ** self.gamma * (1 - label1) * torch.log(1 - pt)
        pt_se = torch.sigmoid(pred2)
        pt_se = torch.clamp(pt_se,1e-5,1.0)
        loss_single = - alpha * (1 - pt_se) ** self.gamma * label2 * torch.log(pt_se) - \
               (1 - alpha) * pt_se ** self.gamma * (1 - label2) * torch.log(1 - pt_se)

        ####### Feature similar loss
        same_label_pos = label1*label2
        same_label_neg = (1-label1)*(1-label2)
        same_label = same_label_pos+same_label_neg
        
        feature_diff = feature1-feature2
        
        base_loss = feature_diff**2
        base_loss = torch.sum(base_loss,dim=1)
        
                      
        feature_sim_loss = same_label*base_loss

        F_cls_loss = F.binary_cross_entropy_with_logits(pred0, label0, reduce=False)
        loss = loss_single + feature_loss_weights*feature_sim_loss+self.lambda_val*(F_cls_loss)

        if self.reduction == 'elementwise_mean':
            loss = torch.mean(loss)
        elif self.reduction == 'sum':
            loss = torch.sum(loss)
        return loss