import os
import sys
import torch
from torch.nn import DataParallel

sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')

#from model.net_cls_1024u_13 import Net
from net.net_cls_1024u_2d import Net

class TorchModel_2d(object):
    def __init__(self, model_path, GPUIndex):
        super(TorchModel_2d, self).__init__()
        os.environ['CUDA_VISIBLE_DEVICES'] = GPUIndex
        gpus = len(GPUIndex.split(','))
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        #print(self.device)
        model = Net()

        #将训练好的模型参数加载到模型内
        if 'cuda' == self.device:
            model_param = torch.load(model_path)
        else:
            model_param = torch.load(model_path, map_location=lambda storage, loc: storage)
        print(model_path)
        model.load_state_dict(model_param['state_dict'], strict=True)
        #将模型放到指定设备上
        #model = model.to(self.device)
        
        #通过多GPU进行模型部署
        if 'cuda' == self.device:
            device_ids = list(range(gpus))
            model = DataParallel(model, device_ids=device_ids)
            print('GPUIndex: ', model.device_ids)
        
        self.original_model = model

    def predict(self, data):
        self.original_model.eval()
        with torch.no_grad():
            data = data.to(self.device)
            return self.original_model(data).cpu().sigmoid().numpy()
    
    def normalize(self, data, min_value=-1000, max_value=600):
        new_data = data
        new_data[new_data < min_value] = min_value
        new_data[new_data > max_value] = max_value
        # normalize to [-1, 1]
        new_data = 2.0 * (new_data - min_value) / (max_value - min_value) - 1

        return new_data