1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torch.nn.functional as F
from torch import nn
def weighted_focal_loss_for_cross_entropy(logits, labels, weights, gamma=2.):
log_probs = F.log_softmax(logits, dim=1).gather(1, labels)
probs = F.softmax(logits, dim=1).gather(1, labels)
probs = F.softmax(logits, dim=1).gather(1, labels)
loss = - log_probs * (1 - probs) ** gamma
loss = (weights * loss).sum() / (weights.sum() + 1e-12)
return loss.sum()
def binary_cross_entropy_with_hard_negative_mining(logits, labels, weights, batch_size, num_hard=2, balanced=True):
classify_loss = nn.BCELoss() if balanced else nn.BCELoss(reduction='sum')
probs = torch.sigmoid(logits)[:, 0].view(-1, 1)
# print('logits', logits)
# print('probs', probs)
pos_idcs = labels[:, 0] == 1
pos_prob = probs[pos_idcs, 0]
pos_labels = labels[pos_idcs, 0]
# For those weights are zero, there are 2 cases,
# 1. Because we first random sample num_neg negative boxes for OHEM
# 2. Because those anchor boxes have some overlap with ground truth box,
# we want to maintain high sensitivity, so we do not count those as
# negative. It will not contribute to the loss
neg_idcs = (labels[:, 0] == 0) & (weights[:, 0] != 0)
neg_prob = probs[neg_idcs, 0]
neg_labels = labels[neg_idcs, 0]
if num_hard > 0:
if len(pos_prob) > 0:
neg_prob, neg_labels = OHEM(neg_prob, neg_labels, num_hard * len(pos_prob))
else:
neg_prob, neg_labels = OHEM(neg_prob, neg_labels, num_hard * batch_size)
pos_correct = 0
pos_total = 0
if len(pos_prob) > 0:
cls_loss = 0.5 * classify_loss(pos_prob, pos_labels.float()) + 0.5 * classify_loss(neg_prob, neg_labels.float())
pos_correct = (pos_prob >= 0.5).sum()
pos_total = len(pos_prob)
if not balanced:
cls_loss = cls_loss / ((num_hard + 1) * len(pos_prob))
else:
cls_loss = 0.5 * classify_loss(neg_prob, neg_labels.float())
cls_loss = cls_loss / batch_size
neg_correct = (neg_prob < 0.5).sum()
neg_total = len(neg_prob)
return cls_loss, pos_correct, pos_total, neg_correct, neg_total
def OHEM(neg_output, neg_labels, num_hard):
_, idcs = torch.topk(neg_output, min(num_hard, len(neg_output)))
neg_output = torch.index_select(neg_output, 0, idcs)
neg_labels = torch.index_select(neg_labels, 0, idcs)
return neg_output, neg_labels
def weighted_focal_loss_with_logits(logits, labels, weights, gamma=2.):
log_probs = F.logsigmoid(logits)
probs = torch.sigmoid(logits)
pos_logprobs = log_probs[labels == 1]
neg_logprobs = torch.log(1 - probs[labels == 0])
pos_probs = probs[labels == 1]
neg_probs = 1 - probs[labels == 0]
pos_weights = weights[labels == 1]
neg_weights = weights[labels == 0]
pos_probs = pos_probs.detach()
neg_probs = neg_probs.detach()
pos_loss = - pos_logprobs * (1 - pos_probs) ** gamma
neg_loss = - neg_logprobs * (1 - neg_probs) ** gamma
loss = ((pos_loss * pos_weights).sum() + (neg_loss * neg_weights).sum()) / (weights.sum() + 1e-12)
# print(pos_weights.sum())
# print(neg_weights.sum())
pos_correct = (probs[labels != 0] > 0.5).sum()
pos_total = (labels != 0).sum()
neg_correct = (probs[labels == 0] < 0.5).sum()
neg_total = (labels == 0).sum()
return loss, pos_correct, pos_total, neg_correct, neg_total
log_probs[labels == 0] = torch.log(1 - probs[labels == 0])
probs[labels == 0] = 1 - probs[labels == 0]
loss = - log_probs * (1 - probs) ** gamma
loss = (weights * loss).sum() / (weights.sum() + 1e-12)
pos_correct = (probs[labels != 0] > 0.5).sum()
pos_total = (labels != 0).sum()
neg_correct = (probs[labels == 0] > 0.5).sum()
neg_total = (labels == 0).sum()
return loss, pos_correct, pos_total, neg_correct, neg_total