import torch
import pdb

def train_collate(batch):
    batch_size = len(batch)
    inputs = torch.stack([batch[b][0] for b in range(batch_size)], 0)
    bboxes = [batch[b][1] for b in range(batch_size)]
    return [inputs, bboxes]

def test_ddp_collate(batch):
    batch_size = len(batch)
    inputs = torch.stack([batch[b][0] for b in range(batch_size)], 0)
    patch_idx = [batch[b][1] for b in range(batch_size)]
    n_zyx = [batch[b][2] for b in range(batch_size)]
    return [inputs, patch_idx, n_zyx]

def train_collate_center(batch):
    batch_size = len(batch)
    inputs = torch.stack([batch[b][0] for b in range(batch_size)], 0)
    gaussian_hm = torch.stack([batch[b][1] for b in range(batch_size)], 0)
    center_idxs = torch.stack([batch[b][2] for b in range(batch_size)], 0)
    bboxes_diameters = torch.stack([batch[b][3] for b in range(batch_size)], 0)
    reg_offset = torch.stack([batch[b][4] for b in range(batch_size)], 0)
    reg_mask = torch.stack([batch[b][5] for b in range(batch_size)], 0)
    return [inputs, gaussian_hm, center_idxs, bboxes_diameters, reg_offset, reg_mask]


def test_collate(batch):
    batch_size = len(batch)
    for b in range(batch_size): 
        inputs = torch.stack([batch[b][0]for b in range(batch_size)], 0)
        images = [batch[b][1] for b in range(batch_size)]

    return [inputs, images]