from torch.utils.data import Dataset class SubReader(Dataset): def __init__(self, sub_dataset, patch_idx, n_zyx, coord) -> None: super().__init__() self.sub_dataset =sub_dataset self.patch_idx = patch_idx self.n_zyx = n_zyx self.coord = coord def __len__(self): return len(self.sub_dataset) def __getitem__(self, idx): inputs = self.sub_dataset[idx] p_idx = self.patch_idx[idx] n_zyx = self.n_zyx[idx] coord = self.coord[idx] return inputs, p_idx, n_zyx, coord