import copy import torch from torch import nn from .loss import Loss from ..backbone import ResUNet from ..roi_heads import BoxHead, BoxModule from ..rpn import RpnHead, RpnModule from ..util import convert_to_roi_format, pad2factor, norm, get_anchors, add_rcnn_probability from ...layers.roi_align_3d import ROIAlign3D @torch.jit.script def make_rpn_windows(fs, stride, anchors): """ Generating anchor boxes at each voxel on the feature map, the center of the anchor box on each voxel corresponds to center on the original input image. return windows: list of anchor boxes, [z, y, x, d, h, w] """ offset = (stride - 1) / 2 device = fs.device # anchors = torch.tensor(anchors, dtype=torch.float32, device=device) # D, H , W =(i/cfg['stride'] for i in cfg['crop_size']) _, _, D, H, W = fs.shape oz = torch.arange(offset, offset + stride * (D - 1) + 1, step=stride, dtype=torch.float32, device=device) oh = torch.arange(offset, offset + stride * (H - 1) + 1, step=stride, dtype=torch.float32, device=device) ow = torch.arange(offset, offset + stride * (W - 1) + 1, step=stride, dtype=torch.float32, device=device) oanchor = torch.arange(anchors.size()[0], dtype=torch.float32, device=device) shift_z, shift_y, shift_x, anchor_idx = torch.meshgrid(oz, oh, ow, oanchor) shift_z = shift_z.reshape(-1) shift_y = shift_y.reshape(-1) shift_x = shift_x.reshape(-1) anchor_idx = anchor_idx.reshape(-1).long() shift_anchors = torch.index_select(anchors, 0, anchor_idx) shift_anchor_z = shift_anchors[:, 0].squeeze() shift_anchor_y = shift_anchors[:, 1].squeeze() shift_anchor_x = shift_anchors[:, 2].squeeze() windows = torch.stack((shift_z, shift_y, shift_x, shift_anchor_z, shift_anchor_y, shift_anchor_x), dim=1) return windows class NoduleNet(nn.Module): def __init__(self, cfg, mode='train'): super(NoduleNet, self).__init__() self.cfg = cfg self.mode = mode if cfg.MODEL.BACKBONE.CONV_BODY == 'ResUNet': self.feature_net = ResUNet(cfg, 1, 128) else: raise ValueError('Detector backbone %s is not implemented.' % self.cfg.MODEL.BACKBONE.CONV_BODY) self.stride = torch.tensor(cfg.DATA.DATA_PROCESS.STRIDE, dtype=torch.float32) self.anchors = torch.tensor(get_anchors(cfg.MODEL.ANCHOR.BASES, cfg.MODEL.ANCHOR.ASPECT_RATIOS), dtype=torch.float32).cuda() self.rpn_head = RpnHead(len(cfg.MODEL.ANCHOR.BASES), in_channels=128) self.box_head = BoxHead(cfg.MODEL.ROI_BOX_HEAD.NUM_CLASS, cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION, in_channels=128) self.rpn = RpnModule(cfg, self.mode) self.roi_box = BoxModule(cfg) self.roi_crop = ROIAlign3D(cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION, spatial_scale=1 / cfg.DATA.DATA_PROCESS.STRIDE, sampling_ratio=cfg.MODEL.ROI_BOX_HEAD.SAMPLING_RATIO) self.use_rcnn = False self.loss = Loss(self.rpn.loss, self.roi_box.loss, self.use_rcnn) def forward(self, inputs, truth_bboxes=None): if truth_bboxes is None: return self.forward_deploy(inputs) features, feat_4 = self.feature_net(inputs) self.rpn_window = make_rpn_windows(features[-1], self.stride, self.anchors) self.rpn_logits, self.rpn_deltas = self.rpn_head(features[-1]) if self.mode in ['train', 'valid']: self.rpn_labels, self.rpn_label_assigns, self.rpn_label_weights, \ self.rpn_targets, self.rpn_target_weights = self.rpn.make_rpn_target(self.rpn_window, truth_bboxes) if self.use_rcnn: self.rpn_proposals = self.rpn.rpn_nms(inputs, self.rpn_window, self.rpn_logits, self.rpn_deltas) self.rpn_proposals, self.rcnn_labels, self.rcnn_targets = self.roi_box.make_box_target( self.rpn_proposals, truth_bboxes) if len(self.rpn_proposals) > 0: self.rpn_proposal_rois = convert_to_roi_format(self.rpn_proposals) self.rcnn_logits, self.rcnn_deltas = self.box_head(self.roi_crop(feat_4, self.rpn_proposal_rois)) else: self.rcnn_logits, self.rcnn_deltas, self.rcnn_labels, self.rcnn_targets = None, None, None, None return self.rpn_logits, self.rpn_deltas, \ self.rpn_labels, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights, \ self.rcnn_logits, self.rcnn_deltas, self.rcnn_labels, self.rcnn_targets elif self.mode in ['test']: self.rpn_proposals = self.rpn.rpn_nms(inputs, self.rpn_window, self.rpn_logits, self.rpn_deltas) if self.use_rcnn: if len(self.rpn_proposals) > 0: self.ensemble_proposals = copy.deepcopy(self.rpn_proposals) self.rpn_proposal_rois = convert_to_roi_format(self.rpn_proposals) self.rcnn_logits, self.rcnn_deltas = self.box_head(self.roi_crop(feat_4, self.rpn_proposal_rois)) self.fpr_res = self.roi_box.get_probability(inputs, self.rpn_proposals, self.rcnn_logits, self.rcnn_deltas) self.ensemble_proposals[:, 1] = (self.ensemble_proposals[:, 1] + self.fpr_res[:, 0]) / 2 return self.rpn_proposals, self.fpr_res, self.ensemble_proposals else: raise RuntimeError("The nms pre score threshold {} is too high, change for a lower one.".format( self.pre_nms_score_threshold)) else: return self.rpn_proposals else: raise ValueError('rpn_nms(): invalid mode = %s?' % self.mode) def forward_deploy(self, inputs): inputs = pad2factor(inputs) inputs = norm(inputs) features, feat_4 = self.feature_net(inputs) self.rpn_window = make_rpn_windows(features[-1], self.stride, self.anchors) self.rpn_logits, self.rpn_deltas = self.rpn_head(features[-1]) self.rpn_proposals = self.rpn.rpn_nms_deploy(inputs, self.rpn_window, self.rpn_logits, self.rpn_deltas) # self.rpn_proposal_rois = convert_to_roi_format(self.rpn_proposals) # self.rcnn_logits, self.rcnn_deltas = self.box_head( # self.roi_crop(feat_4.float(), self.rpn_proposal_rois.float()).half()) # self.rpn_proposals = add_rcnn_probability(self.rpn_proposals, self.rcnn_logits) return self.rpn_proposals def set_mode(self, mode): assert mode in ['train', 'valid', 'test'] self.mode = mode self.rpn.update_mode(self.mode) # model.eval() 在模型测试阶段使用,让model自动把 BN 和 dropout 固定住,不会取平均,而是用训练好的值。 # model.train() 在模型训练阶段使用,此时 dropout 和 BN 的操作在训练起到防止网络过拟合的问题。 if mode in ['train']: self.train() else: self.eval()