import torch from torch.utils.data import Dataset import sys import os import numpy as np sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') from cls_utils.data import load_data_csv, get_data_and_label from cls_utils.utils import normalize, hu_value_to_uint8, hu_normalize class SubjectDataset(Dataset): def __init__(self, np_path, csv_path, is_train=True, is_2d=False, augment=False, permute=False): self._npy_path = np_path self._csv_path = csv_path #self._cfg = cfg self._is_train = is_train self._is_2d = is_2d self.augment = augment self.permute = permute self._preprocess() def _preprocess(self): self.subject_ids = load_data_csv(self._csv_path) def __len__(self): return self.subject_ids.shape[0] def _get_data(self, i): idx = i % len(self.subject_ids) #获取当前指定索引对应的npy文件的路径 subject_id = self.subject_ids.iloc[idx, 0] label = int(self.subject_ids.iloc[idx, 1]) data = get_data_and_label(self._npy_path, subject_id, self._is_2d, self.augment, self.permute) #print(data.shape) return data, np.asarray([label]) def __getitem__(self, idx): data, label = self._get_data(idx) if data is None or label is None: data, label = self._get_data(0) data = hu_normalize(data) #data = normalize(hu_value_to_uint8(data)) return torch.from_numpy(data.copy()), torch.tensor(label).type(torch.float32)