import numpy as np import torch import pdb class SplitComb(object): def __init__(self, side_len, stride, margin): self.side_len = side_len self.stride = stride self.margin = margin self.n_zyx = None assert (side_len % stride == 0) assert (margin % stride == 0) def split(self, data, side_len=None, margin=None, pad_value=None): if side_len is None: side_len = self.side_len if margin is None: margin = self.margin assert (side_len > margin) assert (side_len % self.stride == 0) assert (margin % self.stride == 0) _, z, y, x = data.shape nz = int(np.ceil(float(z) / side_len)) ny = int(np.ceil(float(y) / side_len)) nx = int(np.ceil(float(x) / side_len)) self.n_zyx = [nz, ny, nx] pad = [[0, 0], [margin, nz * side_len - z + margin], [margin, ny * side_len - y + margin], [margin, nx * side_len - x + margin]] data = np.pad(data, pad, 'constant', constant_values=pad_value) splits = list() splits_idx = [] for iz in range(nz): for iy in range(ny): for ix in range(nx): sz = iz * side_len ez = (iz + 1) * side_len + 2 * margin sy = iy * side_len ey = (iy + 1) * side_len + 2 * margin sx = ix * side_len ex = (ix + 1) * side_len + 2 * margin split = data[np.newaxis, :, sz:ez, sy:ey, sx:ex] splits.append(split) splits_idx.append(np.array([iz,iy,ix]).reshape(1,3)) splits = np.concatenate(splits, 0) splits_idx = np.concatenate(splits_idx, axis=0) ## coord -0.5~0.5 # start = [0, 0, 0] # end = [sx, sy, sz] # _, z, y, x = data.shape # norm_start = np.array(start).astype('float32') / np.array([x, y, z]) - 0.5 # norm_end = np.array(end).astype('float32') / np.array([x, y, z]) # xx, yy, zz = np.meshgrid( # np.linspace(norm_start[0], norm_end[0], nx), # np.linspace(norm_start[1], norm_end[1], ny), # np.linspace(norm_start[2], norm_end[2], nz), # indexing='ij') # coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], axis=0).astype('float32') return splits, splits_idx, self.n_zyx def new_combine(self, output, patch_idx, n_zyx=None, side_len=None, margin=None): if side_len is None: side_len = self.side_len if margin is None: margin = self.margin assert (side_len % self.stride == 0) assert (margin % self.stride == 0) iz, iy, ix = patch_idx output[:, 2] = output[:, 2] + iz * side_len - margin output[:, 3] = output[:, 3] + iy * side_len - margin output[:, 4] = output[:, 4] + ix * side_len - margin return torch.from_numpy(output) def combine(self, output, n_zyx=None, side_len=None, margin=None): if side_len is None: side_len = self.side_len if margin is None: margin = self.margin if n_zyx.all() is None: nz = self.n_zyx[0] ny = self.n_zyx[1] nx = self.n_zyx[2] else: nz, ny, nx = n_zyx assert (side_len % self.stride == 0) assert (margin % self.stride == 0) splits = list() for i in range(len(output)): splits.append(output[i]) result = [] idx = 0 for iz in range(nz): for iy in range(ny): for ix in range(nx): split = splits[idx] split[:, 2] = split[:, 2] + iz * side_len - margin split[:, 3] = split[:, 3] + iy * side_len - margin split[:, 4] = split[:, 4] + ix * side_len - margin result.append(split) idx += 1 result = np.concatenate(result, 0) # result = torch.cat(result, 0) return torch.from_numpy(result) # def combine(self, output, n_zyx=None, side_len=None, margin=None): # self.stride = int(128 / output.shape[1]) # if side_len is None: # side_len = self.side_len # # if margin is None: # margin = self.margin # # if n_zyx.all() is None: # nz = self.n_zyx[0] # ny = self.n_zyx[1] # nx = self.n_zyx[2] # else: # nz, ny, nx = n_zyx # # assert (side_len % self.stride == 0) # assert (margin % self.stride == 0) # # side_len = int(side_len / self.stride) # margin = int(margin / self.stride) # # splits = list() # for i in range(len(output)): # splits.append(output[i]) # # result = np.ones((nz * side_len, ny * side_len, nx * side_len, # splits[0].shape[3], # len(config.anchors) # splits[0].shape[4]), np.float32) # 共5个维度(type, z, y, x, d) # # idx = 0 # for iz in range(nz): # for iy in range(ny): # for ix in range(nx): # sz = iz * side_len # ez = (iz + 1) * side_len # sy = iy * side_len # ey = (iy + 1) * side_len # sx = ix * side_len # ex = (ix + 1) * side_len # # split = splits[idx][margin:margin + side_len, margin:margin + side_len, margin:margin + side_len] # result[sz:ez, sy:ey, sx:ex] = split # idx += 1 # # return result