// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. #include #include #include #include #include #include int const threadsPerBlock = sizeof(unsigned long long) * 8; __device__ inline float devIoU(float const * const a, float const * const b) { float left = max(a[0], b[0]), right = min(a[2], b[2]); float top = max(a[1], b[1]), bottom = min(a[3], b[3]); float front = max(a[4], b[4]), back = min(a[5], b[5]); float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f), depth = max(back - front + 1, 0.f); float interS = width * height * depth; float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1) * (a[5] - a[4] + 1); float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1) * (b[5] - b[4] + 1); return interS / (Sa + Sb - interS); } __global__ void overlap_3d_kernel(const int n_boxes1, const float* boxes1, const int n_boxes2, const float* boxes2, float* overlap) { const int row_start = blockIdx.y; const int col_start = blockIdx.x; // if (row_start > col_start) return; const int row_size = fminf(n_boxes1 - row_start * threadsPerBlock, threadsPerBlock); const int col_size = fminf(n_boxes2 - col_start * threadsPerBlock, threadsPerBlock); __shared__ float block_boxes[threadsPerBlock * 6]; if (threadIdx.x < col_size) { block_boxes[threadIdx.x * 6 + 0] = boxes2[(threadsPerBlock * col_start + threadIdx.x) * 6 + 0]; block_boxes[threadIdx.x * 6 + 1] = boxes2[(threadsPerBlock * col_start + threadIdx.x) * 6 + 1]; block_boxes[threadIdx.x * 6 + 2] = boxes2[(threadsPerBlock * col_start + threadIdx.x) * 6 + 2]; block_boxes[threadIdx.x * 6 + 3] = boxes2[(threadsPerBlock * col_start + threadIdx.x) * 6 + 3]; block_boxes[threadIdx.x * 6 + 4] = boxes2[(threadsPerBlock * col_start + threadIdx.x) * 6 + 4]; block_boxes[threadIdx.x * 6 + 5] = boxes2[(threadsPerBlock * col_start + threadIdx.x) * 6 + 5]; } __syncthreads(); if (threadIdx.x < row_size) { const int box1_idx = threadsPerBlock * row_start + threadIdx.x; const float *box1 = boxes1 + box1_idx * 6; int i = 0; int start = 0; // if (row_start == col_start) { // start = threadIdx.x + 1; // } for (i = start; i < col_size; i++) { const int box2_idx = threadsPerBlock * col_start + i; const int index = box1_idx * n_boxes2 + box2_idx; overlap[index] = devIoU(box1, block_boxes + i * 6); } } } // boxes is a N x 7 tensor at::Tensor overlap_3d_cuda(const at::Tensor boxes_1, const at::Tensor boxes_2) { using scalar_t = float; AT_ASSERTM(boxes_1.type().is_cuda(), "boxes_1 must be a CUDA tensor"); AT_ASSERTM(boxes_2.type().is_cuda(), "boxes_2 must be a CUDA tensor"); int boxes1_num = boxes_1.size(0); int boxes2_num = boxes_2.size(0); // const int col_blocks_1 = THCCeilDiv(boxes1_num, threadsPerBlock); // const int col_blocks_2 = THCCeilDiv(boxes2_num, threadsPerBlock); // overlap_dev = (scalar_t*) THCudaMalloc(state, boxes1_num * col_blocks_1 * boxes2_num * col_blocks_2 * sizeof(scalar_t)); auto overlap = at::empty({boxes1_num, boxes2_num}, boxes_1.options()); dim3 blocks(at::ceil_div(boxes2_num, threadsPerBlock), at::ceil_div(boxes1_num, threadsPerBlock)); dim3 threads(threadsPerBlock); overlap_3d_kernel<<>>(boxes1_num, boxes_1.data(), boxes2_num, boxes_2.data(), overlap.data()); AT_CUDA_CHECK(cudaGetLastError()); return overlap; }