from util import box_transform, box_transform_inv, clip_boxes, convert_zyxdhw, convert_xyxyzz, convert_xyxyzz_to_xyzxyz # from ..layer.util import box_transform, box_transform_inv, clip_boxes, convert_zyxdhw, convert_xyxyzz, convert_xyxyzz_to_xyzxyz import torch.nn.functional as F import torch from torch.autograd import Variable # from net.layer.nms_3d import nms_3d # from ..layer.nms_3d import nms_3d def box_area(boxes): """ Computes the area of a set of bounding boxes, which are specified by its (x1, y1, z1, x2, y2, z2) coordinates. Arguments: boxes (Tensor[N, 6]): boxes for which the area will be computed. They are expected to be in (x1, y1, z1, x2, y2, z2) format Returns: area (Tensor[N]): area for each box """ return (boxes[:, 3] - boxes[:, 0]) * (boxes[:, 4] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 2]) def box_iou(boxes1, boxes2): """ Return intersection-over-union (Jaccard index) of boxes. Both sets of boxes are expected to be in (x1, y1, z1, x2, y2, z2) format. Arguments: boxes1 (Tensor[N, 6]) boxes2 (Tensor[N, 6]) Returns: iou (Tensor[N, 1]): the Nx1 matrix containing the N Ious """ area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :3], boxes2[:, :3]) # [N,3] rb = torch.min(boxes1[:, None, 3:], boxes2[:, 3:]) # [N,3] wh = (rb - lt).clamp(min=0) # [N,3] inter = wh[:, :, 0] * wh[:, :, 1] * wh[:, :, 2] # [N,M] union = area1[:, None] + area2 - inter iou = inter / union return iou def nms_3d_hky(bboxes, scores, threshold=0.5): untril = 1 - torch.tril(torch.ones(1000, 1000)).cuda() tril = torch.tril(torch.ones(1000, 1000)).cuda() _, order = scores.sort(0, descending=True) # 降序排列 # bboxes = bboxes.cuda() bboxes = convert_xyxyzz_to_xyzxyz(bboxes) bboxes = bboxes[order] ious = box_iou(bboxes, bboxes) ious = ious * untril + tril mask = (ious <= threshold) + tril keep_mask = (mask.sum(0) == 1000) keep_idx = keep_mask.nonzero(as_tuple=False).squeeze() # keep_idx = keep_mask.nonzero()[:, 0] keep = order[keep_idx] return keep # Pytorch的索引值为LongTensor def nms_3d_(bboxes, scores, threshold=0.5): bboxes = bboxes.cuda() x1 = bboxes[:, 0] y1 = bboxes[:, 1] x2 = bboxes[:, 2] y2 = bboxes[:, 3] z1 = bboxes[:, 4] z2 = bboxes[:, 5] areas = (x2 - x1) * (y2 - y1) * (z2 - z1) # [N,] 每个bbox的面积 _, order = scores.sort(0, descending=True) # 降序排列 keep = [] while order.numel() > 0: # torch.numel()返回张量元素个数 if order.numel() == 1: # 保留框只剩一个 i = order.item() keep.append(i) break else: i = order[0].item() # 保留scores最大的那个框box[i] keep.append(i) # 计算box[i]与其余各框的IOU(思路很好) xx1 = x1[order[1:]].clamp(min=x1[i]) # [N-1,] yy1 = y1[order[1:]].clamp(min=y1[i]) xx2 = x2[order[1:]].clamp(max=x2[i]) yy2 = y2[order[1:]].clamp(max=y2[i]) zz1 = z1[order[1:]].clamp(min=z1[i]) zz2 = z2[order[1:]].clamp(max=z2[i]) inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0) * (zz2 - zz1).clamp(min=0) # [N-1,] iou = inter / (areas[i] + areas[order[1:]] - inter) # [N-1,] idx = (iou <= threshold).nonzero().squeeze() # 注意此时idx为[N-1,] 而order为[N,] if idx.numel() == 0: break order = order[idx + 1] # 修补索引之间的差值 return torch.LongTensor(keep).cuda() # Pytorch的索引值为LongTensor # @torch.jit.script def make_rpn_windows(fs): """ Generating anchor boxes at each voxel on the feature map, the center of the anchor box on each voxel corresponds to center on the original input image. return windows: list of anchor boxes, [z, y, x, d, h, w] """ # stride = cfg['stride'] stride = torch.tensor(4, dtype=torch.float32) offset = (float(stride) - 1) / 2 device = fs.device # anchors = torch.tensor(cfg['anchors'], dtype=torch.float32, device=device) anchors = torch.tensor([[3, 3, 3], [10, 10, 10], [30, 30, 30]], dtype=torch.float32, device=device) # D, H , W =(i/cfg['stride'] for i in cfg['crop_size']) _, _, D, H, W = fs.shape oz = torch.arange(offset, offset + stride * (D - 1) + 1, step=stride, dtype=torch.float32, device=device) oh = torch.arange(offset, offset + stride * (H - 1) + 1, step=stride, dtype=torch.float32, device=device) ow = torch.arange(offset, offset + stride * (W - 1) + 1, step=stride, dtype=torch.float32, device=device) oanchor = torch.arange(anchors.size()[0], dtype=torch.float32, device=device) shift_z, shift_y, shift_x, anchor_idx = torch.meshgrid(oz, oh, ow, oanchor) shift_z = shift_z.reshape(-1) shift_y = shift_y.reshape(-1) shift_x = shift_x.reshape(-1) anchor_idx = anchor_idx.reshape(-1).long() shift_anchors = torch.index_select(anchors, 0, anchor_idx) # shift_anchor_z = shift_anchors[:, 0].squeeze() # shift_anchor_y = shift_anchors[:, 1].squeeze() # shift_anchor_x = shift_anchors[:, 2].squeeze() shift_anchor_z = shift_anchors[:, 0] shift_anchor_y = shift_anchors[:, 1] shift_anchor_x = shift_anchors[:, 2] windows = torch.stack((shift_z, shift_y, shift_x, shift_anchor_z, shift_anchor_y, shift_anchor_x), dim=1) return windows def rpn_nms_(cfg, mode, inputs, window, logits, deltas): if mode in ['train', 'valid']: nms_pre_score_threshold = cfg['rpn_train_nms_pre_score_threshold'] nms_overlap_threshold = cfg['rpn_train_nms_overlap_threshold'] elif mode in ['eval','test', ]: nms_pre_score_threshold = cfg['rpn_test_nms_pre_score_threshold'] nms_overlap_threshold = cfg['rpn_test_nms_overlap_threshold'] pre_nms_top_k = cfg['rpn_test_nms_pre_topk'] else: raise ValueError('rpn_nms(): invalid mode = %s?' % mode) device = logits.device logits = torch.sigmoid(logits) batch_size, num_anchors, _ = logits.size() # pre_nms_top_n = min(nms_pre_max_num, num_anchors) proposals = [] for b in range(batch_size): ps = logits[b, :, 0].reshape(-1, 1) ds = deltas[b, :, :] # p, index = ps.topk(pre_nms_top_n, dim=0, sorted=True) # index = (ps > nms_pre_score_threshold).squeeze().nonzero().squeeze() p, index = ps.squeeze().topk(pre_nms_top_k, dim=0, sorted=True) # if index.nelement() > 0: p = torch.index_select(ps, 0, index).squeeze(1) d = torch.index_select(ds, 0, index) w = torch.index_select(window, 0, index) box = rpn_decode(w, d, cfg['box_reg_weight']) box = clip_boxes(box, inputs.shape[2:]) box = convert_xyxyzz(box) keep = nms_3d(box, p, nms_overlap_threshold) # keep = nms_3d(box.float(), p.float(), nms_overlap_threshold) res_box = torch.index_select(box, 0, keep) res_p = torch.index_select(p, 0, keep) res_box = convert_zyxdhw(res_box) res_p = torch.unsqueeze(res_p, 1) # b_tensor = torch.full((res_box.size()[0], 1), b, device=device) b_tensor = torch.full((res_box.size()[0], 1), b, device=device, dtype=torch.long) prop = torch.cat((b_tensor, res_p, res_box), dim=1) # prop = torch.cat((b_tensor.half(), res_p.half(), res_box.half()), dim=1) proposals.append(prop) proposals = torch.cat(proposals, dim=0) return Variable(proposals) def rpn_nms(cfg, mode, inputs, window, logits, deltas): nms_overlap_threshold = 0.1 pre_nms_top_k = 1000 b = 0 device = logits.device logits = torch.sigmoid(logits) ps = logits[b, :, 0].reshape(-1, 1) ds = deltas[b, :, :] # ps = ps.squeeze().squeeze() p, index = ps.squeeze(1).topk(pre_nms_top_k, dim=0, sorted=True) d = torch.index_select(ds, 0, index) w = torch.index_select(window, 0, index) box = box_transform_inv(w, d, [1., 1., 1., 1., 1., 1.]) # box = clip_boxes(box, inputs) box = convert_xyxyzz(box) keep = nms_3d_hky(box, p, nms_overlap_threshold) # print(keep) # keep = torch.tensor([0, 1, 2, 3, 5, 6, 7, 8, 9], dtype=torch.long).cuda() res_box = torch.index_select(box, 0, keep) res_p = torch.index_select(p, 0, keep) res_box = convert_zyxdhw(res_box) res_p = torch.unsqueeze(res_p, 1) b_tensor = torch.full((res_box.size()[0], 1), b, dtype=torch.float32, device=device) print(b_tensor.type(), res_p.type(), res_box.type()) prop = torch.cat((b_tensor, res_p, res_box), dim=1) # prop = torch.cat((b_tensor.half(), res_p.half(), res_box.half()), dim=1) return Variable(prop) def rpn_encode(window, truth_box, weight): return box_transform(window, truth_box, weight) def rpn_decode(window, delta, weight): return box_transform_inv(window, delta, weight) # bboxes = torch.rand((1000, 6)).cuda() # prob = torch.rand(1000).cuda() # # keep_hky = nms_3d_hky(bboxes, prob, 0.1) # print('*********************************') # print("keep_hky", keep_hky.shape) # # keep_cpu = nms_3d_(bboxes, prob, 0.1) # print("keep_cpu", keep_cpu.shape) # # # keep_op = nms_3d(bboxes, prob, 0.1) # # print("keep_op", keep_op.shape)