import os import sys import numpy as np import torch import torch.nn as nn import torch.nn.functional as F sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') from net.component_i_231025 import RfbfBlock3d, ResBasicBlock3d from net.component_i_231025 import create_conv_block_k3 from net.component_c import init_modules #用于测试 class FPN(nn.Module): def __init__(self, n_channels, n_base_filters, groups=1): super(FPN, self).__init__() self.down_level1 = nn.Sequential( RfbfBlock3d(n_channels, n_base_filters, groups=groups) ) self.down_level2 = nn.Sequential( ResBasicBlock3d(1 * n_base_filters, 2 * n_base_filters, stride=(1, 2, 2), groups=groups), ResBasicBlock3d(2 * n_base_filters, 2 * n_base_filters, groups=groups) ) self.down_level3 = nn.Sequential( ResBasicBlock3d(2 * n_base_filters, 4 * n_base_filters, stride=(1, 2, 2), groups=groups), ResBasicBlock3d(4 * n_base_filters, 4 * n_base_filters, groups=groups) ) self.down_level4 = nn.Sequential( ResBasicBlock3d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True), ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True), ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True) ) self.down_level5 = nn.Sequential( ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True), ResBasicBlock3d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True), ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True) ) self.down_level6 = nn.Sequential( ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True) ) self.down_level7 = nn.Sequential( ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True) ) self.down_level8 = nn.Sequential( create_conv_block_k3(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False) ) def forward(self, x): down_out1 = self.down_level1(x) down_out2 = self.down_level2(down_out1) down_out3 = self.down_level3(down_out2) down_out4 = self.down_level4(down_out3) down_out5 = self.down_level5(down_out4) down_out6 = self.down_level6(down_out5) down_out7 = self.down_level7(down_out6) down_out8 = self.down_level8(down_out7) return down_out8 class Net(nn.Module): def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8): super(Net, self).__init__() self.fpn = FPN(n_channels, n_base_filters) self.diff_classifier = nn.Linear(16 * n_base_filters, n_diff_classes) #初始化模型参数 init_modules(self.modules()) def forward(self, data): out = self.fpn(data) #改成最大池化 out_max = F.adaptive_max_pool3d(out, (1, 1, 1)) #out_max = F.max_pool3d(out, (1, 1, 1)) #out_avg = F.adaptive_avg_pool3d(out, (1, 1, 1)) out = torch.flatten(out_max, 1) #print(out_avg.shape) #print(out_max.shape) diff_output = self.diff_classifier(out) diff_output = F.sigmoid(diff_output) return diff_output def net_test(): cfg = dict() cfg['n_channels'] = 1 cfg['n_diff_classes'] = 1 cfg['training_crop_size'] = [48, 256, 256] cfg['pretrain_ckpt'] = '' batch_size = 1 x = torch.rand(batch_size, cfg['n_channels'], cfg['training_crop_size'][0], cfg['training_crop_size'][1], cfg['training_crop_size'][2]) print(x.shape) model = Net(n_channels=cfg.get('n_channels'), n_diff_classes=cfg.get('n_diff_classes')) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(device) x = x.to(device) print(x.shape, x.device) #加载模型参数 """pretrain_ckpt_path = './cls_train/best_cls/cls_model218_1024u_20230407.ckpt' model_param = torch.load(pretrain_ckpt_path) model.load_state_dict(model_param['state_dict'])""" model = model.to(device) #print('参数加载成功') model.eval() with torch.no_grad(): diff_output = model(x) print(diff_output.shape) dummy_input = torch.randn(1, 1, 48, 256, 256).to(device) # 根据实际输入尺寸调整 torch.onnx.export(model, dummy_input, 'cls_train_net_cls_1024u_231025_20241113.onnx', input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}) if __name__ == '__main__': net_test()