# -*- coding: utf-8 -*- import random import itertools import numpy as np def random_permute_key(): """ Generates and randomly selects a permute key. """ return random.choice(list(generate_general_permute_keys())) def generate_permute_keys(): """ This function returns a set of "keys" that represent the 48 unique rotations & reflections of a 3D matrix. Each item of the set is a tuple: ((rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose) """ return set(itertools.product( itertools.combinations_with_replacement(range(2), 2), range(2), range(2), range(2), range(2))) # def generate_general_indexes(): # indexes = [[0, 1, 2, 3, 4, 5, 6, 7], # [1, 2, 3, 0, 5, 6, 7, 4], # [2, 3, 0, 1, 6, 7, 4, 5], # [3, 0, 1, 2, 7, 4, 5, 6], # [4, 5, 6, 7, 0, 1, 2, 3], # [5, 6, 7, 4, 1, 2, 3, 0], # [6, 7, 4, 5, 2, 3, 0, 1], # [7, 4, 5, 6, 3, 0, 1, 2]] # return indexes def generate_general_indexes(): indexes = [[0, 2], [1, 3], [2, 0], [3, 1], [4, 6], [5, 7], [6, 4], [7, 5]] return indexes def generate_general_permute_keys(): rotate_y = 0 transpose = 0 keys = [((rotate_y, 0), 0, 0, 0, transpose), ((rotate_y, 1), 0, 0, 0, transpose), ((rotate_y, 0), 0, 1, 1, transpose), ((rotate_y, 1), 0, 1, 1, transpose), ((rotate_y, 0), 1, 0, 0, transpose), ((rotate_y, 1), 1, 0, 0, transpose), ((rotate_y, 0), 1, 1, 1, transpose), ((rotate_y, 1), 1, 1, 1, transpose)] return keys def permute_data(data, key): if data is None or key == ((0, 0), 0, 0, 0, 0): return data data = np.copy(data) (rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key if rotate_y != 0: data = np.rot90(data, rotate_y, axes=(0, 2)) if rotate_z != 0: data = np.rot90(data, rotate_z, axes=(1, 2)) if flip_z: data = np.flip(data, axis=0) if flip_y: data = np.flip(data, axis=1) if flip_x: data = np.flip(data, axis=2) if transpose: add_axes = [i for i in range(3, len(data.shape))] data = np.transpose(data, axes=[2, 1, 0] + add_axes) return data def reverse_permute_data(data, key): if data is None or key == ((0, 0), 0, 0, 0, 0): return data data = np.copy(data) key = reverse_permute_key(key) (rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key if transpose: add_axes = [i for i in range(3, len(data.shape))] data = np.transpose(data, axes=[2, 1, 0] + add_axes) if flip_x: data = np.flip(data, axis=2) if flip_y: data = np.flip(data, axis=1) if flip_z: data = np.flip(data, axis=0) if rotate_z != 0: data = np.rot90(data, rotate_z, axes=(1, 2)) if rotate_y != 0: data = np.rot90(data, rotate_y, axes=(0, 2)) return data def reverse_permute_key(key): rotation = tuple([-rotate for rotate in key[0]]) return rotation, key[1], key[2], key[3], key[4] def augment(data, targets=None, mask=None, weight=None): key = random_permute_key() (rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key rotate_y = 0 transpose = 0 key = (rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose aug_data = permute_data(data, key) aug_targets = targets if targets is not None: for i in range(len(aug_targets)): aug_targets[i] = permute_data(aug_targets[i], key) transform_target(aug_targets[i], key) aug_mask = permute_data(mask, key) aug_weight = permute_data(weight, key) return aug_data, aug_targets, aug_mask, aug_weight def transform_target(target, key): (rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key if rotate_y: z = target[..., 1].copy() x = target[..., 3].copy() target[..., 1] = -1.0 * x target[..., 3] = z if rotate_z: y = target[..., 2].copy() x = target[..., 3].copy() target[..., 2] = -1.0 * x target[..., 3] = y if flip_z: target[..., 1] = -1.0 * target[..., 1] if flip_y: target[..., 2] = -1.0 * target[..., 2] if flip_x: target[..., 3] = -1.0 * target[..., 3] if transpose: z = target[..., 1].copy() x = target[..., 3].copy() target[..., 1] = x target[..., 3] = z def data_test(): data = np.arange(40).reshape((2, 2, 2, 1, 5)) print('data', data) history = [] for key in list(generate_general_permute_keys()): print('key', key) new_data = permute_data(data, key) # print('new_data', new_data) for old_data in history: if np.all(old_data == new_data): print('new_data is exist!!!') break history.append(new_data) reverse_data = reverse_permute_data(new_data, key) if not np.all(data == reverse_data): print('error:', reverse_data) def data_test2(): data = np.arange(60).reshape((3, 2, 2, 1, 5)) for i in range(10000): key = random_permute_key() print('key', key) new_data = permute_data(data, key) if new_data.shape != data.shape: print('error:', new_data.shape) def data_test3(): data = np.arange(40).reshape((2, 2, 2, 1, 5)) print('data', data) key = random_permute_key() (rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key rotate_y = 0 rotate_z = 1 flip_z = 0 flip_y = 0 flip_x = 0 transpose = 0 key = (rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose new_data = permute_data(data, key) transform_target(new_data, key) print('new_data', new_data) def data_test4(): data = np.arange(27).reshape((3, 3, 3)) print('data', data) history = [] for key in list(generate_general_permute_keys()): print('key', key) new_data = permute_data(data, key) print('new_data', new_data) for old_data in history: if np.all(old_data == new_data): print('new_data is exist!!!') break history.append(new_data) reverse_data = reverse_permute_data(new_data, key) if not np.all(data == reverse_data): print('error:', reverse_data) if __name__ == '__main__': data_test2()