import os import sys import argparse import torch import horovod.torch as hvd import numpy as np import lmdb import json import SimpleITK as sitk if __name__ == "__main__" and __package__ is None: sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "..")) __package__ = "lungDetection3D" from lungDetection3D.BaseDetector.config import get_cfg_defaults from lungDetection3D.BaseDetector.utils.gpu_utils import set_gpu from lungDetection3D.NoduleDetector.engine import NoduleDetection3D from lungDetection3D.NoduleDetector.modeling.detector.nodulenet import NoduleNet from lungDetection3D.NoduleDetector.data.bbox_reader import BboxReader_Nodule def set_requires_grad(nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad if __name__ == "__main__": parser = argparse.ArgumentParser(description="full functional execute script of Detection module.") group = parser.add_mutually_exclusive_group() group.add_argument("-t", "--train", help="model training", action="store_true") group.add_argument("-i", "--inference", help="model inference", action="store_true") parser.add_argument("-c", "--config", type=str, default="config.yaml", help="config file path") args = parser.parse_args() cfg = get_cfg_defaults() cfg.merge_from_file(args.config) cfg.freeze() # load model os.environ['CUDA_VISIBLE_DEVICES'] = '7' checkpoint = cfg.TESTING.WEIGHT model = NoduleNet(cfg, mode='test') model.eval() device = torch.device("cuda") model.to(device) checkpoint = torch.load(checkpoint) model.load_state_dict(checkpoint['state_dict']) print("model successfully loaded") set_requires_grad(model) model = model.half() # generate data test_set_name = "hahaha" dataset = BboxReader_Nodule(cfg, mode='test') for i, (input, image) in enumerate(dataset): if i > 0: break D, H, W = image.shape pid = dataset.sample_bboxes[i].get_field("filename") pid = pid.split('/')[-2].replace('.nii.gz', '') print('[%d] Predicting %s' % (i, pid), image.shape) image = np.expand_dims(image, 0) image = np.expand_dims(image, 0) image = torch.from_numpy(image).cuda() # rpns, detections, ensembles = model.forward(input, []) # pth to torchscript model_path = cfg.DEPLOY.TORCHSCRIPT_SAVE_PATH with torch.no_grad(): traced_model = torch.jit.trace(model, (image.half())) print(traced_model.graph) traced_model.save(model_path) # # load data # input_data = torch.from_numpy(np.load('input_data.npy')).cuda() # rpn_windows = torch.from_numpy(np.load('rpn_windows.npy')).cuda() # print("data successfully loaded") # # pth to torchscript # model_path = 'NoduleNet.pt' # traced_model = torch.jit.trace(model, (input_data, rpn_windows)) # print(traced_model.graph) # traced_model.save(model_path)