import torch
import copy
from torch import nn
import torch.nn.functional as F

from layers.roi_align_3d import ROIAlign3D
from backbone.resunet import ResUNet, ResUNet_FPN
from rpn.rpn_head import RpnHead
from rpn.rpn_module import RpnModule
from roi_heads.box_head import BoxHead
from roi_heads.box_module import BoxModule
from util import convert_to_roi_format, pad2factor, norm, add_rcnn_probability, get_anchors
from loss_fun.loss import Loss, Loss_FPN


@torch.jit.script
def make_rpn_windows(fs, stride, anchors):
    """
    Generating anchor boxes at each voxel on the feature map,
    the center of the anchor box on each voxel corresponds to center
    on the original input image.

    return
    windows: list of anchor boxes, [z, y, x, d, h, w]
    """
    offset = (stride - 1) / 2
    device = fs.device

    # anchors = torch.tensor(anchors, dtype=torch.float32, device=device)
    # D, H , W =(i/cfg['stride'] for i in cfg['crop_size'])

    _, _, D, H, W = fs.shape
    oz = torch.arange(offset, offset + stride * (D - 1) + 1, step=stride, dtype=torch.float32, device=device)
    oh = torch.arange(offset, offset + stride * (H - 1) + 1, step=stride, dtype=torch.float32, device=device)
    ow = torch.arange(offset, offset + stride * (W - 1) + 1, step=stride, dtype=torch.float32, device=device)
    oanchor = torch.arange(anchors.size()[0], dtype=torch.float32, device=device)

    shift_z, shift_y, shift_x, anchor_idx = torch.meshgrid(oz, oh, ow, oanchor)
    shift_z = shift_z.reshape(-1)
    shift_y = shift_y.reshape(-1)
    shift_x = shift_x.reshape(-1)
    anchor_idx = anchor_idx.reshape(-1).long()
    shift_anchors = torch.index_select(anchors, 0, anchor_idx)
    shift_anchor_z = shift_anchors[:, 0].squeeze()
    shift_anchor_y = shift_anchors[:, 1].squeeze()
    shift_anchor_x = shift_anchors[:, 2].squeeze()
    windows = torch.stack((shift_z, shift_y, shift_x, shift_anchor_z, shift_anchor_y, shift_anchor_x), dim=1)

    return windows


