# -*- coding:utf-8 -*- import os import sys import glob import math import random import numpy as np import pandas as pd import scipy.ndimage as nd from scipy.ndimage import zoom from keras.utils import to_categorical from torch.utils.data import Dataset from collections import Counter, defaultdict def SplitPatch(mask,image,patch_shape,stride=2): ''' split image & mask based on patch shape 1. step one: calculate num patch ''' num_patches = [int(math.ceil(val1/val2)) for val1,val2 in zip(mask.shape,patch_shape)] num_patches = [val+1 for val in num_patches] target_shape = [val1*val2 for val1,val2 in zip(num_patches,patch_shape)] diff = [val1-val2 for val1,val2 in zip(target_shape,mask.shape)] pad_left = [int(val/2) for val in diff] pad_right = [val1-val2 for val1,val2 in zip(diff,pad_left)] pad_vals = [(val1,val2) for val1,val2 in zip(pad_left,pad_right)] pad_mask,pad_img = np.pad(mask,pad_vals,mode='constant'),np.pad(image,pad_vals,mode='constant') ''' 2 step two: split image ''' patch_intervals = [int(val/stride) for val in patch_shape] output_img_patches = [] output_mask_patches = [] z_center_vals,y_center_vals,x_center_vals = set([]),set([]),set([]) margin_add_val = 1 patch_count = 0 for left_z in range(0,target_shape[0]-patch_shape[0]+margin_add_val,patch_intervals[0]): for left_y in range(0,target_shape[1]-patch_shape[1]+margin_add_val,patch_intervals[1]): for left_x in range(0,target_shape[2]-patch_shape[2]+margin_add_val,patch_intervals[2]): current_mask_patch = pad_mask[left_z:left_z+patch_shape[0],left_y:left_y+patch_shape[1],left_x:left_x+patch_shape[2]] current_img_patch = pad_img[left_z:left_z+patch_shape[0],left_y:left_y+patch_shape[1],left_x:left_x+patch_shape[2]] patch_count+=1 output_img_patches.append(current_img_patch) output_mask_patches.append(current_mask_patch) return output_img_patches,output_mask_patches,pad_vals,pad_mask def CombinePatch(predictions,pad_mask,pad_vals,stride=2): result = np.zeros_like(pad_mask).astype(np.float32) cand_results = np.zeros_like(pad_mask).astype(np.float32) patch_shape = np.squeeze(predictions[0]).shape template_single_patch = np.ones(patch_shape).astype(np.float32) patch_intervals = [int(val/stride) for val in patch_shape] patch_idx = 0 margin_add_val = 1 for left_z in range(0,result.shape[0]-patch_shape[0]+margin_add_val,patch_intervals[0]): for left_y in range(0,result.shape[1]-patch_shape[1]+margin_add_val,patch_intervals[1]): for left_x in range(0,result.shape[2]-patch_shape[2]+margin_add_val,patch_intervals[2]): result[left_z:left_z+patch_shape[0],left_y:left_y+patch_shape[1],left_x:left_x+patch_shape[2]]+=np.squeeze(predictions[patch_idx]) cand_results[left_z:left_z+patch_shape[0],left_y:left_y+patch_shape[1],left_x:left_x+patch_shape[2]]+=template_single_patch patch_idx +=1 output_shape = [] for idx in range(3): current_val = pad_mask.shape[idx] - sum(pad_vals[idx]) output_shape.append(current_val) result = result[pad_vals[0][0]:pad_vals[0][0]+output_shape[0], pad_vals[1][0]:pad_vals[1][0]+output_shape[1], pad_vals[2][0]:pad_vals[2][0]+output_shape[2]] cand_results = cand_results[pad_vals[0][0]:pad_vals[0][0]+output_shape[0], pad_vals[1][0]:pad_vals[1][0]+output_shape[1], pad_vals[2][0]:pad_vals[2][0]+output_shape[2]] print ('max of cand_results is ',np.unique(cand_results)) return result/cand_results