trtserver_test.py 3.19 KB
# from tensorrtserver.api import *
import os
import sys
import argparse
import torch
import horovod.torch as hvd
import numpy as np
import lmdb
import json
import SimpleITK as sitk
import time


if __name__ == "__main__" and __package__ is None:
    sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".."))
    __package__ = "lungDetection3D"

from lungDetection3D.BaseDetector.config import get_cfg_defaults
from lungDetection3D.BaseDetector.utils.gpu_utils import set_gpu
from lungDetection3D.NoduleDetector.engine import NoduleDetection3D
from lungDetection3D.NoduleDetector.modeling.detector.nodulenet import NoduleNet
from lungDetection3D.NoduleDetector.data.bbox_reader import BboxReader_Nodule

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="full functional execute script of Detection module.")
    group = parser.add_mutually_exclusive_group()
    group.add_argument("-t", "--train", help="model training", action="store_true")
    group.add_argument("-i", "--inference", help="model inference", action="store_true")
    parser.add_argument("-c", "--config", type=str, default="config.yaml", help="config file path")
    args = parser.parse_args()

    cfg = get_cfg_defaults()
    cfg.merge_from_file(args.config)
    cfg.freeze()


    # load model
    os.environ['CUDA_VISIBLE_DEVICES'] = '7'
    # checkpoint = "/data/huangky/FPN/20201119/checkpoints/069.pth"
    # model = NoduleNet(cfg, mode='test')
    # model.eval()
    # device = torch.device("cuda")
    # model.to(device)
    # checkpoint = torch.load(checkpoint)
    # model.load_state_dict(checkpoint['state_dict'])
    # print("model successfully loaded")

    model_path = '/data/huangky/FPN/20201119/NoduleNet_FPN_half.pt'
    model = torch.jit.load(model_path)

    # generate data
    test_set_name = "hahaha"
    dataset = BboxReader_Nodule(cfg, mode='test')
    for i, (input, image) in enumerate(dataset):
        D, H, W = image.shape
        pid = dataset.sample_bboxes[i].get_field("filename")
        pid = pid.split('/')[-2].replace('.nii.gz', '')
        print('[%d] Predicting %s' % (i, pid), image.shape)
        input = input.cuda().unsqueeze(0).half()
        # rpns, detections, ensembles = model.forward(input, [])

        # torchscript infer
        # input_data = torch.from_numpy(input_data).cuda()
        # rpn_windows = torch.from_numpy(rpn_windows).cuda()

        # model = torch.jit.load(model_path)
        with torch.no_grad():
            rpns = model(input)
        time.sleep(5)
        if cfg.MODEL.BACKBONE.FPN:
            fpns = rpns
            for i in range(len(fpns)):
                fpn = fpns[i]
                fpn = fpn.cpu().numpy()
                if len(fpn):
                    fpn = fpn[:, 1:]
                print("-------------------torchscript infer-------------------")
                print('fpns_%s' % str(i))
                print("fpns: {}".format(fpn))
        torch.cuda.empty_cache()
        # # pth to torchscript
        # model_path = '/data/huangky/FPN/20201119/NoduleNet_FPN.pt'
        # with torch.no_grad():
        #     traced_model = torch.jit.trace(model, (input))
        # print(traced_model.graph)
        # traced_model.save(model_path)