import os import sys import numpy as np import torch import torch.nn as nn import torch.nn.functional as F sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') from net.component_i_2d import RfbfBlock2d, ResBasicBlock2d from net.component_i_2d import create_conv_block_k3 from net.component_c import init_modules_2d #用于测试 from cls_utils.data import test_save_ckpt class FPN(nn.Module): def __init__(self, n_channels, n_base_filters, groups=1): super(FPN, self).__init__() self.down_level1 = nn.Sequential( RfbfBlock2d(n_channels, n_base_filters, groups=groups) ) self.down_level2 = nn.Sequential( ResBasicBlock2d(1 * n_base_filters, 2 * n_base_filters, stride=(2, 2), groups=groups), ResBasicBlock2d(2 * n_base_filters, 2 * n_base_filters, groups=groups) ) self.down_level3 = nn.Sequential( ResBasicBlock2d(2 * n_base_filters, 4 * n_base_filters, stride=(2, 2), groups=groups), ResBasicBlock2d(4 * n_base_filters, 4 * n_base_filters, groups=groups) ) self.down_level4 = nn.Sequential( ResBasicBlock2d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True), ResBasicBlock2d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True), ResBasicBlock2d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True) ) self.down_level5 = nn.Sequential( ResBasicBlock2d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True), ResBasicBlock2d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True), ResBasicBlock2d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True) ) self.down_level6 = nn.Sequential( ResBasicBlock2d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True) ) self.down_level7 = nn.Sequential( ResBasicBlock2d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True) ) self.down_level8 = nn.Sequential( create_conv_block_k3(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False) ) def forward(self, x): down_out1 = self.down_level1(x) down_out2 = self.down_level2(down_out1) down_out3 = self.down_level3(down_out2) down_out4 = self.down_level4(down_out3) down_out5 = self.down_level5(down_out4) down_out6 = self.down_level6(down_out5) down_out7 = self.down_level7(down_out6) down_out8 = self.down_level8(down_out7) return down_out8 class Net(nn.Module): def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8): super(Net, self).__init__() self.fpn = FPN(8 * n_channels, n_base_filters) self.diff_classifier = nn.Linear(16 * n_base_filters, n_diff_classes) #初始化模型参数 init_modules_2d(self.modules()) def forward(self, data): out = self.fpn(data) out_avg = F.adaptive_avg_pool2d(out, (1, 1)) out = torch.flatten(out_avg, 1) diff_output = self.diff_classifier(out) return diff_output def net_test(): cfg = dict() cfg['n_channels'] = 1 cfg['n_diff_classes'] = 1 cfg['training_crop_size'] = [128, 128] cfg['pretrain_ckpt'] = '' batch_size = 1 x = torch.rand(batch_size, cfg['n_channels'] * 8, cfg['training_crop_size'][0], cfg['training_crop_size'][1]) print('x.shape:', x.shape) model = Net(n_channels=cfg.get('n_channels'), n_diff_classes=cfg.get('n_diff_classes')) print('模型结构:') print(model) print('------------------------------------------') """ print(type(model)) for layer in model: #X = layer(x) print(layer.__class__.__name__, f'output size: ')""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(device) x = x.to(device) print(x.shape, x.device) #加载模型参数 model = model.to(device) #print('参数加载成功') model.eval() with torch.no_grad(): diff_output = model(x) print(diff_output.shape) #test_save_ckpt(model=model) if __name__ == '__main__': net_test()