import os
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')

from net.component_i_231025 import RfbfBlock3d, ResBasicBlock3d
from net.component_i_231025 import create_conv_block_k3
from net.component_c import init_modules

#用于测试

class FPN(nn.Module):
    def __init__(self, n_channels, n_base_filters, groups=1):
        super(FPN, self).__init__()

        self.down_level1 = nn.Sequential(
            RfbfBlock3d(n_channels, n_base_filters, groups=groups)
        )

        self.down_level2 = nn.Sequential(
            ResBasicBlock3d(1 * n_base_filters, 2 * n_base_filters, stride=(1, 2, 2), groups=groups),
            ResBasicBlock3d(2 * n_base_filters, 2 * n_base_filters, groups=groups)
        )

        self.down_level3 = nn.Sequential(
            ResBasicBlock3d(2 * n_base_filters, 4 * n_base_filters, stride=(1, 2, 2), groups=groups),
            ResBasicBlock3d(4 * n_base_filters, 4 * n_base_filters, groups=groups)
        )

        self.down_level4 = nn.Sequential(
            ResBasicBlock3d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True),
            ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
            ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True)
        )

        self.down_level5 = nn.Sequential(
            ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True),
            ResBasicBlock3d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
            ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True)
        )

        self.down_level6 = nn.Sequential(
            ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
        )

        self.down_level7 = nn.Sequential(
            ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
        )

        self.down_level8 = nn.Sequential(
            create_conv_block_k3(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False)
        )
    
    def forward(self, x):
        down_out1 = self.down_level1(x)
        down_out2 = self.down_level2(down_out1)
        down_out3 = self.down_level3(down_out2)
        down_out4 = self.down_level4(down_out3)
        down_out5 = self.down_level5(down_out4)
        down_out6 = self.down_level6(down_out5)
        down_out7 = self.down_level7(down_out6)
        down_out8 = self.down_level8(down_out7)

        return down_out8

class Net(nn.Module):
    def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8):
        super(Net, self).__init__()

        self.fpn = FPN(n_channels, n_base_filters)

        self.diff_classifier = nn.Linear(16 * n_base_filters, n_diff_classes)

        #初始化模型参数
        init_modules(self.modules())
    
    def forward(self, data):

        out = self.fpn(data)
        #改成最大池化
        out_max = F.adaptive_max_pool3d(out, (1, 1, 1))
        #out_max = F.max_pool3d(out, (1, 1, 1))
        #out_avg = F.adaptive_avg_pool3d(out, (1, 1, 1))
        out = torch.flatten(out_max, 1)
        #print(out_avg.shape)
        #print(out_max.shape)
        diff_output = self.diff_classifier(out)
        diff_output = F.sigmoid(diff_output)

        return diff_output

def net_test():
    cfg = dict()
    cfg['n_channels'] = 1
    cfg['n_diff_classes'] = 1
    cfg['training_crop_size'] = [48, 256, 256]
    cfg['pretrain_ckpt'] = ''
    batch_size = 1

    x = torch.rand(batch_size, cfg['n_channels'],
                   cfg['training_crop_size'][0], cfg['training_crop_size'][1], cfg['training_crop_size'][2])
    
    print(x.shape)

    model = Net(n_channels=cfg.get('n_channels'), n_diff_classes=cfg.get('n_diff_classes'))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    x = x.to(device)
    print(x.shape, x.device)

    #加载模型参数
    """pretrain_ckpt_path = './cls_train/best_cls/cls_model218_1024u_20230407.ckpt'
    model_param = torch.load(pretrain_ckpt_path)
    model.load_state_dict(model_param['state_dict'])"""
    model = model.to(device)
    #print('参数加载成功')

    model.eval()
    with torch.no_grad():
        diff_output = model(x)
        print(diff_output.shape)

    dummy_input = torch.randn(1, 1, 48, 256, 256).to(device)  # 根据实际输入尺寸调整
    torch.onnx.export(model, 
                    dummy_input, 
                    'cls_train_net_cls_1024u_231025_20241113.onnx',
                    input_names=['input'],
                    output_names=['output'],
                    dynamic_axes={'input': {0: 'batch_size'},
                                'output': {0: 'batch_size'}})

if __name__ == '__main__':
    net_test()