import random import numpy as np import torch from net.layer.rpn_nms import rpn_encode from net.layer.overlap_3d import overlap_3d from net.layer.util import convert_xyxyzz, convert_xyzxyz, box_iou from torch.autograd import Variable def make_one_rpn_target(cfg, mode, window, truth_box, truth_label): """ Generate region proposal targets for one batch cfg: dict, for hyper-parameters mode: string, which phase/mode is used currently window: list of anchor bounding boxes, [z, y, x, d, h, w] truth_box: list of ground truth bounding boxes, [z, y, x, d, h, w] truth_label: list of grount truth class label for each object in the correponding truth_box return torch tensors label: positive or negative (1 or 0) for each anchor box label_assign: index of the ground truth box, to which the anchor box is matched to label_weight: class weight for each sample, zero means current sample is protected, and won't contribute to loss target: bounding box regression terms target_weight: weight for each regression term, by default it should all be ones """ num_neg = cfg['num_neg'] num_window = len(window) label = np.zeros((num_window,), np.float32) label_assign = np.zeros((num_window,), np.int32) - 1 label_weight = np.zeros((num_window,), np.float32) target = np.zeros((num_window, 6), np.float32) target_weight = np.zeros((num_window,), np.float32) assert (len(truth_label) == len(truth_box)) num_truth_box = len(truth_box) if num_truth_box: # Get sure background anchor boxes try: overlap = box_iou(convert_xyzxyz(torch.from_numpy(window).float().cuda()), convert_xyzxyz(torch.from_numpy(truth_box).float().cuda())) overlap = overlap.cpu().data.numpy() except: overlap = box_iou(convert_xyzxyz(torch.from_numpy(window).float()), convert_xyzxyz(torch.from_numpy(truth_box).float())) overlap = overlap.data.numpy() # overlap = overlap_3d(convert_xyxyzz(torch.from_numpy(window).float().cuda()), convert_xyxyzz(torch.from_numpy(truth_box).float().cuda())) # print(overlap.size()) # print(overlap) # overlap = overlap.data.numpy() # overlap = overlap.cpu().data.numpy() # negative condition argmax_overlap = np.argmax(overlap, 1) max_overlap = overlap[np.arange(num_window), argmax_overlap] # The anchor box is a sure background, if its largest IoU is less than a threshold bg_index = np.where(max_overlap < cfg['rpn_train_bg_thresh_high'])[0] label[bg_index] = 0 label_weight[bg_index] = 1 # positive condition pos_index = np.where(truth_label == 1)[0] # if len(pos_index) == 0: # print(truth_box) for i in list(pos_index): truth_box_overlap = overlap[:, i] argmax_overlap = np.argmax(truth_box_overlap, 0) max_overlap = truth_box_overlap[argmax_overlap] if max_overlap < cfg['rpn_train_fg_thresh_low']: # if the max iou between gt and anchors is lower than threshold # classification lable for anchor fg_index = argmax_overlap label[fg_index] = 1 # label_weight[fg_index] = 1 label_assign[...] = argmax_overlap # regression label for anchor target_window = window[fg_index].reshape(-1, 6) target_truth_box = truth_box[i].reshape(-1, 6) target[fg_index] = rpn_encode(target_window, target_truth_box, cfg['box_reg_weight']) target_weight[fg_index] = 1 else: # more than one anchor whose iou with gt is higher than threshold argmax_overlap = np.where(truth_box_overlap >= cfg['rpn_train_fg_thresh_low'])[0] fg_index = argmax_overlap # idx = random.sample(range(len(fg_index)), 1) # fg_index = fg_index[idx] # classification label for anchors label[fg_index] = 1 # label_weight[fg_index] = 1 # regression label for anchors target_window = window[fg_index].reshape(-1, 6) target_truth_box = truth_box[i].reshape(-1, 6) target[fg_index] = rpn_encode(target_window, target_truth_box, cfg['box_reg_weight']) target_weight[fg_index] = 1 if mode in ['train', 'valid']: bg_index = np.where((label_weight != 0) & (label == 0))[0] label_weight[bg_index] = 0 idx = random.sample(range(len(bg_index)), min(num_neg, len(bg_index))) bg_index = bg_index[idx] label_weight[bg_index] = 1 else: # if there is no ground truth box in this batch print('no way !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') label_weight[...] = 1 bg_index = np.where(label == 0)[0] if mode in ['train', 'valid']: idx = random.sample(range(len(window)), min(num_neg, len(bg_index))) bg_index = bg_index[idx] label_weight[bg_index] = 1 label = Variable(torch.from_numpy(label)).cuda() label_assign = Variable(torch.from_numpy(label_assign)).cuda() label_weight = Variable(torch.from_numpy(label_weight)).cuda() target = Variable(torch.from_numpy(target)).cuda() target_weight = Variable(torch.from_numpy(target_weight)).cuda() return label, label_assign, label_weight, target, target_weight def make_rpn_target(cfg, mode, window, truth_boxes, truth_labels): rpn_labels = [] rpn_label_assigns = [] rpn_label_weights = [] rpn_targets = [] rpn_targets_weights = [] batch_size = len(truth_boxes) window = window.cpu().data.numpy() for b in range(batch_size): truth_box = truth_boxes[b] truth_label = truth_labels[b] rpn_label, rpn_label_assign, rpn_label_weight, rpn_target, rpn_targets_weight = \ make_one_rpn_target(cfg, mode, window, truth_box, truth_label) rpn_labels.append(rpn_label.view(1, -1)) rpn_label_assigns.append(rpn_label_assign.view(1, -1)) rpn_label_weights.append(rpn_label_weight.view(1, -1)) rpn_targets.append(rpn_target.view(1, -1, 6)) rpn_targets_weights.append(rpn_targets_weight.view(1, -1)) rpn_labels = torch.cat(rpn_labels, 0) rpn_label_assigns = torch.cat(rpn_label_assigns, 0) rpn_label_weights = torch.cat(rpn_label_weights, 0) rpn_targets = torch.cat(rpn_targets, 0) rpn_targets_weights = torch.cat(rpn_targets_weights, 0) return rpn_labels, rpn_label_assigns, rpn_label_weights, rpn_targets, rpn_targets_weights