// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. #pragma once #ifdef WITH_CUDA #include "cuda/vision.h" #endif // int_terface for Python void deform_conv3d_forward(at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor offset, at::Tensor output, const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_l, const int64_t stride_h, const int64_t stride_w, const int64_t stride_l, const int64_t pad_h, const int64_t pad_w, const int64_t pad_l, const int64_t dilation_h, const int64_t dilation_w, const int64_t dilation_l, const int64_t group, const int64_t deformable_group, const int64_t in_step, const bool with_bias) { if (input.type().is_cuda()) { #ifdef WITH_CUDA return deform_conv3d_forward_cuda(input, weight, bias, offset, output, kernel_h, kernel_w, kernel_l, stride_h, stride_w, stride_l, pad_h, pad_w, pad_l, dilation_h, dilation_w, dilation_l, group, deformable_group, in_step, with_bias); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } void deform_conv3d_backward(at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor offset, at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, at::Tensor grad_offset, at::Tensor grad_output, const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_l, const int64_t stride_h, const int64_t stride_w, const int64_t stride_l, const int64_t pad_h, const int64_t pad_w, const int64_t pad_l, const int64_t dilation_h, const int64_t dilation_w, const int64_t dilation_l, const int64_t group, int64_t deformable_group, const int64_t in_step, const bool with_bias) { if (grad_input.type().is_cuda()) { #ifdef WITH_CUDA return deform_conv3d_backward_cuda(input, weight, bias, offset, grad_input, grad_weight, grad_bias, grad_offset, grad_output, kernel_h, kernel_w, kernel_l, stride_h, stride_w, stride_l, pad_h, pad_w, pad_l, dilation_h, dilation_w, dilation_l, group, deformable_group, in_step, with_bias); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); }