import os import glob import torch import timeit import numpy as np import sys sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "..")) __package__ = "BaseClassifier" from .utils.RWconfig import LoadJson from .model.modelBuild import build_model import SimpleITK as sitk from torchsummary import summary from torch.autograd import Variable import os # os.environ["CUDA_VISIBLE_DEVICES"] = "6" def set_requires_grad(nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad if __name__ == "__main__": path = '/fileser/xupl/lungclassifier3D' dir_names = [dir_name for dir_name in os.listdir(path) if dir_name.startswith('Nodule') or dir_name.startswith('Rib')] # dir_names = ['NoduleDensityClassifier'] for dir_name in dir_names: infer_chocie = False convert_choice = True model = [file for file in os.listdir(os.path.join(path, dir_name, 'best')) if file.startswith('CP')][0] checkpoint = os.path.join(path, dir_name, 'best', model) config_path = os.path.join('..', dir_name,'model_parameters.json') train_config_path = os.path.join('..', dir_name, 'config.json') model_name = LoadJson(train_config_path)['training']['model_name'] model_param = LoadJson(config_path)[model_name] model = build_model(model_name, model_param) model.eval() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model.to(device) model.load_state_dict(torch.load(checkpoint, map_location=device)) set_requires_grad(model) print("model successfully loaded") if infer_chocie: paths = sorted(glob.glob('/fileser/zrx/Rib/alpha/493/temp/*nii.gz')) for path in paths: print('image_path is', path) img = sitk.GetArrayFromImage(sitk.ReadImage(path)) img = img[np.newaxis, np.newaxis, ...] print('image range is', np.amin(img), np.amax(img)) data = torch.from_numpy(img).to(device=device, dtype=torch.float32) output = model(data, data) print('output is', output[0][0].item()) probs = torch.sigmoid(output[0]) print('result is', probs[0].item()) # # load data # pth to torchscript if convert_choice: patch_size = 48 input_shape = [1, 1] + [patch_size for _ in range(3)] input_data = np.zeros(input_shape) input_data = torch.from_numpy(input_data).to(device=device, dtype=torch.float32) print('type of input_data is', type(input_data)) model_base_path = os.path.dirname(checkpoint) model_path = '%s/model.pt' % model_base_path traced_model = torch.jit.trace(model, [input_data]) traced_model.save(model_path) print('Successful save model to %s' % model_path) summary(model.cuda(), (1, 48, 48, 48))