import os import sys import copy import math import time import glob import random import datetime import numpy as np from time import time from scipy.ndimage import zoom from skimage.draw import line_aa from matplotlib.lines import Line2D from matplotlib import pyplot as plt from scipy.ndimage import gaussian_filter from skimage.measure import label,regionprops from collections import defaultdict,OrderedDict from scipy.ndimage.measurements import center_of_mass from skimage.morphology import binary_erosion,binary_dilation def get_bbox(mask_slice): target_region = np.asarray(np.where(mask_slice>0)) try: y_min,y_max = np.amin(target_region[0]),np.amax(target_region[0]) except: y_min,y_max = 0,1 try: x_min,x_max = np.amin(target_region[1]),np.amax(target_region[1]) except: x_min,x_max = 0,1 return y_min,y_max,x_min,x_max def get_max_slice(mask_result): center_slice = int(mask_result.shape[0]/2) num_pixel_array = [0 for _ in range(mask_result.shape[0])] for slice_idx in range(center_slice-2,center_slice+3): num_pixels = np.sum(mask_result[slice_idx]) num_pixel_array[slice_idx] = num_pixels current_center_slice = np.argmax(num_pixel_array) delta = current_center_slice - center_slice return delta,current_center_slice def get_lines(start_point,dire,inp_shape): voxel_size = int(inp_shape[-1]/2) dy,dx = dire start_y,start_x = start_point start_y,start_x = int(round(start_y)),int(round(start_x)) final_results = [(start_y+k*dy,start_x+k*dx) for k in range(int(-voxel_size/2),int(voxel_size/2+1)) if (start_y+k*dy in range(0,voxel_size)) and (start_x+k*dx) in range(0,voxel_size)] return final_results def get_maxLength(target_slice,final_center,inp_shape): all_direction = ((-1,-1),(-1,0),(-1,1),(0,1)) direction_map = {(-1,-1):(-1,1),(-1,0):(0,1),(-1,1):(-1,-1),(0,1):(-1,0)} max_len = 0 line_length_map = defaultdict(float) for dire in all_direction: current_line_points = get_lines((int(round(final_center[1])),int(round(final_center[2]))),dire,inp_shape) line_length = np.sum([target_slice[y,x] for (y,x) in current_line_points]) max_len = max(max_len,line_length) line_length_map[dire] = line_length for key in line_length_map.keys(): if line_length_map[key] == max_len: min_len = line_length_map[direction_map[key]] return max_len,min_len def get_angle(ang): result_ang = -1 if ang-90<0: result_ang = ang+90 else: result_ang = ang-90 return result_ang def CalculateNoduleLength(center_slice,final_center,inp_shape,debug=False): # logger = logging.getLogger(__name__) center_slice = center_slice.astype(np.uint8) bbox = get_bbox(center_slice) y_min,y_max,x_min,x_max = bbox # logger.debug('bbox is %s'%str(bbox)) unit = 5 number_result = int(180/unit) height,width = np.squeeze(center_slice).shape line_mask_image = np.zeros((number_result,height,width)) ##### define result parameters max_angle = -1 NoduleMaxLength = 0 angle_MaxLengthMap = defaultdict(float) line_point_map = defaultdict(list) if np.sum(center_slice) == 0: NoduleMaxLength,minLength = get_maxLength(center_slice,final_center,inp_shape) maxLength_angle,minLength_angle = 0,90 for angle in range(0,180,unit): max_length = 0 radi = math.radians(angle) tang = math.tan(radi) if angle == 90: for x_value in range(x_min,x_max): current_point = [y_value for y_value in range(y_min,y_max+1) if center_slice[y_value,x_value]>0] current_length = len(current_point) if current_length>max_length: max_length = max(current_length,max_length) img = copy.deepcopy(center_slice) start_y,start_x,end_y,end_x = min(current_point),x_value,max(current_point),x_value rr,cc,val = line_aa(start_y,start_x,end_y,end_x) img[rr,cc] = 0 diff_img = center_slice - img if np.sum(diff_img)==0: continue line_bbox = get_bbox(diff_img) line_mask_image[int(angle/5)] = np.squeeze(diff_img) line_point_map[angle] = [start_x,end_x,start_y,end_y] elif angle == 0: for y_value in range(y_min,y_max+1): current_point = [x_value for x_value in range(x_min,x_max+1) if center_slice[y_value,x_value]>0] current_length = len(current_point) if current_length>max_length: max_length = max(current_length,max_length) img = copy.deepcopy(center_slice) start_y,start_x,end_y,end_x = y_value,min(current_point),y_value,max(current_point) rr,cc,val = line_aa(start_y,start_x,end_y,end_x) img[rr,cc] = 0 diff_img = center_slice - img if np.sum(diff_img)==0: continue line_bbox = get_bbox(diff_img) line_mask_image[int(angle/5)] = np.squeeze(diff_img) line_point_map[angle] = [start_x,end_x,start_y,end_y] else: for start_y in range(y_min,y_max): for end_y in range(start_y+1,y_max+1): img = copy.deepcopy(center_slice) diff = end_y - start_y if tang!=0: diff_x = diff/tang if diff_x<0: end_x = max(x_min,int(round(x_max+diff_x))) start_x = x_max else: start_x = x_min end_x = min(x_max,int(round(x_min + diff_x))) rr,cc,val = line_aa(start_y,start_x,end_y,end_x) try: img[rr,cc] = 0 diff_img = center_slice - img if np.sum(diff_img)==0: continue line_bbox = get_bbox(diff_img) line_y_min,line_y_max,line_x_min,line_x_max = line_bbox line_length = float((line_y_max-line_y_min)**2+(line_x_max-line_x_min)**2)**0.5 if line_length>max_length: max_length = max(line_length,max_length) line_mask_image[int(angle/5)] = np.squeeze(diff_img) if tang<0: line_point_map[angle] = [line_x_max,line_x_min,line_y_min,line_y_max] else: line_point_map[angle] = [line_x_min,line_x_max,line_y_min,line_y_max] # if debug: # print ('Result of ') # _,axs = plt.subplots(1,3) # axs[0].imshow(center_slice) # axs[1].imshow(img) # axs[2].imshow(diff_img) # axs[0].set_title('current max_length is %.2f'%max_length) # plt.show() except Exception as err: # print ('err',err) continue NoduleMaxLength = max(NoduleMaxLength,max_length) angle_MaxLengthMap[angle] = max_length # logger.debug('angle_MaxLengthMap is %s'%str(angle_MaxLengthMap)) NoduleMaxLength = np.amax([angle_MaxLengthMap[x] for x in angle_MaxLengthMap]) maxLength_angles = sorted([x for x in angle_MaxLengthMap.keys() if angle_MaxLengthMap[x]==NoduleMaxLength]) minLength_angles = sorted([get_angle(x) for x in maxLength_angles]) minLengthAngleDiameters = [angle_MaxLengthMap[x] for x in minLength_angles] minLength_angle_idx = np.argmax(minLengthAngleDiameters) minLength_angle = minLength_angles[minLength_angle_idx] minLength_angle_img = line_mask_image[minLength_angles[minLength_angle_idx]/5] maxLength_angle = get_angle(minLength_angles[minLength_angle_idx]) maxLength_angle_img = line_mask_image[maxLength_angle/5] minLength = max(minLengthAngleDiameters) ''' Plot MaxDiam and MinDiam over image ''' if debug: print ('Plot MaxDiam and MinDiam over center slice') try: max_line_para = line_point_map[maxLength_angle] min_line_para = line_point_map[minLength_angle] # print max_line_para,min_line_para max_l = Line2D(max_line_para[:2],max_line_para[2:],color='r') min_l = Line2D(min_line_para[:2],min_line_para[2:],color='y') max_l_m = Line2D(max_line_para[:2],max_line_para[2:],color='r') min_l_m = Line2D(min_line_para[:2],min_line_para[2:],color='b') _,axs = plt.subplots(1,4) axs[0].imshow(center_slice,cmap='bone') axs[1].imshow(center_slice) axs[1].add_line(max_l_m) axs[1].add_line(min_l_m) axs[2].imshow(minLength_angle_img) axs[3].imshow(maxLength_angle_img) plt.show() except: print ('') return NoduleMaxLength,minLength def post_process(mask_image,spacing_z,inp_shape,debug): ######### Erode lb = label(np.squeeze(mask_image)) max_area = 0 center_slice = int(np.squeeze(mask_image).shape[0]/2) final_mask = np.asarray(mask_image) current_label_value = -1 final_center = [int(val/2) for val in inp_shape[:3]] num_slice = inp_shape[0] for region in regionprops(lb,intensity_image=None): centroid = region.centroid area = region.area temp_label = region.label center_slice_area = np.sum((lb==temp_label)[int(centroid[0])]) diff = abs(centroid[0]-num_slice/2) fake_long_diam = region.major_axis_length if abs(centroid[0]-num_slice/2)<= max(3,fake_long_diam/spacing_z * 1/3.0): if area>max_area : # print ('area,max_area',area,max_area) current_label_value = temp_label center_slice = int(round(centroid[0])) final_mask = lb == temp_label max_area = max(max_area,area) final_center = centroid final_area = max_area max_len,min_len = 0,0 # max_len,min_len = get_maxLength(final_mask[int(round(final_center[0]))],final_center,inp_shape) # max_len,min_len = CalculateNoduleLength(final_mask[int(round(final_center[0]))],final_center,inp_shape,debug=debug) return center_slice,final_center,final_mask,max_len,min_len