import torch
from torch import nn

from .center_loss import _sigmoid, FocalLoss, RegL1Loss


class Loss(nn.Module):
    def __init__(self, rpn_loss_fun, bbox_loss_fun, use_rcnn):
        super(Loss, self).__init__()
        self.rpn_loss_fun = rpn_loss_fun
        self.bbox_loss_fun = bbox_loss_fun
        self.rcnn_cls_loss, self.rcnn_reg_loss = torch.tensor(0.).cuda(), torch.tensor(0.).cuda()
        self.rcnn_stats = None
        self.use_rcnn = use_rcnn

    def forward(self, rpn_logits_flat, rpn_deltas_flat, rpn_labels,
                rpn_label_weights, rpn_targets, rpn_target_weights,
                rcnn_logits=None, rcnn_deltas=None, rcnn_labels=None, rcnn_targets=None):
        rpn_cls_loss, rpn_reg_loss, rpn_stats = self.rpn_loss_fun(rpn_logits_flat, rpn_deltas_flat, rpn_labels,
                                                                  rpn_label_weights, rpn_targets, rpn_target_weights)

        if self.use_rcnn:
            self.rcnn_cls_loss, self.rcnn_reg_loss, self.rcnn_stats = self.bbox_loss_fun(rcnn_logits, rcnn_deltas,
                                                                                         rcnn_labels, rcnn_targets)

        total_loss = rpn_cls_loss + rpn_reg_loss + self.rcnn_cls_loss + self.rcnn_reg_loss

        return total_loss, [rpn_cls_loss, rpn_reg_loss, self.rcnn_cls_loss,
                            self.rcnn_reg_loss], rpn_stats, self.rcnn_stats


class CT_Loss(nn.Module):
    def __init__(self):
        super(CT_Loss, self).__init__()
        self.crit = FocalLoss()
        self.crit_reg = RegL1Loss()  # RegLoss()

    def forward(self, rpn_logits, rpn_deltas, rpn_offsets, gaussian_hm, center_idxs, bboxes_diameters, reg_offset,
                reg_mask):
        hm_loss, reg_loss, off_loss = 0, 0, 0

        rpn_logits = _sigmoid(rpn_logits)

        hm_loss = self.crit(rpn_logits, gaussian_hm)

        reg_loss = self.crit_reg(rpn_deltas, reg_mask, center_idxs, bboxes_diameters)

        if torch.isnan(reg_loss.detach().cpu()):
            reg_loss.data = torch.tensor(0.).cuda()

        off_loss = self.crit_reg(rpn_offsets, reg_mask, center_idxs, reg_offset)

        if torch.isnan(off_loss.detach().cpu()):
            off_loss.data = torch.tensor(0.).cuda()

        loss = hm_loss + 0.1 * reg_loss + 1 * off_loss

        loss_stats = [hm_loss, reg_loss, off_loss]

        return loss, loss_stats