import torch import numpy as np def box_transform(windows, targets, weight): """ Calculate regression terms, dz, dy, dx, dd, dh, dw # windows should equal to # targets windows: [num_window, z, y, x, D, H, W] targets: [num_target, z, y, x, D, H, W] """ wz, wy, wx, wd, wh, ww = weight bz, by, bx = windows[:, 0], windows[:, 1], windows[:, 2] bd, bh, bw = windows[:, 3], windows[:, 4], windows[:, 5] tz, ty, tx = targets[:, 0], targets[:, 1], targets[:, 2] td, th, tw = targets[:, 3], targets[:, 4], targets[:, 5] dz = wz * (tz - bz) / bd dy = wy * (ty - by) / bh dx = wx * (tx - bx) / bw dd = wd * np.log(td / bd) dh = wh * np.log(th / bh) dw = ww * np.log(tw / bw) deltas = np.vstack((dz, dy, dx, dd, dh, dw)).transpose() return deltas # def box_transform(windows, targets, weight): # """ # Calculate regression terms, dz, dy, dx, dd, dh, dw # # windows should equal to # targets # windows: [num_window, z, y, x, D, H, W] # targets: [num_target, z, y, x, D, H, W] # """ # wz, wy, wx, wd, wh, ww = weight # bz, by, bx = windows[:, 0], windows[:, 1], windows[:, 2] # bd, bh, bw = windows[:, 3], windows[:, 4], windows[:, 5] # tz, ty, tx = targets[:, 0], targets[:, 1], targets[:, 2] # td, th, tw = targets[:, 3], targets[:, 4], targets[:, 5] # dz = wz * (tz - bz) / bd # dy = wy * (ty - by) / bh # dx = wx * (tx - bx) / bw # dd = wd * torch.log(td / bd) # dh = wh * torch.log(th / bh) # dw = ww * torch.log(tw / bw) # deltas = torch.stack((dz, dy, dx, dd, dh, dw),dim = 0) # return deltas def box_transform_inv(windows, deltas, weight): """ Apply regression terms to predicted bboxes windows: [num_window, z, y, x, D, H, W] targets: [num_target, z, y, x, D, H, W] return: prediction: [x1,y1,x2,y2,z1,z2] """ device = deltas.device wz, wy, wx, wd, wh, ww = weight bz = torch.unsqueeze(windows[:, 0],1) by = torch.unsqueeze(windows[:, 1],1) bx = torch.unsqueeze(windows[:, 2],1) bd = torch.unsqueeze(windows[:, 3],1) bh = torch.unsqueeze(windows[:, 4],1) bw = torch.unsqueeze(windows[:, 5],1) dz = torch.unsqueeze(deltas[:, 0] / wz, 1) dy = torch.unsqueeze(deltas[:, 1] / wy, 1) dx = torch.unsqueeze(deltas[:, 2] / wx, 1) dd = torch.unsqueeze(deltas[:, 3] / wd, 1) dh = torch.unsqueeze(deltas[:, 4] / wh, 1) dw = torch.unsqueeze(deltas[:, 5] / ww, 1) z = dz * bd + bz y = dy * bh + by x = dx * bw + bx d = torch.exp(dd) * bd h = torch.exp(dh) * bh w = torch.exp(dw) * bw predictions = torch.cat((z, y, x, d, h, w), dim=1) return predictions def convert_zyxdhw(bbox): # xyxyzz -> zyxdhw z = (bbox[:, 4] + bbox[:, 5]) / 2 y = (bbox[:, 1] + bbox[:, 3]) / 2 x = (bbox[:, 0] + bbox[:, 2]) / 2 d = bbox[:, 5] - bbox[:, 4] h = bbox[:, 3] - bbox[:, 1] w = bbox[:, 2] - bbox[:, 0] z = torch.unsqueeze(z, 1) y = torch.unsqueeze(y, 1) x = torch.unsqueeze(x, 1) d = torch.unsqueeze(d, 1) h = torch.unsqueeze(h, 1) w = torch.unsqueeze(w, 1) new_bbox = torch.cat((z, y, x, d, h, w), dim=1) return new_bbox def convert_xyxyzz(bbox): # zyxdhw -> xyxyzz x1 = bbox[:, 2] - bbox[:, 5] / 2 y1 = bbox[:, 1] - bbox[:, 4] / 2 x2 = bbox[:, 2] + bbox[:, 5] / 2 y2 = bbox[:, 1] + bbox[:, 4] / 2 z1 = bbox[:, 0] - bbox[:, 3] / 2 z2 = bbox[:, 0] + bbox[:, 3] / 2 x1 = torch.unsqueeze(x1, 1) x2 = torch.unsqueeze(x2, 1) y1 = torch.unsqueeze(y1, 1) y2 = torch.unsqueeze(y2, 1) z1 = torch.unsqueeze(z1, 1) z2 = torch.unsqueeze(z2, 1) new_bbox = torch.cat((x1, y1, x2, y2, z1, z2), dim=1) return new_bbox def convert_xyzxyz(bbox): # zyxdhw -> xyzxyz x1 = bbox[:, 2] - bbox[:, 5] / 2 y1 = bbox[:, 1] - bbox[:, 4] / 2 x2 = bbox[:, 2] + bbox[:, 5] / 2 y2 = bbox[:, 1] + bbox[:, 4] / 2 z1 = bbox[:, 0] - bbox[:, 3] / 2 z2 = bbox[:, 0] + bbox[:, 3] / 2 x1 = torch.unsqueeze(x1, 1) x2 = torch.unsqueeze(x2, 1) y1 = torch.unsqueeze(y1, 1) y2 = torch.unsqueeze(y2, 1) z1 = torch.unsqueeze(z1, 1) z2 = torch.unsqueeze(z2, 1) new_bbox = torch.cat((x1, y1, z1, x2, y2, z2), dim=1) return new_bbox def clip_boxes(boxes, img_size): """ clip boxes outside the image, all box follows [z, y, x, d, h, w] """ depth, height, width = img_size boxes[:, 0] = torch.clamp(boxes[:, 0], 0, depth - 1) boxes[:, 1] = torch.clamp(boxes[:, 1], 0, height - 1) boxes[:, 2] = torch.clamp(boxes[:, 2], 0, width - 1) return boxes def convert_to_roi_format(proposals): # bpzyxdhw -> bxyxyzz b = proposals[:, 0] x1 = proposals[:, 4] - proposals[:, 7] / 2 y1 = proposals[:, 3] - proposals[:, 6] / 2 x2 = proposals[:, 4] + proposals[:, 7] / 2 y2 = proposals[:, 3] + proposals[:, 6] / 2 z1 = proposals[:, 2] - proposals[:, 5] / 2 z2 = proposals[:, 2] + proposals[:, 5] / 2 b = torch.unsqueeze(b, 1) x1 = torch.unsqueeze(x1, 1) x2 = torch.unsqueeze(x2, 1) y1 = torch.unsqueeze(y1, 1) y2 = torch.unsqueeze(y2, 1) z1 = torch.unsqueeze(z1, 1) z2 = torch.unsqueeze(z2, 1) rois = torch.cat((b, x1, y1, x2, y2, z1, z2), dim=1) return rois def box_area(boxes): """ Computes the area of a set of bounding boxes, which are specified by its (x1, y1, z1, x2, y2, z2) coordinates. Arguments: boxes (Tensor[N, 6]): boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2, z1, z2) format Returns: area (Tensor[N]): area for each box """ return (boxes[:, 5] - boxes[:, 2]) * (boxes[:, 4] - boxes[:, 1]) * (boxes[:, 3] - boxes[:, 0]) def box_iou(boxes1, boxes2): """ Return intersection-over-union (Jaccard index) of boxes. Both sets of boxes are expected to be in (x1, y1, z1, x2, y2, z2) format. Arguments: boxes1 (Tensor[N, M, 6]) boxes2 (Tensor[M, 6]) Returns: iou (Tensor[N, M]): the NxM matrix containing the NxM Ious """ area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :3], boxes2[:, :3]) # [N,M,3] rb = torch.min(boxes1[:, None, 3:], boxes2[:, 3:]) # [N,M,3] wh = (rb - lt).clamp(min=0) # [N,M,3] inter = wh[:, :, 0] * wh[:, :, 1] * wh[:, :, 2] # [N,M] union = area1[:, None] + area2 - inter iou = inter / union return iou