# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch from torch import nn from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _triple # from ..utils import amp import apex.amp as amp class _ROIAlign3D(Function): # 在传入forward前,autograd engine会自动将Variable unpack成Tensor。 # 故这里的input, roi, output_size, spatial_scale, sampling_ratio已经都是tensor。 # 第一个是ctx,第二个是input,其他是可选参数。ctx在这里类似self,ctx的属性可以在backward中调用。 @staticmethod def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio): ctx.save_for_backward(roi) # 将Tensor的roi转变为Variable保存到ctx中,传给backward使用 ctx.output_size = _triple(output_size) ctx.spatial_scale = spatial_scale ctx.sampling_ratio = sampling_ratio ctx.input_shape = input.size() output = torch.ops.nodulenet.roi_align_3d_forward( input, roi, spatial_scale, output_size[0], output_size[1], output_size[2], sampling_ratio) return output @staticmethod @once_differentiable # 自定义的函数不可导时,在实现backward函数时需要使用@once_differentiable def backward(ctx, grad_output): # grad_output是variable,为反向传播上一级计算得到的梯度值 # print(type(grad_output)) # 此时你会惊奇的发现,竟然是Tensor了,由于@staticmethod @once_differentiable rois, = ctx.saved_tensors # 接入forward中save_for_backward(roi)保存的variable形式的参数并转为tensor形式 output_size = ctx.output_size spatial_scale = ctx.spatial_scale sampling_ratio = ctx.sampling_ratio bs, ch, d, h, w = ctx.input_shape grad_input = torch.ops.nodulenet.roi_align_3d_backward( grad_output, rois, spatial_scale, output_size[0], output_size[1], output_size[2], bs, ch, d, h, w, sampling_ratio) return grad_input, None, None, None, None # apply(fn):将fn函数递归地应用到网络模型的每个子模型中,主要用在参数的初始化。 # 用apply方法对自己定义的方法取个别名 roi_align = _ROIAlign3D.apply class ROIAlign3D(nn.Module): # 网络的输入和输出的形式都是Variable def __init__(self, output_size, spatial_scale, sampling_ratio): super(ROIAlign3D, self).__init__() self.output_size = _triple(output_size) self.spatial_scale = spatial_scale self.sampling_ratio = sampling_ratio @amp.float_function def forward(self, input, rois): if torch._C._get_tracing_state(): # we cannot currently trace through the autograd function return torch.ops.nodulenet.roi_align_3d_forward( input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.output_size[2], self.sampling_ratio) return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio) def __repr__(self): tmpstr = self.__class__.__name__ + "(" tmpstr += "output_size=" + str(self.output_size) tmpstr += ", spatial_scale=" + str(self.spatial_scale) tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) tmpstr += ")" return tmpstr