// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#pragma once

#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif

// int_terface for Python
void deform_conv3d_forward(at::Tensor input, at::Tensor weight, at::Tensor bias,
                           at::Tensor offset, at::Tensor output,
                           const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_l,
                           const int64_t stride_h, const int64_t stride_w, const int64_t stride_l,
                           const int64_t pad_h, const int64_t pad_w, const int64_t pad_l,
                           const int64_t dilation_h, const int64_t dilation_w, const int64_t dilation_l,
                           const int64_t group, const int64_t deformable_group, const int64_t in_step, const bool with_bias) {
    if (input.type().is_cuda()) {
#ifdef WITH_CUDA
        return deform_conv3d_forward_cuda(input, weight, bias, offset, output,
                                          kernel_h, kernel_w, kernel_l,
                                          stride_h, stride_w, stride_l,
                                          pad_h, pad_w, pad_l,
                                          dilation_h, dilation_w, dilation_l,
                                          group, deformable_group, in_step, with_bias);
#else
        AT_ERROR("Not compiled with GPU support");
#endif
    }
    AT_ERROR("Not implemented on the CPU");
}

void deform_conv3d_backward(at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor offset,
                            at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
                            at::Tensor grad_offset, at::Tensor grad_output,
                            const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_l,
                            const int64_t stride_h, const int64_t stride_w, const int64_t stride_l,
                            const int64_t pad_h, const int64_t pad_w, const int64_t pad_l,
                            const int64_t dilation_h, const int64_t dilation_w, const int64_t dilation_l,
                            const int64_t group, int64_t deformable_group, const int64_t in_step, const bool with_bias) {
    if (grad_input.type().is_cuda()) {
#ifdef WITH_CUDA
        return deform_conv3d_backward_cuda(input, weight, bias, offset,
                                           grad_input, grad_weight, grad_bias,
                                           grad_offset, grad_output,
                                           kernel_h, kernel_w, kernel_l,
                                           stride_h, stride_w, stride_l,
                                           pad_h, pad_w, pad_l,
                                           dilation_h, dilation_w, dilation_l,
                                           group, deformable_group, in_step, with_bias);
#else
        AT_ERROR("Not compiled with GPU support");
#endif
    }
    AT_ERROR("Not implemented on the CPU");
}