import os import sys import numpy as np import torch import torch.nn as nn import torch.nn.functional as F sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') from net.component_i import RfbfBlock3d, ResBasicBlock3d from net.component_i import create_conv_block_k3 from net.component_c import init_modules #用于测试 from cls_utils.data import test_save_ckpt class FPN(nn.Module): def __init__(self, n_channels, n_base_filters, groups=1): super(FPN, self).__init__() self.down_level1 = nn.Sequential( RfbfBlock3d(n_channels, n_base_filters, groups=groups) ) self.down_level2 = nn.Sequential( ResBasicBlock3d(1 * n_base_filters, 2 * n_base_filters, stride=(1, 2, 2), groups=groups), ResBasicBlock3d(2 * n_base_filters, 2 * n_base_filters, groups=groups) ) self.down_level3 = nn.Sequential( ResBasicBlock3d(2 * n_base_filters, 4 * n_base_filters, stride=(1, 2, 2), groups=groups), ResBasicBlock3d(4 * n_base_filters, 4 * n_base_filters, groups=groups) ) self.down_level4 = nn.Sequential( ResBasicBlock3d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True), ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True), ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True) ) self.down_level5 = nn.Sequential( ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True), ResBasicBlock3d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True), ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True) ) self.down_level6 = nn.Sequential( ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True) ) self.down_level7 = nn.Sequential( ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True) ) self.down_level8 = nn.Sequential( create_conv_block_k3(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False) ) def forward(self, x): down_out1 = self.down_level1(x) down_out2 = self.down_level2(down_out1) down_out3 = self.down_level3(down_out2) down_out4 = self.down_level4(down_out3) down_out5 = self.down_level5(down_out4) down_out6 = self.down_level6(down_out5) down_out7 = self.down_level7(down_out6) down_out8 = self.down_level8(down_out7) return down_out8 class Net(nn.Module): def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8): super(Net, self).__init__() self.fpn = FPN(8 * n_channels, n_base_filters) self.diff_classifier = nn.Linear(16 * n_base_filters, n_diff_classes) #初始化模型参数 init_modules(self.modules()) def forward(self, data): out = self.fpn(data) out_avg = F.adaptive_avg_pool3d(out, (1, 1, 1)) out = torch.flatten(out_avg, 1) diff_output = self.diff_classifier(out) return diff_output def net_test(): cfg = dict() cfg['n_channels'] = 1 cfg['n_diff_classes'] = 1 cfg['training_crop_size'] = [48, 256, 256] cfg['pretrain_ckpt'] = '' batch_size = 1 x = torch.rand(batch_size, cfg['n_channels'] * 8, cfg['training_crop_size'][0], cfg['training_crop_size'][1], cfg['training_crop_size'][2]) model = Net(n_channels=cfg.get('n_channels'), n_diff_classes=cfg.get('n_diff_classes')) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(device) x = x.to(device) print(x.shape, x.device) #加载模型参数 pretrain_ckpt_path = './cls_train/best_cls/cls_model218_1024u_20230407.ckpt' model_param = torch.load(pretrain_ckpt_path) model.load_state_dict(model_param['state_dict']) model = model.to(device) print('参数加载成功') model.eval() with torch.no_grad(): diff_output = model(x) print(diff_output.shape) #test_save_ckpt(model=model) if __name__ == '__main__': net_test()