import torch import copy from torch import nn import torch.nn.functional as F from layers.roi_align_3d import ROIAlign3D from backbone.resunet import ResUNet, ResUNet_FPN from rpn.rpn_head import RpnHead from rpn.rpn_module import RpnModule from roi_heads.box_head import BoxHead from roi_heads.box_module import BoxModule from util import convert_to_roi_format, pad2factor, norm, add_rcnn_probability, get_anchors from loss_fun.loss import Loss, Loss_FPN @torch.jit.script def make_rpn_windows(fs, stride, anchors): """ Generating anchor boxes at each voxel on the feature map, the center of the anchor box on each voxel corresponds to center on the original input image. return windows: list of anchor boxes, [z, y, x, d, h, w] """ offset = (stride - 1) / 2 device = fs.device # anchors = torch.tensor(anchors, dtype=torch.float32, device=device) # D, H , W =(i/cfg['stride'] for i in cfg['crop_size']) _, _, D, H, W = fs.shape oz = torch.arange(offset, offset + stride * (D - 1) + 1, step=stride, dtype=torch.float32, device=device) oh = torch.arange(offset, offset + stride * (H - 1) + 1, step=stride, dtype=torch.float32, device=device) ow = torch.arange(offset, offset + stride * (W - 1) + 1, step=stride, dtype=torch.float32, device=device) oanchor = torch.arange(anchors.size()[0], dtype=torch.float32, device=device) shift_z, shift_y, shift_x, anchor_idx = torch.meshgrid(oz, oh, ow, oanchor) shift_z = shift_z.reshape(-1) shift_y = shift_y.reshape(-1) shift_x = shift_x.reshape(-1) anchor_idx = anchor_idx.reshape(-1).long() shift_anchors = torch.index_select(anchors, 0, anchor_idx) shift_anchor_z = shift_anchors[:, 0].squeeze() shift_anchor_y = shift_anchors[:, 1].squeeze() shift_anchor_x = shift_anchors[:, 2].squeeze() windows = torch.stack((shift_z, shift_y, shift_x, shift_anchor_z, shift_anchor_y, shift_anchor_x), dim=1) return windows class NoduleNet(nn.Module): def __init__(self, cfg, mode='train'): super(NoduleNet, self).__init__() self.mode = mode self.cfg = cfg self.stride = torch.tensor(cfg.data_process_stride, dtype=torch.float32) # self.anchors = torch.tensor(get_anchors(cfg.MODEL.ANCHOR.BASES, cfg.MODEL.ANCHOR.ASPECT_RATIOS), # dtype=torch.float32).cuda() self.anchors = [torch.tensor(get_anchors(cfg.model_anchor_bases, cfg.model_abchor_aspect_ratios)[i], dtype=torch.float32).cuda() for i in range(3)] if cfg.MODEL.BACKBONE.FPN: self.feature_net = ResUNet_FPN(cfg, 1, 128) self.rpn_heads = nn.ModuleList([RpnHead(len(cfg.model_anchor_bases[0]), in_channels=64), RpnHead(len(cfg.model_anchor_bases[1]), in_channels=64), RpnHead(len(cfg.model_anchor_bases[2]), in_channels=128)]) else: self.feature_net = ResUNet(cfg, 1, 128) self.rpn_head = RpnHead(len(cfg.model_anchor_bases), in_channels=128) self.box_head = BoxHead(cfg.model_roi_box_head_num_class, cfg.model_roi_box_head_pooler_resolution, in_channels=128) self.rpn = RpnModule(cfg, self.mode) self.roi_box = BoxModule(cfg) self.roi_crop = ROIAlign3D(cfg.model_roi_box_head_pooler_resolution, spatial_scale=1 / cfg.data_process_stride, sampling_ratio=cfg.model_roi_box_head_sample_ratio) self.use_rcnn = False if cfg.model_backbone_fpn: self.loss = Loss_FPN(self.rpn.loss, self.roi_box.loss, self.use_rcnn) else: self.loss = Loss(self.rpn.loss, self.roi_box.loss, self.use_rcnn) def forward_depoly(self, inputs): inputs = pad2factor(inputs) inputs = norm(inputs) features, feat_4 = self.feature_net(inputs) if self.cfg.model_backbone_fpn: self.rpn_logits, self.rpn_deltas = [], [] self.rpn_windows = [make_rpn_windows(features[0], self.stride * 2, self.anchors[0]), make_rpn_windows(features[1], self.stride, self.anchors[1]), make_rpn_windows(features[2], self.stride // 2, self.anchors[2])] for i in range(3): rpn_logits, rpn_deltas = self.rpn_heads[i](features[i]) self.rpn_logits.append(rpn_logits) self.rpn_deltas.append(rpn_deltas) else: self.rpn_window = make_rpn_windows(features[-1], self.stride, self.anchors) self.rpn_logits, self.rpn_deltas = self.rpn_head(features[-1]) if self.cfg.model_backbone_fpn: self.rpn_proposals = [ self.rpn.rpn_nms(inputs, self.rpn_windows[i], self.rpn_logits[i], self.rpn_deltas[i]) for i in range(3)] else: self.rpn_proposals = self.rpn.rpn_nms(inputs, self.rpn_window, self.rpn_logits, self.rpn_deltas) # self.rpn_proposal_rois = convert_to_roi_format(self.rpn_proposals) # self.rcnn_logits, self.rcnn_deltas = self.box_head( # self.roi_crop(feat_4.float(), self.rpn_proposal_rois.float()).half()) # self.rpn_proposals = add_rcnn_probability(self.rpn_proposals, self.rcnn_logits) return self.rpn_proposals def forward(self, inputs, truth_bboxes): features, feat_4 = self.feature_net(inputs) if self.cfg.model_backbone_fpn: self.rpn_logits, self.rpn_deltas = [], [] self.rpn_windows = [make_rpn_windows(features[0], self.stride * 2, self.anchors[0]), make_rpn_windows(features[1], self.stride, self.anchors[1]), make_rpn_windows(features[2], self.stride // 2, self.anchors[2])] for i in range(3): rpn_logits, rpn_deltas = self.rpn_heads[i](features[i]) self.rpn_logits.append(rpn_logits) self.rpn_deltas.append(rpn_deltas) else: self.rpn_window = make_rpn_windows(features[-1], self.stride, self.anchors) self.rpn_logits, self.rpn_deltas = self.rpn_head(features[-1]) if self.mode in ['train', 'valid']: if self.cfg.model_backbone_fpn: self.rpn_labels, self.rpn_label_assigns, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights = \ [], [], [], [], [] for i in range(3): rpn_labels, rpn_label_assigns, rpn_label_weights, rpn_targets, rpn_target_weights = \ self.rpn.make_rpn_target(self.rpn_windows[i], truth_bboxes, self.cfg.DATA.DATA_PROCESS.FPN_DIAMETER_RANGE[i]) self.rpn_labels.append(rpn_labels) self.rpn_label_assigns.append(rpn_label_assigns) self.rpn_label_weights.append(rpn_label_weights) self.rpn_targets.append(rpn_targets) self.rpn_target_weights.append(rpn_target_weights) else: self.rpn_labels, self.rpn_label_assigns, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights = \ self.rpn.make_rpn_target(self.rpn_window, truth_bboxes) if self.use_rcnn: self.rpn_proposals = self.rpn.rpn_nms(inputs, self.rpn_window, self.rpn_logits, self.rpn_deltas) self.rpn_proposals, self.rcnn_labels, self.rcnn_targets = self.roi_box.make_box_target( self.rpn_proposals, truth_bboxes) if len(self.rpn_proposals) > 0: self.rpn_proposal_rois = convert_to_roi_format(self.rpn_proposals) self.rcnn_logits, self.rcnn_deltas = self.box_head(self.roi_crop(feat_4, self.rpn_proposal_rois)) else: self.rcnn_logits, self.rcnn_deltas, self.rcnn_labels, self.rcnn_targets = None, None, None, None return self.rpn_logits, self.rpn_deltas, \ self.rpn_labels, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights, \ self.rcnn_logits, self.rcnn_deltas, self.rcnn_labels, self.rcnn_targets elif self.mode in ['test']: if self.cfg.model_backbone_fpn: self.rpn_proposals = [ self.rpn.rpn_nms(inputs, self.rpn_windows[i], self.rpn_logits[i], self.rpn_deltas[i]) for i in range(3)] else: self.rpn_proposals = self.rpn.rpn_nms(inputs, self.rpn_window, self.rpn_logits, self.rpn_deltas) if self.use_rcnn: if len(self.rpn_proposals) > 0: self.ensemble_proposals = copy.deepcopy(self.rpn_proposals) self.rpn_proposal_rois = convert_to_roi_format(self.rpn_proposals) self.rcnn_logits, self.rcnn_deltas = self.box_head(self.roi_crop(feat_4, self.rpn_proposal_rois)) self.fpr_res = self.roi_box.get_probability(inputs, self.rpn_proposals, self.rcnn_logits, self.rcnn_deltas) self.ensemble_proposals[:, 1] = (self.ensemble_proposals[:, 1] + self.fpr_res[:, 0]) / 2 return self.rpn_proposals, self.fpr_res, self.ensemble_proposals else: raise RuntimeError("The nms pre score threshold {} is too high, change for a lower one.".format( self.pre_nms_score_threshold)) else: return self.rpn_proposals else: raise ValueError('rpn_nms(): invalid mode = %s?' % self.mode) def set_mode(self, mode): assert mode in ['train', 'valid', 'test'] self.mode = mode # self.rpn.mode = self.mode if mode in ['train']: self.train() else: self.eval()