import torch import torch.nn as nn class BCEFocalLoss(nn.Module): def __init__(self, gamma=2, alpha=0.5, reduction='elementwise_mean') -> None: super().__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction def forward(self, _input, target): pt = torch.sigmoid(_input) alpha = self.alpha