import json import math import os import random import time import traceback import warnings import copy import torch import SimpleITK as sitk import lmdb import numpy as np # import sympy as sp from skimage.measure import regionprops,label from scipy.ndimage import zoom from scipy.ndimage.interpolation import rotate from .bbox_reader import BboxReader, BBoxes, norm, pad2factor class CenterReader(BboxReader): def __init__(self, cfg, mode="train"): super(CenterReader, self).__init__(cfg, mode) self.crop = Crop_with_mask(cfg) self.mask_dir = cfg.DATA.DATA_LOADER.MASK_DIR self.max_objs = cfg.DATA.DATA_PROCESS.MAX_OBJS self.downsampling_ratio = cfg.DATA.DATA_PROCESS.DOWNSAMPLING_RATIO self.min_overlap = cfg.DATA.DATA_PROCESS.MIN_OVERLAP def __getitem__(self, idx): # t = time.time() # np.random.seed(int(str(t % 1)[2:7])) # seed according to time is_random_img = False if self.mode in ['train', 'val']: if self.balanced_sampling_flag: total_len = len(self.list_IDs) else: total_len = len(self.targets_idx) if idx >= total_len: is_random_crop = True idx = idx % total_len is_random_img = np.random.randint(2) else: is_random_crop = False else: is_random_crop = False if self.balanced_sampling_flag: idx = self.list_IDs[idx] if self.mode in ['train', 'val']: if not is_random_img: target_idx = self.targets_idx[idx] bboxes = self.sample_bboxes[target_idx[0]] new_bboxes = BBoxes(np.copy(bboxes.data)) new_bboxes._copy_extra_fields(bboxes) target_bbox = new_bboxes.data[target_idx[1]] filename = new_bboxes.get_field("filename") mask_filename = filename.replace(self.data_dir,self.mask_dir) mask_filename = mask_filename.replace('/resizedImage.nii.gz','') # mask_filename = mask_filename[:-20] imgs = self._load_img(filename) mask = self._load_mask(mask_filename) try: sample, mask, new_bboxes = self.crop(imgs, mask, target_bbox, new_bboxes, is_random_crop) except: print('Crop fail in %s.' % filename) return assert sample.shape[1] == self.crop_size[0] and sample.shape[2] == self.crop_size[1] \ and sample.shape[3] == self.crop_size[2], 'crop patch of {} has illegal shape: {}'.format(filename,sample.shape) assert mask.shape[1] == self.crop_size[0] and mask.shape[2] == self.crop_size[1] \ and mask.shape[3] == self.crop_size[2], 'crop mask patch of {} has illegal shape: {}'.format(filename,mask.shape) if self.mode == 'train' and not is_random_crop: sample, mask, new_bboxes.data = augment_with_mask(sample, mask, target_bbox, new_bboxes.data, do_flip=self.augtype.FLIP, do_rotate=self.augtype.ROTATE, do_swap=self.augtype.SWAP) else: randimid = np.random.randint(len(self.sample_bboxes)) bboxes = self.sample_bboxes[randimid] new_bboxes = BBoxes(np.copy(bboxes.data)) new_bboxes._copy_extra_fields(bboxes) filename = new_bboxes.get_field("filename") mask_filename = filename.replace(self.data_dir,self.mask_dir) mask_filename = mask_filename.replace('/resizedImage.nii.gz','') # mask_filename = mask_filename[:-20] imgs = self._load_img(filename) mask = self._load_mask(mask_filename) sample, mask, new_bboxes = self.crop(imgs, mask, [], new_bboxes, isRand=True) assert sample.shape[1] == self.crop_size[0] and sample.shape[2] == self.crop_size[1] \ and sample.shape[3] == self.crop_size[2], 'patch of {} has illegal shape: {}'.format(filename,sample.shape) assert mask.shape[1] == self.crop_size[0] and mask.shape[2] == self.crop_size[1] \ and mask.shape[3] == self.crop_size[2], 'mask patch of {} has illegal shape: {}'.format(filename,mask.shape) gaussian_hm, center_idxs, bboxes_diameters, reg_offset, reg_mask = self.draw_gaussian_from_bbox_and_mask(new_bboxes,mask) # tmp_gaussian_hm = np.zeros((64,64,64)) # tmp_gaussian_hm[(gaussian_hm[0]<0)] = 1 # tmp_gaussian_hm[(gaussian_hm[0]>0)&(gaussian_hm[0]<=0.1)] = 2 # tmp_gaussian_hm[(gaussian_hm[0]>0.1)&(gaussian_hm[0]<=0.2)] = 3 # tmp_gaussian_hm[(gaussian_hm[0]>0.2)&(gaussian_hm[0]<=0.3)] = 4 # tmp_gaussian_hm[(gaussian_hm[0]>0.3)&(gaussian_hm[0]<=0.4)] = 5 # tmp_gaussian_hm[(gaussian_hm[0]>0.4)&(gaussian_hm[0]<=0.5)] = 6 # tmp_gaussian_hm[(gaussian_hm[0]>0.5)&(gaussian_hm[0]<=0.6)] = 7 # tmp_gaussian_hm[(gaussian_hm[0]>0.6)&(gaussian_hm[0]<=0.7)] = 8 # tmp_gaussian_hm[(gaussian_hm[0]>0.7)&(gaussian_hm[0]<=0.8)] = 9 # tmp_gaussian_hm[(gaussian_hm[0]>0.8)&(gaussian_hm[0]<=0.9)] = 10 # tmp_gaussian_hm[(gaussian_hm[0]>0.9)&(gaussian_hm[0]<1.0)] = 11 # tmp_gaussian_hm[(gaussian_hm[0]>=1.0)] = 12 # new_mask = sitk.GetImageFromArray(tmp_gaussian_hm) # new_mask.SetSpacing((2.0,2.0,2.0)) # new_img = sitk.GetImageFromArray(sample[0]) # new_img.SetSpacing((1.0,1.0,1.0)) # uid = mask_filename.split("/")[-1].replace('.nii.gz','') # sitk.WriteImage(new_mask,os.path.join("/fileser/gupc/Tmp/gaussian_heatmap_check", uid + "_mask_" + str(idx) + ".nii.gz")) # sitk.WriteImage(new_img,os.path.join("/fileser/gupc/Tmp/gaussian_heatmap_check", uid + "_" + str(idx) + ".nii.gz")) sample = norm(sample) return [torch.from_numpy(sample), torch.from_numpy(gaussian_hm), torch.from_numpy(center_idxs), torch.from_numpy(bboxes_diameters), torch.from_numpy(reg_offset), torch.from_numpy(reg_mask)] if self.mode in ['test']: image = self._load_img(self.sample_bboxes[idx].get_field("filename")) original_image = image[0] image = pad2factor(image[0],pad_value=self.pad_value) image = np.expand_dims(image, 0) input = norm(image) return [torch.from_numpy(input).float(), original_image] def _load_mask(self, path_to_img): itk_image = sitk.ReadImage(path_to_img) img = sitk.GetArrayFromImage(itk_image)[np.newaxis, ...] return img def draw_gaussian_from_bbox_and_mask(self, bboxes, mask): hm_size = [int(i / self.downsampling_ratio) for i in self.crop_size] gaussian_hm = np.zeros(hm_size, dtype=np.float32) center_idxs = np.zeros((self.max_objs), dtype=np.int64) reg_offset = np.zeros((self.max_objs, 3), dtype=np.float32) bboxes_diameters = np.zeros((self.max_objs, 3), dtype=np.float32) reg_mask = np.zeros((self.max_objs), dtype=np.uint8) bboxes_label = bboxes.get_field("bboxes_label") raw_bboxes = bboxes.data / self.downsampling_ratio mask = zoom(mask[0,:,:,:],1 / self.downsampling_ratio,mode='nearest') truth_bboxes = raw_bboxes[bboxes_label>0] irrelevant_bboxes = raw_bboxes[bboxes_label==-2] for i, bbox in enumerate(truth_bboxes): radius = gaussian_radius_3d(bbox[3:], self.min_overlap) radius = max(0, int(radius)) draw_umich_gaussian(gaussian_hm, bbox[:3], radius) bbox_int = bbox.astype(np.int32) center_idxs[i] = 64 * 64 * bbox_int[0] + 64 * bbox_int[1] + bbox_int[2] reg_offset[i] = bbox[:3] - bbox_int[:3] bboxes_diameters[i] = bbox[3:] reg_mask[i] = 1 gaussian_hm = gaussian_hm[np.newaxis, ...] mask = mask[np.newaxis, ...] gaussian_hm[mask==0] = 0.0 for bbox in irrelevant_bboxes: x_sp = int(round(bbox[2]-bbox[5]/2)) x_ep = int(round(bbox[2]+bbox[5]/2)) y_sp = int(round(bbox[1]-bbox[4]/2)) y_ep = int(round(bbox[1]+bbox[4]/2)) z_sp = int(round(bbox[0]-bbox[3]/2)) z_ep = int(round(bbox[0]+bbox[3]/2)) gaussian_hm[0,z_sp:z_ep,y_sp:y_ep,x_sp:x_ep] = -2.0 return gaussian_hm, center_idxs, bboxes_diameters, reg_offset, reg_mask def augment_with_mask(sample, mask, target, bboxes, do_flip=True, do_rotate=True, do_swap=True): if do_rotate: valid_rotate = False counter = 0 while not valid_rotate: angle = float(np.random.rand() * 180) rotate_mat = np.array([[np.cos(angle / 180 * np.pi), -np.sin(angle / 180 * np.pi)], [np.sin(angle / 180 * np.pi), np.cos(angle / 180 * np.pi)]]) # 计算rotate后的target位置 new_target = np.copy(target) size = np.array(sample.shape[2:4]).astype('float') new_target[1:3] = np.dot(rotate_mat, target[1:3] - size / 2) + size / 2 # 确保rotate后target的完整信息仍保留在sample中 if np.all(new_target[:3] > max(target[3:])) \ and np.all(new_target[:3] < np.array(sample.shape[1:4]) - max(new_target[3:])): valid_rotate = True sample = rotate(sample, angle, axes=(2, 3), reshape=False) mask = rotate(mask, angle, axes=(2, 3), reshape=False) for box in bboxes: box[1:3] = np.dot(rotate_mat, box[1:3] - size / 2) + size / 2 else: counter += 1 if counter == 3: break if do_swap: if sample.shape[1] == sample.shape[2] and sample.shape[1] == sample.shape[3]: axis_order = np.random.permutation(3) sample = np.transpose(sample, np.concatenate([[0], axis_order + 1])) mask = np.transpose(mask, np.concatenate([[0], axis_order + 1])) bboxes[:, :3] = bboxes[:, :3][:, axis_order] bboxes[:, 3:] = bboxes[:, 3:][:, axis_order] if do_flip: # only flip by x/y axis flip_id = np.array([1, np.random.randint(2), np.random.randint(2)]) * 2 - 1 sample = np.ascontiguousarray(sample[:, ::flip_id[0], ::flip_id[1], ::flip_id[2]]) mask = np.ascontiguousarray(mask[:, ::flip_id[0], ::flip_id[1], ::flip_id[2]]) for ax in range(3): if flip_id[ax] == -1: bboxes[:, ax] = np.array(sample.shape[ax + 1]) - bboxes[:, ax] return sample, mask, bboxes class Crop_with_mask(object): def __init__(self, cfg): self.pad_value = cfg.DATA.DATA_PROCESS.PAD_VALUE self.crop_size = cfg.DATA.DATA_PROCESS.CROP_SIZE self.bound_size = cfg.DATA.DATA_PROCESS.BOUND_SIZE def __call__(self, imgs, mask, target, bboxes, isRand=False): crop_size = self.crop_size bbox_data = bboxes.data start = [] for i in range(3): if not isRand: r = target[3 + i] / 2 s = np.floor(target[i] - r) + 1 - self.bound_size e = np.ceil(target[i] + r) + 1 + self.bound_size - crop_size[i] else: s = np.max([imgs.shape[i + 1] - crop_size[i] / 2, imgs.shape[i + 1] / 2 + self.bound_size]) e = np.min([crop_size[i] / 2, imgs.shape[i + 1] / 2 - self.bound_size]) if s > e: start.append(np.random.randint(e, s)) # ! else: start.append( int(target[i]) - int(crop_size[i] / 2) + np.random.randint(-self.bound_size / 2, self.bound_size / 2)) pad, mask_pad = [], [] pad.append([0, 0]) mask_pad.append([0, 0]) for i in range(3): leftpad = max(0, -start[i]) rightpad = max(0, start[i] + crop_size[i] - imgs.shape[i + 1]) rightpad_mask = max(0, start[i] + crop_size[i] - mask.shape[i + 1]) pad.append([leftpad, rightpad]) mask_pad.append([leftpad, rightpad_mask]) try: crop = imgs[:, max(start[0], 0):min(start[0] + crop_size[0], imgs.shape[1]), max(start[1], 0):min(start[1] + crop_size[1], imgs.shape[2]), max(start[2], 0):min(start[2] + crop_size[2], imgs.shape[3])] crop_mask = mask[:, max(start[0], 0):min(start[0] + crop_size[0], mask.shape[1]), max(start[1], 0):min(start[1] + crop_size[1], mask.shape[2]), max(start[2], 0):min(start[2] + crop_size[2], mask.shape[3])] except Exception as err: traceback.print_exc() print( 'crop fail with err %s, start %d, %d, %d, in crop_size %d, %d, %d, in image size %d, %d, %d, ' 'in target %d, %d, %d, %d, %d, %d.' % ( err, start[0], start[1], start[2], crop_size[0], crop_size[1], crop_size[2], imgs.shape[1], imgs.shape[2], imgs.shape[3], target[0], target[1], target[2], target[3], target[4], target[5])) return crop = np.pad(crop, pad, 'constant', constant_values=self.pad_value) crop_mask = np.pad(crop_mask, mask_pad, 'constant', constant_values= 0) for i in range(len(bbox_data)): for j in range(3): bbox_data[i][j] = bbox_data[i][j] - start[j] for i, box in enumerate(bboxes.data): if not (np.all(box[:3] - box[-3:] / 2 > 0) and np.all(box[:3] + box[-3:] / 2 < self.crop_size)): bboxes.get_field("bboxes_label")[i] = -2 return crop, crop_mask, bboxes def draw_umich_gaussian(heatmap, center, radius, k=1): diameter = 2 * radius + 1 gaussian = gaussian3D_isotropic((diameter,diameter,diameter), sigma=diameter / 3) x, y, z = int(center[2]), int(center[1]), int(center[0]) depth, height, width = heatmap.shape[0:3] left, right = min(x, radius), min(width - x, radius + 1) front, back = min(y, radius), min(height - y, radius + 1) top, bottom = min(z, radius), min(depth - z, radius + 1) masked_heatmap = heatmap[z - top:z + bottom, y - front:y + back, x - left:x + right] masked_gaussian = gaussian[radius - top:radius + bottom, radius - front:radius + back, radius - left:radius + right] if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) return heatmap def gaussian2D(shape, sigma=1): m, n = [(ss - 1.) / 2. for ss in shape] y, x = np.ogrid[-m:m+1,-n:n+1] h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) h[h < np.finfo(h.dtype).eps * h.max()] = 0 return h def gaussian3D_isotropic(shape, sigma=1): d, h, w = [(ss - 1.) / 2. for ss in shape] z, y, x = np.ogrid[-d:d+1,-h:h+1,-w:w+1] h = np.exp(-(x * x + y * y+ z * z) / (2 * sigma * sigma)) h[h < np.finfo(h.dtype).eps * h.max()] = 0 return h def gaussian_radius_3d_anisotropic(det_size, min_overlap=0.1): d, h, w = det_size # r1 r = sp.Symbol('r', real=True) a1 = 8 b1 = -4 * (d + h + w) c1 = 2 * (w * h + w * d + h * d) d1 = (min_overlap - 1) * w * h * d f = a1 * r ** 3 + b1 * r ** 2 + c1 * r + d1 r = sp.solve(f) r = np.asarray(r) index = np.where((r > 0) & (r < min(d / 2, h / 2, w / 2))) try: r1 = float(r[index]) except Exception as e: # print(e) # print("r1",d, h, w, r) r1 = min((d-1)/2, (h-1)/2, (w-1)/2) # r2 r = sp.Symbol('r', real=True) a2 = 8 * min_overlap b2 = 4 * (d + h + w) * min_overlap c2 = 2 * (w * h + w * d + h * d) * min_overlap d2 = (min_overlap - 1) * w * h * d f = a2 * r ** 3 + b2 * r ** 2 + c2 * r + d2 r = sp.solve(f) r = np.asarray(r) index = np.where(r > 0) try: r2 = float(r[index]) except Exception as e: # print(e) # print("r2",d, h, w, r) r2 = min((d-1)/2, (h-1)/2, (w-1)/2) # r3 r = sp.Symbol('r', real=True) a3 = 1 + min_overlap b3 = -1 * (1 + min_overlap) * (d + h + w) c3 = (1 + min_overlap) * (w * h + w * d + h * d) d3 = (min_overlap - 1) * w * h * d f = a3 * r ** 3 + b3 * r ** 2 + c3 * r + d3 r = sp.solve(f) r = np.asarray(r) index = np.where((r > 0) & (r < min(d, h, w))) try: r3 = float(r[index]) except Exception as e: # print(e) # print("r3",d, h, w, r) r3 = min((d-1)/2, (h-1)/2, (w-1)/2) return min(r1, r2, r3) def gaussian_radius_3d(det_size, min_overlap=0.1): d, h, w = det_size # r1 r = sp.Symbol('r', real=True) a1 = 8 b1 = -4 * (d + h + w) c1 = 2 * (w * h + w * d + h * d) d1 = (min_overlap - 1) * w * h * d f = a1 * r ** 3 + b1 * r ** 2 + c1 * r + d1 r = sp.solve(f) r = np.asarray(r) index = np.where((r > 0) & (r < min(d / 2, h / 2, w / 2))) try: r1 = float(r[index]) except Exception as e: # print(e) # print("r1",d, h, w, r) r1 = min((d-1)/2, (h-1)/2, (w-1)/2) # r2 r = sp.Symbol('r', real=True) a2 = 8 * min_overlap b2 = 4 * (d + h + w) * min_overlap c2 = 2 * (w * h + w * d + h * d) * min_overlap d2 = (min_overlap - 1) * w * h * d f = a2 * r ** 3 + b2 * r ** 2 + c2 * r + d2 r = sp.solve(f) r = np.asarray(r) index = np.where(r > 0) try: r2 = float(r[index]) except Exception as e: # print(e) # print("r2",d, h, w, r) r2 = min((d-1)/2, (h-1)/2, (w-1)/2) # r3 r = sp.Symbol('r', real=True) a3 = 1 + min_overlap b3 = -1 * (1 + min_overlap) * (d + h + w) c3 = (1 + min_overlap) * (w * h + w * d + h * d) d3 = (min_overlap - 1) * w * h * d f = a3 * r ** 3 + b3 * r ** 2 + c3 * r + d3 r = sp.solve(f) r = np.asarray(r) index = np.where((r > 0) & (r < min(d, h, w))) try: r3 = float(r[index]) except Exception as e: # print(e) # print("r3",d, h, w, r) r3 = min((d-1)/2, (h-1)/2, (w-1)/2) return min(r1, r2, r3) # x_sp = int(round(bbox[2]-bbox[5]/2)) # x_ep = int(round(bbox[2]+bbox[5]/2)) # y_sp = int(round(bbox[1]-bbox[4]/2)) # y_ep = int(round(bbox[1]+bbox[4]/2)) # z_sp = int(round(bbox[0]-bbox[3]/2)) # z_ep = int(round(bbox[0]+bbox[3]/2)) # bbox_mask = mask[z_sp:z_ep,y_sp:y_ep,x_sp:x_ep] # label_mask = label(bbox_mask) # try: # properties = regionprops(label_mask, bbox_mask) # except Exception as e: # print(e) # if len(properties)>0: # center_of_mass = properties[0].centroid # center = [int(round(x_sp+center_of_mass[2])),int(round(y_sp+center_of_mass[1])),int(round(z_sp+center_of_mass[0]))] # info['mass_center'] = center # info['distance'] = [info['index_point'][0]-center[0],info['index_point'][1]-center[1],info['index_point'][2]-center[2]] # info['mass_center_on_mask'] = arr[center[2],center[1],center[0]] # new_nodule_infos.append(info) # if arr[center[2],center[1],center[0]] == 0: # cnt += 1 # print("uid: {0}, frac_type: {1}, point :{2},{3},{4}, distance: {5},{6},{7}, center label: {8}".\ # format(key_str,\ # info['frac_type'],\ # round(info['index_point'][0]),\ # round(info['index_point'][1]),\ # round(info['index_point'][2]),\ # round(info['index_point'][0]-center[0]),\ # round(info['index_point'][1]-center[1]),\ # round(info['index_point'][2]-center[2]),\ # arr[center[2],center[1],center[0]]))