# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import math import torch from torch import nn from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _triple from ..utils import amp class _DeformConv3D(Function): ''' int deform_conv3d_forward_cuda( at::Tensor input, at::Tensor weight,at::Tensor bias, at::Tensor offset, at::Tensor output, const int kernel_h,const int kernel_w,const int kernel_l, const int stride_h,const int stride_w,const int stride_l, const int pad_h,const int pad_w,const int pad_l, const int dilation_h,const int dilation_w,const int dilation_l, const int group,const int deformable_group, const int in_step,const bool with_bias); ''' @staticmethod @amp.half_function def forward(ctx, input, offset, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, in_step=64): ctx.stride = _triple(stride) ctx.padding = _triple(padding) ctx.dilation = _triple(dilation) ctx.groups = groups ctx.deformable_groups = deformable_groups ctx.in_step = in_step ctx.with_bias = bias is not None if not ctx.with_bias: bias = input.new_empty(0) # fake tensor if not input.is_cuda: raise NotImplementedError if weight.requires_grad or offset.requires_grad or input.requires_grad: ctx.save_for_backward(input, offset, weight, bias) output = input.new_empty(_DeformConv3D._infer_shape(ctx, input, weight)) torch.ops.nodulenet.deform_conv3d_forward( input, weight, bias, offset, output, weight.shape[2], weight.shape[3], weight.shape[4], ctx.stride[0], ctx.stride[1], ctx.stride[2], ctx.padding[0], ctx.padding[1], ctx.padding[2], ctx.dilation[0], ctx.dilation[1], ctx.dilation[2], ctx.groups, ctx.deformable_groups, ctx.in_step, ctx.with_bias) return output ''' int deform_conv3d_backward_cuda( at::Tensor input, at::Tensor weight, at::Tensor bias,at::Tensor offset, at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, at::Tensor grad_offset, at::Tensor grad_output, const int kernel_h,const int kernel_w,const int kernel_l, const int stride_h,const int stride_w,const int stride_l, const int pad_h,const int pad_w,const int pad_l, const int dilation_h,const int dilation_w,const int dilation_l, const int group, int deformable_group,const int in_step,const bool with_bias) ; ''' @staticmethod @once_differentiable @amp.half_function def backward(ctx, grad_output): grad_output = grad_output.contiguous() # print(grad_output) if not grad_output.is_cuda: raise NotImplementedError input, offset, weight, bias = ctx.saved_tensors grad_input = torch.zeros_like(input) grad_offset = torch.zeros_like(offset) grad_weight = torch.zeros_like(weight) grad_bias = torch.zeros_like(bias) torch.ops.nodulenet.deform_conv3d_backward( input, weight, bias, offset, grad_input, grad_weight, grad_bias, grad_offset, grad_output, weight.shape[2], weight.shape[3], weight.shape[4], ctx.stride[0], ctx.stride[1], ctx.stride[2], ctx.padding[0], ctx.padding[1], ctx.padding[2], ctx.dilation[0], ctx.dilation[1], ctx.dilation[2], ctx.groups, ctx.deformable_groups, ctx.in_step, ctx.with_bias) if not ctx.with_bias: grad_bias = None return (grad_input, grad_offset, grad_weight, grad_bias, None, None, None, None, None, None) @staticmethod def _infer_shape(ctx, input, weight): n = input.size(0) channels_out = weight.size(0) height, width, length = input.shape[2:5] kernel_h, kernel_w, kernel_l = weight.shape[2:5] height_out = (height + 2 * ctx.padding[0] - (ctx.dilation[0] * (kernel_h - 1) + 1)) // ctx.stride[0] + 1 width_out = (width + 2 * ctx.padding[1] - (ctx.dilation[1] * (kernel_w - 1) + 1)) // ctx.stride[1] + 1 length_out = (length + 2 * ctx.padding[2] - (ctx.dilation[2] * (kernel_l - 1) + 1)) // ctx.stride[2] + 1 return n, channels_out, height_out, width_out, length_out deform_conv3d = _DeformConv3D.apply class DeformConv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, bias=False, in_step=256): super(DeformConv3d, self).__init__() assert in_channels % groups == 0, ( 'in_channels {} cannot be divisible by groups {}'.format(in_channels, groups)) assert out_channels % groups == 0, ( 'out_channels {} cannot be divisible by groups {}'.format(out_channels, groups)) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _triple(kernel_size) self.stride = _triple(stride) self.padding = _triple(padding) self.dilation = _triple(dilation) self.groups = groups self.deformable_groups = deformable_groups self.in_step = in_step self.weight = nn.Parameter( torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) self.with_bias = bias if self.with_bias: self.bias = nn.Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.bias = None # nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") # if self.bias is not None: # nn.init.constant_(self.bias, 0) self.reset_parameters() def reset_parameters(self): n = self.in_channels for k in self.kernel_size: n *= k stdv = 1. / math.sqrt(n) self.weight.data.uniform_(-stdv, stdv) if self.with_bias: self.bias.data.fill_(0) def forward(self, x): pass class DeformConv3dPack(DeformConv3d): def __init__(self, *args, **kwargs): super(DeformConv3dPack, self).__init__(*args, **kwargs) self.conv_offset = nn.Conv3d( self.in_channels, self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2], kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=True) self.init_offset() def init_offset(self): n = self.in_channels for k in self.kernel_size: n *= k stdv = 1. / math.sqrt(n) self.conv_offset.weight.data.uniform_(-stdv, stdv) self.conv_offset.bias.data.zero_() def _infer_shape(self, x): n = x.size(0) channels_out = self.out_channels height, width, length = x.shape[2:5] kernel_h, kernel_w, kernel_l = self.kernel_size[0], self.kernel_size[1], self.kernel_size[2] height_out = (height + 2 * self.padding[0] - (self.dilation[0] * (kernel_h - 1) + 1)) // self.stride[0] + 1 width_out = (width + 2 * self.padding[1] - (self.dilation[1] * (kernel_w - 1) + 1)) // self.stride[1] + 1 length_out = (length + 2 * self.padding[2] - (self.dilation[2] * (kernel_l - 1) + 1)) // self.stride[2] + 1 return n, channels_out, height_out, width_out, length_out def forward(self, x): offset = self.conv_offset(x) if torch._C._get_tracing_state(): # we cannot currently trace through the autograd function output = torch.empty((self._infer_shape(x))).cuda().half() bias = torch.empty((0)).cuda().half() torch.ops.nodulenet.deform_conv3d_forward( x, self.weight, bias, offset, output, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2], self.stride[0], self.stride[1], self.stride[2], self.padding[0], self.padding[1], self.padding[2], self.dilation[0], self.dilation[1], self.dilation[2], self.groups, self.deformable_groups, self.in_step, self.with_bias) return output else: return deform_conv3d(x, offset, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups, self.in_step)