import json import math import os import random import traceback import warnings import SimpleITK as sitk import lmdb import numpy as np import torch from scipy.ndimage import zoom from scipy.ndimage.interpolation import rotate from torch.utils.data import Dataset class BBoxes(object): def __init__(self, data): self.data = data self.extra_fields = {} def add_field(self, field, field_data): self.extra_fields[field] = field_data def get_field(self, field): return self.extra_fields[field] def has_field(self, field): return field in self.extra_fields def fields(self): return list(self.extra_fields.keys()) def _copy_extra_fields(self, bbox): for k, v in bbox.extra_fields.items(): self.extra_fields[k] = v class BboxReader(Dataset): def __init__(self, cfg, mode='train'): self.crop = Crop(cfg) self.mode = mode self.crop_size = cfg.DATA.DATA_PROCESS.CROP_SIZE self.bbox_border = cfg.DATA.DATA_PROCESS.BBOX_BORDER self.r_rand = cfg.DATA.DATA_LOADER.RAND_CROP_RATIO self.augtype = cfg.DATA.DATA_PROCESS.AUGTYPE self.pad_value = cfg.DATA.DATA_PROCESS.PAD_VALUE self.data_dir = cfg.DATA.DATA_LOADER.DATA_DIR self.stride = cfg.DATA.DATA_PROCESS.STRIDE self.balanced_sampling_flag = cfg.DATA.DATA_LOADER.BALANCED_SAMPLING self.pos_target_range = cfg.DATA.DATA_LOADER.POS_TARGET_RANGE self.neg_target_range = cfg.DATA.DATA_LOADER.NEG_TARGET_RANGE self.irrelevant_target_range = cfg.DATA.DATA_LOADER.IRRELEVANT_TARGET_RANGE self.hu_min = cfg.DATA.DATA_PROCESS.HU_MIN self.hu_max = cfg.DATA.DATA_PROCESS.HU_MAX self.data_mode = cfg.DATA.DATA_LOADER.DATA_MODE if mode == 'train': db_file = cfg.DATA.DATA_LOADER.TRAIN_DB elif mode == 'val': db_file = cfg.DATA.DATA_LOADER.VALIDATE_DB elif mode == 'test': db_file = cfg.DATA.DATA_LOADER.TEST_DB else: print('Invalid mode with %s in BboxReader!' % mode) return self.sample_bboxes = self._read_data(db_file) self.targets_idx = self._get_targets_idx() if self.balanced_sampling_flag: self.count_dict, self.resample_dict, self.index_dict = self._get_resample_dict() # 获取数据的id列表 self.list_IDs = self._get_list_ids() def _read_data(self, db_file): bboxes_num = 0 sample_bboxes = [] env = lmdb.open(db_file) txn = env.begin() for key, value in txn.cursor(): key = str(key, encoding='utf-8') value = str(value, encoding='utf-8') label_info = json.loads(value) nidi_info = label_info['frac_info'] if 'frac_info' in label_info else label_info['nidi_info'] bboxes_data = [] bboxes_type = [] bboxes_diameter = [] for nidus_info in nidi_info: index_point = nidus_info['index_point'] index_diameter = nidus_info['index_diameter'] density_type = nidus_info['frac_type'] if 'frac_type' in nidus_info else nidus_info['density_type'] bboxes_data.append(np.array( [index_point[2], index_point[1], index_point[0], index_diameter[2], index_diameter[1], index_diameter[0]])) bboxes_diameter.append(max(index_diameter)) if isinstance(density_type, str): bboxes_type.append(density_type[:2]) else: bboxes_type.append(density_type) bboxes_num += len(bboxes_data) bboxes = BBoxes(np.array(bboxes_data)) bboxes.add_field("bboxes_type", np.array(bboxes_type)) bboxes.add_field("bboxes_diameter", np.array(bboxes_diameter)) if self.data_mode == "mass": bboxes.add_field("filename", os.path.join(self.data_dir, key + '.nii.gz', 'm_ptrCuttingImage.nii.gz')) elif self.data_mode == "rib": bboxes.add_field("filename", os.path.join(self.data_dir, key + '.nii.gz', 'resizedImage.nii.gz')) # else: raise ValueError("Data mode %s is not implemented." % self.data_mode) sample_bboxes.append(bboxes) env.close() print('In mode %s, num of ct is %d, num of bboxes: %d.' % (self.mode, len(sample_bboxes), bboxes_num)) return sample_bboxes def _get_targets_idx(self): targets_idx = [] for i, bboxes in enumerate(self.sample_bboxes): if len(bboxes.data) > 0: bboxes.add_field("bboxes_label", np.zeros(len(bboxes.data), dtype=np.int32)) for j, bbox_type in enumerate(bboxes.get_field("bboxes_type")): if int(bbox_type) in self.pos_target_range: targets_idx.append(np.array([i, j])) bboxes.get_field("bboxes_label")[j] = 1 elif int(bbox_type) in self.neg_target_range: targets_idx.append(np.array([i, j])) bboxes.get_field("bboxes_label")[j] = -1 targets_idx = np.vstack(targets_idx) return targets_idx def __len__(self): if self.mode == 'train': if self.balanced_sampling_flag: return int(len(self.list_IDs) / (1 - self.r_rand)) else: return int(len(self.targets_idx) / (1 - self.r_rand)) elif self.mode == 'val': return len(self.targets_idx) else: return len(self.sample_bboxes) def __getitem__(self, idx): # t = time.time() # np.random.seed(int(str(t % 1)[2:7])) # seed according to time is_random_img = False if self.mode in ['train', 'val']: if self.balanced_sampling_flag: total_len = len(self.list_IDs) else: total_len = len(self.targets_idx) if idx >= total_len: is_random_crop = True idx = idx % total_len is_random_img = np.random.randint(2) else: is_random_crop = False else: is_random_crop = False if self.balanced_sampling_flag: idx = self.list_IDs[idx] if self.mode in ['train', 'val']: if not is_random_img: target_idx = self.targets_idx[idx] bboxes = self.sample_bboxes[target_idx[0]] new_bboxes = BBoxes(np.copy(bboxes.data)) new_bboxes._copy_extra_fields(bboxes) target_bbox = new_bboxes.data[target_idx[1]] new_bboxes.add_field("target_idx", target_idx[1]) filename = new_bboxes.get_field("filename") imgs = self._load_img(filename) isScale = self.augtype.SCALE and (self.mode == 'train') try: sample, new_bboxes.data = self.crop(imgs, target_bbox, new_bboxes.data, isScale, is_random_crop) except: print('Crop fail in %s.' % filename) return assert sample.shape[1] == self.crop_size[0] and sample.shape[2] == self.crop_size[1] \ and sample.shape[3] == self.crop_size[2], 'crop patch of {} has illegal shape: {}'.format( filename, sample.shape) if self.mode == 'train' and not is_random_crop: target_bbox = new_bboxes.data[target_idx[1]] sample, new_bboxes.data = augment(sample, target_bbox, new_bboxes.data, do_flip=self.augtype.FLIP, do_rotate=self.augtype.ROTATE, do_swap=self.augtype.SWAP) else: randimid = np.random.randint(len(self.sample_bboxes)) bboxes = self.sample_bboxes[randimid] new_bboxes = BBoxes(np.copy(bboxes.data)) new_bboxes._copy_extra_fields(bboxes) filename = new_bboxes.get_field("filename") imgs = self._load_img(filename) isScale = self.augtype.SCALE and (self.mode == 'train') sample, new_bboxes.data = self.crop(imgs, [], new_bboxes.data, isScale=False, isRand=True) assert sample.shape[1] == self.crop_size[0] and sample.shape[2] == self.crop_size[1] \ and sample.shape[3] == self.crop_size[2], 'patch of {} has illegal shape: {}'.format(filename, sample.shape) sample = norm(sample, self.hu_min, self.hu_max) new_bboxes = self._filter_bboxes(new_bboxes) if len(new_bboxes.data) <= 0: pass # print('Invalid patch in image %s.' % filename) else: new_bboxes.data[:, -3:] = new_bboxes.data[:, -3:] + self.bbox_border return [torch.from_numpy(sample), new_bboxes] if self.mode in ['test']: image = self._load_img(self.sample_bboxes[idx].get_field("filename")) original_image = image[0] image = pad2factor(image[0], pad_value=self.pad_value) image = np.expand_dims(image, 0) input = norm(image, self.hu_min, self.hu_max) return [torch.from_numpy(input).float(), original_image] def _load_img(self, path_to_img): if 'resizedImage.nii.gz' in path_to_img: path_to_npy = path_to_img.replace('resizedImage.nii.gz', 'resizedImage.npy') elif 'm_ptrCuttingImage.nii.gz' in path_to_img: path_to_npy = path_to_img.replace('m_ptrCuttingImage.nii.gz', 'm_ptrCuttingImage.npy') else: path_to_npy = '' if os.path.exists(path_to_npy): img = np.load(path_to_npy)[np.newaxis, ...] else: itk_image = sitk.ReadImage(path_to_img) img = sitk.GetArrayFromImage(itk_image)[np.newaxis, ...] return img def _filter_bboxes(self, bboxes): new_bboxes = [] bboxes_type = [] bboxes_diameter = [] bboxes_label = [] j = 0 for i, box in enumerate(bboxes.data): if np.all(box[:3] - box[-3:] / 2 > 0) and np.all(box[:3] + box[-3:] / 2 < self.crop_size): # if np.all(box[:3] > 0) and np.all(box[:3] < self.crop_size): new_bboxes.append(box) bboxes_type.append(bboxes.get_field("bboxes_type")[i]) bboxes_diameter.append(bboxes.get_field("bboxes_diameter")[i]) bboxes_label.append(bboxes.get_field("bboxes_label")[i]) if bboxes.has_field("target_idx"): if bboxes.get_field("target_idx") == i: target_idx = j j += 1 new_bboxes = BBoxes(np.array(new_bboxes)) new_bboxes.add_field("bboxes_type", np.array(bboxes_type)) new_bboxes.add_field("bboxes_diameter", np.array(bboxes_diameter)) new_bboxes.add_field("bboxes_label", np.array(bboxes_label)) if "target_idx" in locals(): new_bboxes.add_field("target_idx", target_idx) return new_bboxes def _get_resample_dict(self): min_d = 14 max_d = 60 # resample_max = 100 count_dict = dict() resample_dict = dict() index_dict = dict() # 初始化dict for d in range(min_d, max_d + 1): count_dict[d] = 0 resample_dict[d] = 0 index_dict[d] = list() # 区间计数 & index append for i, target_idx in enumerate(self.targets_idx): bboxes = self.sample_bboxes[target_idx[0]] target_bbox = bboxes.data[target_idx[1]] diameter = max(target_bbox[3:]) key = int(diameter) assert (key != 0) if key >= max_d: key = max_d elif key <= min_d: key = min_d count_dict[key] += 1 index_dict[key].append(i) # print("**"*10) # print(count_dict) # print(len(count_dict)) # print("**"*10) # 采样倍数 for key in count_dict.keys(): if count_dict[key] == 0: resample_dict[key] = 0 else: if key == 14: resample_max = 500 else: resample_max = 100 if resample_max / count_dict[key] > 5: resample_dict[key] = 5 elif 1 <= resample_max / count_dict[key] <= 5: resample_dict[key] = round(resample_max / count_dict[key]) else: resample_dict[key] = resample_max resample_dict[max_d] = 4 return resample_dict, index_dict def _generate_resample_indexes(self): result_index = [] # resample_max = 100 # resample indexes for key in self.resample_dict.keys(): if key == 14: resample_max = 500 else: resample_max = 100 if self.resample_dict[key] == resample_max: tmp_index = random.sample(self.index_dict[key], resample_max) result_index.extend(tmp_index) else: for _ in range(self.resample_dict[key]): result_index.extend(self.index_dict[key]) return result_index def _get_list_ids(self): if self.mode == 'train': return self._generate_resample_indexes() else: return np.arange(len(self.targets_idx)) def pad2factor(image, factor=16, pad_value=0): depth, height, width = image.shape d = int(math.ceil(depth / float(factor))) * factor h = int(math.ceil(height / float(factor))) * factor w = int(math.ceil(width / float(factor))) * factor pad = [] pad.append([0, d - depth]) pad.append([0, h - height]) pad.append([0, w - width]) image = np.pad(image, pad, 'constant', constant_values=pad_value) return image def augment(sample, target, bboxes, do_flip=True, do_rotate=True, do_swap=True): if do_rotate: valid_rotate = False counter = 0 while not valid_rotate: angle = float(np.random.rand() * 180) rotate_mat = np.array([[np.cos(angle / 180 * np.pi), -np.sin(angle / 180 * np.pi)], [np.sin(angle / 180 * np.pi), np.cos(angle / 180 * np.pi)]]) # 计算rotate后的target位置 new_target = np.copy(target) size = np.array(sample.shape[2:4]).astype('float') new_target[1:3] = np.dot(rotate_mat, target[1:3] - size / 2) + size / 2 # 确保rotate后target的完整信息仍保留在sample中 if np.all(new_target[:3] > max(target[3:])) \ and np.all(new_target[:3] < np.array(sample.shape[1:4]) - max(new_target[3:])): valid_rotate = True sample = rotate(sample, angle, axes=(2, 3), reshape=False) for box in bboxes: box[1:3] = np.dot(rotate_mat, box[1:3] - size / 2) + size / 2 else: counter += 1 if counter == 3: break if do_swap: if sample.shape[1] == sample.shape[2] and sample.shape[1] == sample.shape[3]: axis_order = np.random.permutation(3) sample = np.transpose(sample, np.concatenate([[0], axis_order + 1])) bboxes[:, :3] = bboxes[:, :3][:, axis_order] bboxes[:, 3:] = bboxes[:, 3:][:, axis_order] if do_flip: # only flip by x/y axis flip_id = np.array([1, np.random.randint(2), np.random.randint(2)]) * 2 - 1 sample = np.ascontiguousarray(sample[:, ::flip_id[0], ::flip_id[1], ::flip_id[2]]) for ax in range(3): if flip_id[ax] == -1: bboxes[:, ax] = np.array(sample.shape[ax + 1]) - bboxes[:, ax] return sample, bboxes # 输入数据已经对lungmask之外的HU值做过处理 def norm(image, hu_min=-300.0, hu_max=1000.0): image = (np.clip(image.astype(np.float32), hu_min, hu_max) - hu_min) / float(hu_max - hu_min) return (image - 0.5) * 2.0 def norm_old(image, hu_min=-1000.0, hu_max=282.0): image = image.astype(np.float32) bones = image > hu_max image[bones] = 0.0 image = (np.clip(image, hu_min, hu_max) - hu_min) / float(hu_max - hu_min) return image class Crop(object): def __init__(self, cfg): self.crop_size = cfg.DATA.DATA_PROCESS.CROP_SIZE self.bound_size = cfg.DATA.DATA_PROCESS.BOUND_SIZE self.stride = cfg.DATA.DATA_PROCESS.STRIDE self.pad_value = cfg.DATA.DATA_PROCESS.PAD_VALUE self.diameterLim = cfg.DATA.DATA_PROCESS.DIAMETERLIM self.scaleLim = cfg.DATA.DATA_PROCESS.SCALELIM def __call__(self, imgs, target, bboxes, isScale=False, isRand=False): if isScale: scaleRange = [np.min([np.max([(self.diameterLim[0] / max(target[3:])), self.scaleLim[0]]), 1]), np.max([np.min([(self.diameterLim[1] / max(target[3:])), self.scaleLim[1]]), 1])] scale = np.random.rand() * (scaleRange[1] - scaleRange[0]) + scaleRange[0] crop_size = (np.array(self.crop_size).astype('float') / scale).astype('int') else: crop_size = self.crop_size # target = np.copy(target) # bboxes = np.copy(bboxes) start = [] for i in range(3): # start.append(int(target[i] - crop_size[i] / 2)) if not isRand: r = target[3 + i] / 2 s = np.floor(target[i] - r) + 1 - self.bound_size e = np.ceil(target[i] + r) + 1 + self.bound_size - crop_size[i] else: s = np.max([imgs.shape[i + 1] - crop_size[i] / 2, imgs.shape[i + 1] / 2 + self.bound_size]) e = np.min([crop_size[i] / 2, imgs.shape[i + 1] / 2 - self.bound_size]) if s > e: start.append(np.random.randint(e, s)) # ! else: start.append( int(target[i]) - int(crop_size[i] / 2) + np.random.randint(-self.bound_size / 2, self.bound_size / 2)) pad = [] pad.append([0, 0]) for i in range(3): leftpad = max(0, -start[i]) rightpad = max(0, start[i] + crop_size[i] - imgs.shape[i + 1]) pad.append([leftpad, rightpad]) try: crop = imgs[:, max(start[0], 0):min(start[0] + crop_size[0], imgs.shape[1]), max(start[1], 0):min(start[1] + crop_size[1], imgs.shape[2]), max(start[2], 0):min(start[2] + crop_size[2], imgs.shape[3])] except Exception as err: traceback.print_exc() print( 'crop fail with err %s, start %d, %d, %d, in crop_size %d, %d, %d, in image size %d, %d, %d, ' 'in target %d, %d, %d, %d, %d, %d.' % ( err, start[0], start[1], start[2], crop_size[0], crop_size[1], crop_size[2], imgs.shape[1], imgs.shape[2], imgs.shape[3], target[0], target[1], target[2], target[3], target[4], target[5])) return crop = np.pad(crop, pad, 'constant', constant_values=self.pad_value) for i in range(len(bboxes)): for j in range(3): bboxes[i][j] = bboxes[i][j] - start[j] if isScale: with warnings.catch_warnings(): warnings.simplefilter("ignore") crop = zoom(crop, [1, scale, scale, scale], order=1) newpad = self.crop_size[0] - crop.shape[1:][0] if newpad < 0: crop = crop[:, :newpad, :newpad, :newpad] elif newpad > 0: pad2 = [[0, 0], [0, newpad], [0, newpad], [0, newpad]] crop = np.pad(crop, pad2, 'constant', constant_values=self.pad_value) for i in range(len(bboxes)): for j in range(6): bboxes[i][j] = bboxes[i][j] * scale return crop, bboxes