import random import torch import numpy as np import torch.nn.functional as F from loss import binary_cross_entropy_with_hard_negative_mining, weighted_focal_loss_for_cross_entropy from ..util import box_transform_numpy, convert_xyxyzz, box_transform_inv, clip_boxes, convert_zyxdhw from layers.overlap_3d import overlap_3d from layers.nms_3d import nms_3d # from custom.util import box_iou as overlap_3d # from custom.rpn_nms import nms_3d_hky as nms_3d from torch.autograd import Variable class RpnModule(object): def __init__(self, cfg, mode): self.mode = mode self.spacing = cfg.DATA.DATA_PROCESS.SPACING if self.mode in ['train', 'valid']: self.pre_nms_score_threshold = cfg.MODEL.RPN.TRAIN_PRE_NMS_SCORE_THRESH elif self.mode in ['test', 'depoly']: self.pre_nms_score_threshold = cfg.MODEL.RPN.TEST_PRE_NMS_SCORE_THRESH self.pre_nms_top_k = cfg.MODEL.RPN.DEPLOY_PRE_NMS_TOP_K else: raise ValueError('invalid mode = %s?' % self.mode) self.nms_overlap_threshold = cfg.MODEL.RPN.NMS_IOU_THRESH self.box_reg_weight = cfg.MODEL.BBOX_REG_WEIGHT self.num_hard = cfg.DATA.DATA_PROCESS.NUM_HARD self.num_neg = cfg.DATA.DATA_PROCESS.NUM_NEG self.rpn_train_bg_thresh_high = cfg.MODEL.RPN.BG_THRESH_HIGH self.rpn_train_fg_thresh_low = cfg.MODEL.RPN.FG_THRESH_LOW def rpn_nms_train(self, inputs, window, logits, deltas): device = logits.device logits = torch.sigmoid(logits) batch_size, num_anchors, _ = logits.size() proposals = [] for b in range(batch_size): proposal = [torch.empty((0, 8), dtype=torch.float32, device=device), ] # proposal = [torch.empty((0, 8), dtype=torch.float32, device=device).half(), ] ps = logits[b, :, 0].reshape(-1, 1) ds = deltas[b, :, :] index = (ps >= self.pre_nms_score_threshold).squeeze().nonzero().squeeze() if index.nelement() > 0: p = torch.index_select(ps, 0, index).squeeze(1) d = torch.index_select(ds, 0, index) w = torch.index_select(window, 0, index) box = box_transform_inv(w, d, self.box_reg_weight) box = clip_boxes(box, inputs.shape[2:]) # box = clip_boxes(box, inputs) box = convert_xyxyzz(box) keep = nms_3d(box, p, self.nms_overlap_threshold) # keep = nms_3d(box.float(), p.float(), self.nms_overlap_threshold) res_box = torch.index_select(box, 0, keep) res_p = torch.index_select(p, 0, keep) res_box = convert_zyxdhw(res_box) res_p = torch.unsqueeze(res_p, 1) b_tensor = torch.full((res_box.size()[0], 1), b, device=device) prop = torch.cat((b_tensor, res_p, res_box), dim=1) # prop = torch.cat((b_tensor.half(), res_p.half(), res_box.half()), dim=1) proposal.append(prop) proposal = torch.cat(proposal, dim=0) proposals.append(proposal) proposals = torch.cat(proposals, dim=0) return Variable(proposals) def rpn_nms(self, inputs, window, logits, deltas): b = 0 device = logits.device logits = torch.sigmoid(logits) ps = logits[b, :, 0].reshape(-1, 1) ds = deltas[b, :, :] p, index = ps.squeeze().topk(self.pre_nms_top_k, dim=0, sorted=True) d = torch.index_select(ds, 0, index) w = torch.index_select(window, 0, index) box = box_transform_inv(w, d, self.box_reg_weight) box = clip_boxes(box, inputs) box = convert_xyxyzz(box) keep = nms_3d(box.float(), p.float(), self.nms_overlap_threshold) res_box = torch.index_select(box, 0, keep) res_p = torch.index_select(p, 0, keep) res_box = convert_zyxdhw(res_box) res_p = torch.unsqueeze(res_p, 1) b_tensor = torch.full((res_box.size()[0], 1), b, device=device) # print(b_tensor.type(),res_p.type(),res_box.type()) prop = torch.cat((b_tensor.half(), res_p.half(), res_box.half()), dim=1) # prop = torch.cat((b_tensor, res_p, res_box), dim=1) return Variable(prop) def loss(self, logits, deltas, labels, label_weights, targets, target_weights, delta_sigma=3.0): batch_size, num_windows, num_classes = logits.size() batch_size_k = batch_size labels = labels.long() # Calculate classification score pos_correct, pos_total, neg_correct, neg_total = 0, 0, 0, 0 batch_size = batch_size * num_windows logits = logits.view(batch_size, num_classes) labels = labels.view(batch_size, 1) label_weights = label_weights.view(batch_size, 1) # Make sure OHEM is performed only in training mode if self.mode not in ['train']: self.num_hard = 10000000 rpn_cls_loss, pos_correct, pos_total, neg_correct, neg_total = \ binary_cross_entropy_with_hard_negative_mining(logits, labels, \ label_weights, batch_size_k, self.num_hard) # rpn_cls_loss, pos_correct, pos_total, neg_correct, neg_total = \ # weighted_focal_loss_for_cross_entropy(logits, labels, label_weights) # Calculate regression deltas = deltas.view(batch_size, 6) targets = targets.view(batch_size, 6) index = (labels != 0).nonzero()[:, 0] deltas = deltas[index] targets = targets[index] rpn_reg_loss = 0 reg_losses = [] for i in range(6): l = F.smooth_l1_loss(deltas[:, i], targets[:, i]) rpn_reg_loss += l reg_losses.append(l.data.item()) if torch.isnan(rpn_reg_loss.detach().cpu()): rpn_reg_loss.data = torch.tensor(0.).cuda() reg_losses = [torch.tensor(0.).cuda() for i in range(6)] return rpn_cls_loss, rpn_reg_loss, [pos_correct, pos_total, neg_correct, neg_total, reg_losses[0], reg_losses[1], reg_losses[2], reg_losses[3], reg_losses[4], reg_losses[5]] def make_one_rpn_target(self, window, truth_bbox, diameter_range=None): """ Generate region proposal targets for one batch window: list of anchor bounding boxes, [z, y, x, d, h, w] truth_bbox: list of ground truth bounding boxes, [z, y, x, d, h, w] truth_label: list of grount truth class label for each object in the correponding truth_bbox return torch tensors label: positive or negative (1 or 0) for each anchor box label_assign: index of the ground truth box, to which the anchor box is matched to label_weight: class weight for each sample, zero means current sample is protected, and won't contribute to loss target: bounding box regression terms target_weight: weight for each regression term, by default it should all be ones """ num_window = len(window) label = np.zeros((num_window,), np.float32) label_assign = np.zeros((num_window,), np.int32) - 1 label_weight = np.zeros((num_window,), np.float32) target = np.zeros((num_window, 6), np.float32) target_weight = np.zeros((num_window,), np.float32) target_c_flag = False bboxes_label = truth_bbox.get_field("bboxes_label") if truth_bbox.has_field("target_idx"): target_idx = truth_bbox.get_field("target_idx") target_c = truth_bbox.data[target_idx].reshape(1, -1) target_c_label = bboxes_label[target_idx] target_c_flag = True truth_bbox = truth_bbox.data truth_bbox = truth_bbox[bboxes_label > -1] if target_c_flag: if target_c_label < 0 and len(truth_bbox): truth_bbox = np.r_[truth_bbox, target_c] # truth_bbox = np.vstack([truth_bbox, target_c]) num_truth_bbox = len(truth_bbox) if num_truth_bbox: # Get sure background anchor boxes overlap = overlap_3d(convert_xyxyzz(torch.from_numpy(window).float().cuda()), convert_xyxyzz(torch.from_numpy(truth_bbox).float().cuda())) # print(overlap.size()) # print(overlap) overlap = overlap.cpu().data.numpy() # For each anchor box, get the index of the ground truth box that # has the largest IoU with it argmax_overlap = np.argmax(overlap, 1) # For each anchor box, get the IoU of the ground truth box that # has the largest IoU with it max_overlap = overlap[np.arange(num_window), argmax_overlap] # The anchor box is a sure background, if its largest IoU is less than # a threshold if diameter_range == [20, 99999]: self.rpn_train_bg_thresh_high = 0.02 elif diameter_range == [12, 32]: self.rpn_train_bg_thresh_high = 0.1 elif diameter_range == [8, 16]: self.rpn_train_bg_thresh_high = 0.2 else: print('!!!!!!!!!!!!!!!!!!') raise None bg_index = np.where(max_overlap < self.rpn_train_bg_thresh_high)[0] label[bg_index] = 0 label_weight[bg_index] = 1 if self.mode in ['train']: bg_index = np.where((label_weight != 0) & (label == 0))[0] # Random sample num_neg negative anchor boxes first # This is very strange, but it works well in practice # It makes the use of hard negative example mining loss, not # actually hard negative example mining. label_weight[bg_index] = 0 idx = random.sample(range(len(bg_index)), min(self.num_neg, len(bg_index))) bg_index = bg_index[idx] label_weight[bg_index] = 1 if target_c_flag and (diameter_range[0] <= max((target_c[0, -3:] * self.spacing)[1:3]) <= diameter_range[1]) \ or (target_c_flag and (diameter_range is None)): # Get sure foreground anchor boxes for target_c # The anchor box is a sure foreground, if its largest IoU is larger or # equal than a threshold overlap = overlap_3d(convert_xyxyzz(torch.from_numpy(window).float().cuda()), convert_xyxyzz(torch.from_numpy(target_c).float().cuda())).cpu().data.numpy() argmax_overlap = np.argmax(overlap, 0) max_overlap = overlap[argmax_overlap, np.arange(len(target_c))] if max_overlap < self.rpn_train_fg_thresh_low: fg_index = argmax_overlap if target_c_label > 0: label[fg_index] = 1 else: label[fg_index] = 0 label_weight[fg_index] = 1 label_assign[...] = argmax_overlap else: argmax_overlap, a = np.where(overlap >= self.rpn_train_fg_thresh_low) fg_index = argmax_overlap if target_c_label > 0: label[fg_index] = 1 else: label[fg_index] = 0 label_weight[fg_index] = 1 # label_assign[fg_index] = a # In case one ground truth box within one batch has way too many positive anchors, # which may affect the sample in the loss fucntion, # we only random one positive anchor for each ground truth box if target_c_label > 0: fg_index = np.where(label != 0)[0] idx = random.sample(range(len(fg_index)), 1) label[fg_index] = 0 label_weight[fg_index] = 0 fg_index = fg_index[idx] if target_c_label > 0: label[fg_index] = 1 else: label[fg_index] = 0 label_weight[fg_index] = 1 if target_c_label > 0: # Prepare regression terms for each positive anchor fg_index = np.where(label != 0)[0] target_window = window[fg_index] target_truth_bbox = target_c target[fg_index] = box_transform_numpy(target_window, target_truth_bbox, self.box_reg_weight) target_weight[fg_index] = 1 fg_index = np.where((label_weight != 0) & (label != 0))[0] if self.mode in ['train']: fg_index = np.where((label_weight != 0) & (label != 0))[0] bg_index = np.where((label_weight != 0) & (label == 0))[0] # Random sample num_neg negative anchor boxes first # This is very strange, but it works well in practice # It makes the use of hard negative example mining loss, not # actually hard negative example mining. label_weight[bg_index] = 0 # idx = random.sample(range(len(bg_index)), min(self.num_neg, len(bg_index))) # bg_index = bg_index[idx] # Calculate weight for class balance num_fg = max(1, len(fg_index)) num_bg = max(1, len(bg_index)) # if target_c_label < 0: # print(num_fg, num_bg) label_weight[bg_index] = float(num_fg) / num_bg target_weight[fg_index] = label_weight[fg_index] else: # if there is no ground truth box in this batch label_weight[...] = 1 if self.mode in ['train']: bg_index = np.where((label_weight != 0) & (label == 0))[0] label_weight[bg_index] = 0 idx = random.sample(range(len(bg_index)), min(self.num_neg, len(bg_index))) bg_index = bg_index[idx] label_weight[bg_index] = 1.0 / len(bg_index) label = Variable(torch.from_numpy(label)).cuda() label_assign = Variable(torch.from_numpy(label_assign)).cuda() label_weight = Variable(torch.from_numpy(label_weight)).cuda() target = Variable(torch.from_numpy(target)).cuda() target_weight = Variable(torch.from_numpy(target_weight)).cuda() return label, label_assign, label_weight, target, target_weight def make_rpn_target(self, window, truth_bboxes, diameter_range=None): rpn_labels = [] rpn_label_assigns = [] rpn_label_weights = [] rpn_targets = [] rpn_targets_weights = [] batch_size = len(truth_bboxes) window = window.cpu().data.numpy() for b in range(batch_size): truth_bbox = truth_bboxes[b] rpn_label, rpn_label_assign, rpn_label_weight, rpn_target, rpn_targets_weight = \ self.make_one_rpn_target(window, truth_bbox, diameter_range) rpn_labels.append(rpn_label.view(1, -1)) rpn_label_assigns.append(rpn_label_assign.view(1, -1)) rpn_label_weights.append(rpn_label_weight.view(1, -1)) rpn_targets.append(rpn_target.view(1, -1, 6)) rpn_targets_weights.append(rpn_targets_weight.view(1, -1)) rpn_labels = torch.cat(rpn_labels, 0) rpn_label_assigns = torch.cat(rpn_label_assigns, 0) rpn_label_weights = torch.cat(rpn_label_weights, 0) rpn_targets = torch.cat(rpn_targets, 0) rpn_targets_weights = torch.cat(rpn_targets_weights, 0) return rpn_labels, rpn_label_assigns, rpn_label_weights, rpn_targets, rpn_targets_weights