class NoduleNet(nn.Module):
    def __init__(self, cfg, mode='train'):
        super(NoduleNet, self).__init__()
        self.mode = mode
        self.cfg = cfg

        self.stride = torch.tensor(cfg.data_process_stride, dtype=torch.float32)
        # self.anchors = torch.tensor(get_anchors(cfg.MODEL.ANCHOR.BASES, cfg.MODEL.ANCHOR.ASPECT_RATIOS),
        #                             dtype=torch.float32).cuda()
        self.anchors = [torch.tensor(get_anchors(cfg.model_anchor_bases, cfg.model_abchor_aspect_ratios)[i],
                                    dtype=torch.float32).cuda() for i in range(3)]
        if cfg.MODEL.BACKBONE.FPN:
            self.feature_net = ResUNet_FPN(cfg, 1, 128)
            self.rpn_heads = nn.ModuleList([RpnHead(len(cfg.model_anchor_bases[0]), in_channels=64),
                              RpnHead(len(cfg.model_anchor_bases[1]), in_channels=64),
                              RpnHead(len(cfg.model_anchor_bases[2]), in_channels=128)])
        else:
            self.feature_net = ResUNet(cfg, 1, 128)
            self.rpn_head = RpnHead(len(cfg.model_anchor_bases), in_channels=128)
        self.box_head = BoxHead(cfg.model_roi_box_head_num_class, cfg.model_roi_box_head_pooler_resolution,
                                in_channels=128)
        self.rpn = RpnModule(cfg, self.mode)
        self.roi_box = BoxModule(cfg)
        self.roi_crop = ROIAlign3D(cfg.model_roi_box_head_pooler_resolution,
                                   spatial_scale=1 / cfg.data_process_stride,
                                   sampling_ratio=cfg.model_roi_box_head_sample_ratio)
        self.use_rcnn = False
        if cfg.model_backbone_fpn:
            self.loss = Loss_FPN(self.rpn.loss, self.roi_box.loss, self.use_rcnn)
        else:
            self.loss = Loss(self.rpn.loss, self.roi_box.loss, self.use_rcnn)

    def forward_depoly(self, inputs):
        inputs = pad2factor(inputs)
        inputs = norm(inputs)
        features, feat_4 = self.feature_net(inputs)
        if self.cfg.model_backbone_fpn:
            self.rpn_logits, self.rpn_deltas = [], []
            self.rpn_windows = [make_rpn_windows(features[0], self.stride * 2, self.anchors[0]),
                                make_rpn_windows(features[1], self.stride, self.anchors[1]),
                                make_rpn_windows(features[2], self.stride // 2, self.anchors[2])]
            for i in range(3):
                rpn_logits, rpn_deltas = self.rpn_heads[i](features[i])
                self.rpn_logits.append(rpn_logits)
                self.rpn_deltas.append(rpn_deltas)
        else:
            self.rpn_window = make_rpn_windows(features[-1], self.stride, self.anchors)
            self.rpn_logits, self.rpn_deltas = self.rpn_head(features[-1])
        if self.cfg.model_backbone_fpn:
            self.rpn_proposals = [
                self.rpn.rpn_nms(inputs, self.rpn_windows[i], self.rpn_logits[i], self.rpn_deltas[i])
                for i in range(3)]
        else:
            self.rpn_proposals = self.rpn.rpn_nms(inputs, self.rpn_window, self.rpn_logits, self.rpn_deltas)
        # self.rpn_proposal_rois = convert_to_roi_format(self.rpn_proposals)
        # self.rcnn_logits, self.rcnn_deltas = self.box_head(
        #     self.roi_crop(feat_4.float(), self.rpn_proposal_rois.float()).half())
        # self.rpn_proposals = add_rcnn_probability(self.rpn_proposals, self.rcnn_logits)
        return self.rpn_proposals

    def forward(self, inputs, truth_bboxes):
        features, feat_4 = self.feature_net(inputs)
        if self.cfg.model_backbone_fpn:
            self.rpn_logits, self.rpn_deltas = [], []
            self.rpn_windows = [make_rpn_windows(features[0], self.stride * 2, self.anchors[0]),
                                make_rpn_windows(features[1], self.stride, self.anchors[1]),
                                make_rpn_windows(features[2], self.stride // 2, self.anchors[2])]
            for i in range(3):
                rpn_logits, rpn_deltas = self.rpn_heads[i](features[i])
                self.rpn_logits.append(rpn_logits)
                self.rpn_deltas.append(rpn_deltas)
        else:
            self.rpn_window = make_rpn_windows(features[-1], self.stride, self.anchors)
            self.rpn_logits, self.rpn_deltas = self.rpn_head(features[-1])
        if self.mode in ['train', 'valid']:
            if self.cfg.model_backbone_fpn:
                self.rpn_labels, self.rpn_label_assigns, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights = \
                    [], [], [], [], []
                for i in range(3):
                    rpn_labels, rpn_label_assigns, rpn_label_weights, rpn_targets, rpn_target_weights = \
                        self.rpn.make_rpn_target(self.rpn_windows[i], truth_bboxes, self.cfg.DATA.DATA_PROCESS.FPN_DIAMETER_RANGE[i])
                    self.rpn_labels.append(rpn_labels)
                    self.rpn_label_assigns.append(rpn_label_assigns)
                    self.rpn_label_weights.append(rpn_label_weights)
                    self.rpn_targets.append(rpn_targets)
                    self.rpn_target_weights.append(rpn_target_weights)
            else:
                self.rpn_labels, self.rpn_label_assigns, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights = \
                    self.rpn.make_rpn_target(self.rpn_window, truth_bboxes)

            if self.use_rcnn:
                self.rpn_proposals = self.rpn.rpn_nms(inputs, self.rpn_window, self.rpn_logits, self.rpn_deltas)
                self.rpn_proposals, self.rcnn_labels, self.rcnn_targets = self.roi_box.make_box_target(
                    self.rpn_proposals, truth_bboxes)
                if len(self.rpn_proposals) > 0:
                    self.rpn_proposal_rois = convert_to_roi_format(self.rpn_proposals)
                    self.rcnn_logits, self.rcnn_deltas = self.box_head(self.roi_crop(feat_4, self.rpn_proposal_rois))
            else:
                self.rcnn_logits, self.rcnn_deltas, self.rcnn_labels, self.rcnn_targets = None, None, None, None

            return self.rpn_logits, self.rpn_deltas, \
                   self.rpn_labels, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights, \
                   self.rcnn_logits, self.rcnn_deltas, self.rcnn_labels, self.rcnn_targets

        elif self.mode in ['test']:
            if self.cfg.model_backbone_fpn:
                self.rpn_proposals = [
                    self.rpn.rpn_nms(inputs, self.rpn_windows[i], self.rpn_logits[i], self.rpn_deltas[i])
                    for i in range(3)]
            else:
                self.rpn_proposals = self.rpn.rpn_nms(inputs, self.rpn_window, self.rpn_logits, self.rpn_deltas)

            if self.use_rcnn:
                if len(self.rpn_proposals) > 0:
                    self.ensemble_proposals = copy.deepcopy(self.rpn_proposals)
                    self.rpn_proposal_rois = convert_to_roi_format(self.rpn_proposals)
                    self.rcnn_logits, self.rcnn_deltas = self.box_head(self.roi_crop(feat_4, self.rpn_proposal_rois))
                    self.fpr_res = self.roi_box.get_probability(inputs, self.rpn_proposals, self.rcnn_logits,
                                                                self.rcnn_deltas)
                    self.ensemble_proposals[:, 1] = (self.ensemble_proposals[:, 1] + self.fpr_res[:, 0]) / 2
                    return self.rpn_proposals, self.fpr_res, self.ensemble_proposals
                else:
                    raise RuntimeError("The nms pre score threshold {} is too high, change for a lower one.".format(
                        self.pre_nms_score_threshold))
            else:
                return self.rpn_proposals
        else:
            raise ValueError('rpn_nms(): invalid mode = %s?' % self.mode)

    def set_mode(self, mode):
        assert mode in ['train', 'valid', 'test']
        self.mode = mode
        # self.rpn.mode = self.mode
        if mode in ['train']:
            self.train()
        else:
            self.eval()