net_cls_1024u_2d.py 4.25 KB
import os
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')

from net.component_i_2d import RfbfBlock2d, ResBasicBlock2d
from net.component_i_2d import create_conv_block_k3
from net.component_c import init_modules_2d

#用于测试
from cls_utils.data import test_save_ckpt

class FPN(nn.Module):
    def __init__(self, n_channels, n_base_filters, groups=1):
        super(FPN, self).__init__()

        self.down_level1 = nn.Sequential(
            RfbfBlock2d(n_channels, n_base_filters, groups=groups)
        )

        self.down_level2 = nn.Sequential(
            ResBasicBlock2d(1 * n_base_filters, 2 * n_base_filters, stride=(2, 2), groups=groups),
            ResBasicBlock2d(2 * n_base_filters, 2 * n_base_filters, groups=groups)
        )

        self.down_level3 = nn.Sequential(
            ResBasicBlock2d(2 * n_base_filters, 4 * n_base_filters, stride=(2, 2), groups=groups),
            ResBasicBlock2d(4 * n_base_filters, 4 * n_base_filters, groups=groups)
        )

        self.down_level4 = nn.Sequential(
            ResBasicBlock2d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True),
            ResBasicBlock2d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
            ResBasicBlock2d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True)
        )

        self.down_level5 = nn.Sequential(
            ResBasicBlock2d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True),
            ResBasicBlock2d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
            ResBasicBlock2d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True)
        )

        self.down_level6 = nn.Sequential(
            ResBasicBlock2d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
        )

        self.down_level7 = nn.Sequential(
            ResBasicBlock2d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
        )

        self.down_level8 = nn.Sequential(
            create_conv_block_k3(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False)
        )
    
    def forward(self, x):
        down_out1 = self.down_level1(x)
        down_out2 = self.down_level2(down_out1)
        down_out3 = self.down_level3(down_out2)
        down_out4 = self.down_level4(down_out3)
        down_out5 = self.down_level5(down_out4)
        down_out6 = self.down_level6(down_out5)
        down_out7 = self.down_level7(down_out6)
        down_out8 = self.down_level8(down_out7)

        return down_out8

class Net(nn.Module):
    def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8):
        super(Net, self).__init__()

        self.fpn = FPN(8 * n_channels, n_base_filters)

        self.diff_classifier = nn.Linear(16 * n_base_filters, n_diff_classes)

        #初始化模型参数
        init_modules_2d(self.modules())
    
    def forward(self, data):

        out = self.fpn(data)
        out_avg = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out_avg, 1)
        diff_output = self.diff_classifier(out)

        return diff_output

def net_test():
    cfg = dict()
    cfg['n_channels'] = 1
    cfg['n_diff_classes'] = 1
    cfg['training_crop_size'] = [128, 128]
    cfg['pretrain_ckpt'] = ''
    batch_size = 1

    x = torch.rand(batch_size, cfg['n_channels'] * 8,
                   cfg['training_crop_size'][0], cfg['training_crop_size'][1])
    print('x.shape:', x.shape)
    model = Net(n_channels=cfg.get('n_channels'), n_diff_classes=cfg.get('n_diff_classes'))

    print('模型结构:')
    print(model)
    print('------------------------------------------')
    """
    print(type(model))
    for layer in model:
        #X = layer(x)
        print(layer.__class__.__name__, f'output size: ')"""


    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    x = x.to(device)
    print(x.shape, x.device)

    #加载模型参数
    model = model.to(device)
    #print('参数加载成功')

    model.eval()
    with torch.no_grad():
        diff_output = model(x)
        print(diff_output.shape)

    #test_save_ckpt(model=model)

if __name__ == '__main__':
    net_test()