from torch.utils.data import Dataset

class SubReader(Dataset):
    def __init__(self, sub_dataset, patch_idx, n_zyx, coord) -> None:
        super().__init__()
        self.sub_dataset =sub_dataset
        self.patch_idx = patch_idx
        self.n_zyx = n_zyx
        self.coord = coord

    def __len__(self):
        return len(self.sub_dataset)
    
    def __getitem__(self, idx):
        inputs = self.sub_dataset[idx]
        p_idx = self.patch_idx[idx]
        n_zyx = self.n_zyx[idx]
        coord = self.coord[idx]
        return inputs, p_idx, n_zyx, coord