import torch import pdb def train_collate(batch): batch_size = len(batch) inputs = torch.stack([batch[b][0] for b in range(batch_size)], 0) bboxes = [batch[b][1] for b in range(batch_size)] return [inputs, bboxes] def test_ddp_collate(batch): batch_size = len(batch) inputs = torch.stack([batch[b][0] for b in range(batch_size)], 0) patch_idx = [batch[b][1] for b in range(batch_size)] n_zyx = [batch[b][2] for b in range(batch_size)] return [inputs, patch_idx, n_zyx] def train_collate_center(batch): batch_size = len(batch) inputs = torch.stack([batch[b][0] for b in range(batch_size)], 0) gaussian_hm = torch.stack([batch[b][1] for b in range(batch_size)], 0) center_idxs = torch.stack([batch[b][2] for b in range(batch_size)], 0) bboxes_diameters = torch.stack([batch[b][3] for b in range(batch_size)], 0) reg_offset = torch.stack([batch[b][4] for b in range(batch_size)], 0) reg_mask = torch.stack([batch[b][5] for b in range(batch_size)], 0) return [inputs, gaussian_hm, center_idxs, bboxes_diameters, reg_offset, reg_mask] def test_collate(batch): batch_size = len(batch) for b in range(batch_size): inputs = torch.stack([batch[b][0]for b in range(batch_size)], 0) images = [batch[b][1] for b in range(batch_size)] return [inputs, images]