import os import glob import torch import timeit import numpy as np import sys sys.path.append('../../lungNoduleDensityClassification/nodulecls/utils') sys.path.append('model') from modelBuild import build_model from RWconfig import LoadJson,WriteJson import SimpleITK as sitk from torchsummary import summary from torch.autograd import Variable import os os.environ["CUDA_VISIBLE_DEVICES"]="1" def set_requires_grad(nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad if __name__ == "__main__": infer_chocie = False convert_choice = True checkpoint = '/hdd/disk4/Segmentation/Weights/1030_GroupLungASPPE_E0.1_8_12_3_Zoom48_BS4_Transpose/06_Stride_8_16_Decoder_TLoss/CP_epoch35_loss_0.19.pth' config_path = '/hdd/disk4/Segmentation/Weights/1030_GroupLungASPPE_E0.1_8_12_3_Zoom48_BS4_Transpose/06_Stride_8_16_Decoder_TLoss/model_parameters02.json' encoder_model_name,decoder_model_name = "NoduleSegEncoder_proxima","NoduleSegDecoder_proxima" encoder_params = LoadJson(config_path)[encoder_model_name] decoder_params = LoadJson(config_path)[decoder_model_name] encoder_params['do_cls'] = False encoder_model = build_model(encoder_model_name,encoder_params) decoder_params['encoder_model'] = encoder_model model = build_model(decoder_model_name,decoder_params) model.eval() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model.to(device) model.load_state_dict(torch.load(checkpoint,map_location=device)) set_requires_grad(model) print("model successfully loaded") if infer_chocie: paths = sorted(glob.glob('/fileser/zrx/Rib/alpha/493/temp/*nii.gz')) for path in paths: print ('image_path is',path) img = sitk.GetArrayFromImage(sitk.ReadImage(path)) img = img[np.newaxis,np.newaxis,...] print ('image range is',np.amin(img),np.amax(img)) data = torch.from_numpy(img).to(device=device,dtype=torch.float32) output = model(data,data) print ('output is',output[0][0].item()) probs = torch.sigmoid(output[0]) print ('result is',probs[0].item()) # # load data # pth to torchscript if convert_choice: patch_size = 48 ratio = 0.5 #input_shape = [1,1]+[int(patch_size*0.5),patch_size,patch_size] input_shape = [1,1]+[patch_size,patch_size,int(patch_size*0.5)] input_data = np.zeros(input_shape) input_data = torch.from_numpy(input_data).to(device=device,dtype=torch.float32) print ('type of input_data is',type(input_data)) model_base_path = '/fileser/zrx/alpha/4.103/NoduleSeg/noduleseg_20201102_48/1/' if not os.path.exists(model_base_path): os.makedirs(model_base_path) model_path = '%s/model.pt'%model_base_path traced_model = torch.jit.trace(model, input_data) traced_model.save(model_path)