import torch import torch.nn as nn def _gather_feat(feat, ind, mask=None): dim = feat.size(2) ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) feat = feat.gather(1, ind) if mask is not None: mask = mask.unsqueeze(2).expand_as(feat) feat = feat[mask] feat = feat.view(-1, dim) return feat def _transpose_and_gather_feat(feat, ind): feat = feat.permute(0, 2, 3, 4, 1).contiguous() feat = feat.view(feat.size(0), -1, feat.size(4)) feat = _gather_feat(feat, ind) return feat def _topk(scores, K=40): batch, cat, depth, height, width = scores.size() topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) topk_inds = topk_inds % (depth * height * width) topk_xs = (topk_inds % width).int().float() topk_ys = (topk_inds // width % height).int().float() topk_zs = (topk_inds // width // height % depth).int().float() topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) topk_clses = (topk_ind / K).int() topk_inds = _gather_feat( topk_inds.view(batch, -1, 1), topk_ind).view(batch, K) topk_zs = _gather_feat(topk_zs.view(batch, -1, 1), topk_ind).view(batch, K) topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K) topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K) return topk_score, topk_inds, topk_clses, topk_zs, topk_ys, topk_xs def _nms(heat, kernel=3): pad = (kernel - 1) // 2 hmax = nn.functional.max_pool3d( heat, (kernel, kernel, kernel), stride=1, padding=pad) keep = (hmax == heat).float() return heat * keep