import os import sys import copy import math import time import glob import random import datetime import numpy as np import SimpleITK as sitk from time import time from scipy.ndimage import zoom from skimage.draw import line_aa # from matplotlib.lines import Line2D # from matplotlib import pyplot as plt from scipy.ndimage import gaussian_filter from skimage.measure import label,regionprops from collections import defaultdict,OrderedDict from scipy.ndimage.measurements import center_of_mass from skimage.morphology import binary_erosion,binary_dilation from scipy.ndimage import distance_transform_edt as distance from .respacing_func import * def Image_preprocess(image,deltas,ratio=1.0,input_shape=[64,64,64,1]): ''' Define image preprocess function 1. crop center patch with target_size*ratio 2. zoom center patch to target input_shape of network @image image: Target image.(float) @deltas: int array of length=3. Translation param on (z,y,x) @ratio: float value around 1.0. Define shape of center patch @input_shape: int array of length>=3. First 3 elements correspond to (depth,height,width) of network input @Return: final_image. 5D array. Image after preprocessing ''' center = [int(val/2) for val in image.shape] target_size = [int(round(ratio*val)) for val in input_shape[:3]] image_center_patch = image[int(max(0,center[0]-target_size[0]/2+deltas[0])):int(center[0]+target_size[0]/2+deltas[0]), int(max(0,center[1]-target_size[1]/2+deltas[1])):int(center[1]+target_size[1]/2+deltas[1]), int(max(0,center[1]-target_size[1]/2+deltas[2])):int(center[1]+target_size[1]/2+deltas[2])] #### zoom image_center_patch if its shape is not equal to input_shape zoom_values = [input_shape[idx]/float(image_center_patch.shape[idx]) for idx in range(len(image_center_patch.shape))] target_region = zoom(image_center_patch,zoom=zoom_values,order=1) final_image = target_region[np.newaxis,:,:,:,np.newaxis] return final_image def pad_mask(final_mask): ''' Pad mask image if there is inconsistent inside. before padding. np.sum(blank_slice) = 0 after padding blank_slice = 0.5*(slice_before_blank_slice + slice_after_blank_slice) ''' length = len(final_mask.shape) final_mask = np.squeeze(final_mask) nonzero_slices = [] for slice_idx in range(final_mask.shape[0]): if np.sum(final_mask[slice_idx])>0: nonzero_slices.append(slice_idx) if len(nonzero_slices)>0: min_slice,max_slice = min(nonzero_slices),max(nonzero_slices) for slice_idx in range(min_slice,max_slice): if slice_idx == 0 or slice_idx == final_mask.shape[0]-1: continue current_slice = final_mask[slice_idx] if np.sum(current_slice)==0: target_result = final_mask[slice_idx-1] + final_mask[slice_idx+1] target_result = target_result>0 final_mask[slice_idx] = target_result if length==4: return final_mask[:,:,:,np.newaxis] else: return final_mask[np.newaxis,:,:,:,np.newaxis] def Data_preprocess(image,mask,aug,ratio=1.0,input_shape=[64,64,64,1],shift_ratio=1.2): ''' Preprocess image& mask with same parameter image: original image mask: original mask(binary) image and mask should have same shape aug: bool. Whether do augmentation or not @Return. final_image,final_mask are both 5D array after preprocessing ''' ''' Extract the largest connected component from mask image if there exist more than 2 connected components ''' mask = mask>0 lb = label(np.squeeze(mask)) final_mask = mask max_area = 0 final_bbox = 0 for region in regionprops(lb,intensity_image=None): area = region.area if area>max_area: max_area = area final_mask = mask == region.label final_bbox = region.bbox ''' Calculate range of shift(deltas) based on diameter of current nodule ''' margin_value = 5 depth,height,width = input_shape[:3] final_mask = final_mask>0 num_pixels = np.sum(mask) diameter = int(float(num_pixels**(1/3.0))) radius = int(diameter/2) x_distance = int(min(max(3,radius*shift_ratio),height/2-(radius*shift_ratio+margin_value))) z_distance = int(x_distance/2) ### z_distance = 2 z_distance,x_distance = 2,5 if aug: delta_0 = np.random.randint(low=-z_distance,high=z_distance,size=1)[0] deltas = np.random.randint(low=-x_distance,high=x_distance,size=3) deltas[0] = delta_0 else: deltas = [0,0,0] ''' Crop same region on mask image and ori image to target shape ''' final_image = Image_preprocess(image,deltas,ratio=ratio,input_shape=input_shape) final_mask = Image_preprocess(mask,deltas,ratio=ratio,input_shape=input_shape) final_mask = final_mask>0 final_mask = pad_mask(final_mask) return final_image,final_mask def DataNormlization(data,info,HU_min=-1024,HU_max=1024,mode='mean'): result = np.zeros_like(data) for data_idx, (current_data,current_info) in enumerate(zip(data,info)): max_value,mean_value,std_value,min_value = current_info[-4:].astype('float') if mode == 'mean': current_mean,current_std = mean_value,std_value elif mode == 'HU': current_data = np.clip(current_data,HU_min,HU_max) current_mean,current_std = HU_min,float(HU_max-HU_min) elif mode == 'self': current_mean,current_std = np.mean(current_data),np.std(current_data) else: current_mean,current_std = min_value,(max_value-min_value) current_data = (current_data - current_mean)/current_std result[data_idx] = current_data return result def CalculateMaxDisatnce(mask): ''' calculate max length of mask in x and y direction by slice. 逐层计算x/y方向距离中心点(半边bbox)的最大值,取最大,最后*2 ''' center = int(mask.shape[1]/2) distance = 0 for mask_slice in mask: if np.sum(mask_slice)==0: continue points = np.where(mask>0) y_min,y_max,x_min,x_max = np.amin(points[0]),np.amax(points[0]),np.amin(points[1]),np.amax(points[1]) # distance = max(distance,max(y_max-center+1,center-y_min+1,x_max-center+1,center-x_min+1)) distance = max(distance,max(y_max-y_min,x_max-x_min)+1) return distance def CalculatePatchSizeWithDiameter(mask,stride_ratio=8,diameter_enlarge_ratio=1.5,direction_z_ratio=1,max_patch_size=64,min_patch_size=8,aug=True,diameter_px=None): ''' calculate patch size of nodule based on its diameter and 1. stride_ratio: patch_size should be stride_ratio * N. stride_ratio is same as the one in the network 2. diameter_enlarge_ratio. how much ratio should be left in the patch 3. maximum patch size is 64,minimum patch size is 8 ''' if diameter_px is None: diameter_px = CalculateMaxDisatnce(mask) if aug: diameter_ratio = np.random.randint(90,120,size=1)[0]/100.0 else: diameter_ratio = 1.0 diameter_px *= diameter_ratio target_patch_size = math.ceil(diameter_px/float(stride_ratio) * diameter_enlarge_ratio )*stride_ratio target_inplane_size = min(max_patch_size,max(min_patch_size,int(target_patch_size))) return [(target_inplane_size*direction_z_ratio),target_inplane_size,target_inplane_size,1] def CalculatePatchSizeWithDiameterV2(mask,stride_ratio=8,diameter_enlarge_ratio=1.5,direction_z_ratio=1,max_patch_size=64,min_patch_size=8,aug=True,diameter_px=None): ''' calculate patch size of nodule based on its diameter and 1. stride_ratio: patch_size should be stride_ratio * N. stride_ratio is same as the one in the network 2. diameter_enlarge_ratio. how much ratio should be left in the patch 3. maximum patch size is 64,minimum patch size is 8 ''' if diameter_px is None: diameter_px = CalculateMaxDisatnce(mask) ''' train的阶段加入一些patch大小的变换以应对 mask计算长径和 alpha分割阶段输入长径大小不同的问题 ''' if aug: diameter_ratio = np.random.randint(90,150,size=1)[0]/100.0 else: diameter_ratio = 1.0 ''' 小结节用小stride,大结节用大的stride ''' if diameter_px<2*min_patch_size: stride_ratio = stride_ratio else: stride_ratio = max(min_patch_size,stride_ratio) diameter_px *= diameter_ratio target_patch_size = int(round(diameter_px/float(stride_ratio) * diameter_enlarge_ratio )*stride_ratio) target_inplane_size = min(max_patch_size,max(min_patch_size,int(target_patch_size))) return [(target_inplane_size*direction_z_ratio),target_inplane_size,target_inplane_size,1] def CalculatePatchSizeWithDiameterV3(mask,stride_ratio=8,diameter_enlarge_ratio=1.5,direction_z_ratio=1,max_patch_size=64,min_patch_size=8,aug=True,diameter_px=None): ''' calculate patch size of nodule based on its diameter and 1. stride_ratio: patch_size should be stride_ratio * N. stride_ratio is same as the one in the network 2. diameter_enlarge_ratio. how much ratio should be left in the patch 3. maximum patch size is 64,minimum patch size is 8 ''' if diameter_px is None: diameter_px = CalculateMaxDisatnce(mask) if aug: diameter_ratio = np.random.randint(90,150,size=1)[0]/100.0 else: diameter_ratio = 1.0 diameter_px *= diameter_ratio target_patch_size = int(round(diameter_px/float(stride_ratio) * diameter_enlarge_ratio )*stride_ratio) target_inplane_size = min(max_patch_size,max(min_patch_size,int(target_patch_size))) return [(target_inplane_size*direction_z_ratio),target_inplane_size,target_inplane_size,1] def calc_dist_map(seg): ##### distance val increase from boundary res = np.zeros_like(seg) posmask = seg.astype(np.bool) if posmask.any(): negmask = ~posmask neg_distance = distance(negmask)*negmask pos_distance = -(distance(posmask)-1)*posmask res = distance(negmask)*negmask-(distance(posmask)-1)*posmask return res def calc_dist_map_batch(y_true): y_true_numpy = y_true return np.array([calc_dist_map(y) for y in y_true_numpy]).astype(np.float32) def RespacingImage(image,mask,target_spacing = (4,4,4)): itk_mask = sitk.GetArrayFromImage(mask) itkImage,OriginalSpacing,OriginalSize = LinearResample(image, target_spacing) itk_mask.SetSpacing(OriginalSpacing) itkLung = NearestResample(itk_mask, target_spacing) itkLung.SetDirection(itkImage.GetDirection()) itkLung.SetOrigin(itkImage.GetOrigin()) itkLung.SetSpacing(itkImage.GetSpacing()) return itkImage, itkLung,OriginalSpacing,OriginalSize def imagepair_respacing(image,mask,ori_spacing,target_spacing): itk_image = sitk.GetImageFromArray(image) itk_mask = sitk.GetImageFromArray(mask) itk_image.SetSpacing(ori_spacing[::-1]) itk_mask.SetSpacing(ori_spacing[::-1]) new_itkImage,OriginalSpacing,OriginalSize = LinearResample(itk_image, target_spacing) itk_mask.SetSpacing(OriginalSpacing) new_itkMask = NearestResample(itk_mask, target_spacing,new_itkImage.GetSize()) new_itkMask.SetSpacing(new_itkImage.GetSpacing()) return new_itkImage,new_itkMask def PadToTargetShape(image,target_shape): diff = [max(0,y-x) for y,x in zip(target_shape,image.shape)] left_pad_values = [int(val/2) for val in diff] right_pad_values = [y-x for y,x in zip(diff,left_pad_values)] pad_values = [(x,y) for x,y in zip(left_pad_values,right_pad_values)] image = np.pad(image,pad_values,mode='constant') return image def CropToTargetShape(image,target_shape): diff = [max(0,y-x) for y,x in zip(image.shape,target_shape)] left_diff = [int(val/2) for val in diff] return image[left_diff[0]:target_shape[0]+left_diff[0], left_diff[1]:target_shape[1]+left_diff[1], left_diff[2]:target_shape[2]+left_diff[2]] def GetTargetShapePatch(image,target_shape): image = PadToTargetShape(image,target_shape) image = CropToTargetShape(image,target_shape) return image def imagepari_target_shape_respacing(image,mask,ori_spacing,target_shape): image = np.squeeze(image) mask = np.squeeze(mask) mask = mask.astype(np.uint8) ratios = [val_1/float(val_2) for val_1,val_2 in zip(image.shape,target_shape)] new_spacing = [spac_val*ratio_val for spac_val,ratio_val in zip(ori_spacing,ratios)] target_spacing = new_spacing[::-1] new_itkImage,new_itkMask = imagepair_respacing(image,mask,ori_spacing,target_spacing) new_image,new_mask = sitk.GetArrayFromImage(new_itkImage),sitk.GetArrayFromImage(new_itkMask) image,mask = GetTargetShapePatch(new_image,target_shape),GetTargetShapePatch(new_mask,target_shape) # print ('final image shape is',image.shape,mask.shape) return image,mask,new_spacing