from bbox_reader import * class BboxReader_Nodule(BboxReader): def __init__(self, cfg, mode='train'): self.spacing = cfg.DATA.DATA_PROCESS.SPACING self.crop = Crop(cfg) self.mode = mode self.crop_size = cfg.DATA.DATA_PROCESS.CROP_SIZE self.bbox_border = cfg.DATA.DATA_PROCESS.BBOX_BORDER self.r_rand = cfg.DATA.DATA_LOADER.RAND_CROP_RATIO self.augtype = cfg.DATA.DATA_PROCESS.AUGTYPE self.pad_value = cfg.DATA.DATA_PROCESS.PAD_VALUE self.data_dir = cfg.DATA.DATA_LOADER.DATA_DIR self.stride = cfg.DATA.DATA_PROCESS.STRIDE self.balanced_sampling_flag = cfg.DATA.DATA_LOADER.BALANCED_SAMPLING self.pos_target_range = cfg.DATA.DATA_LOADER.POS_TARGET_RANGE self.neg_target_range = cfg.DATA.DATA_LOADER.NEG_TARGET_RANGE if mode == 'train': db_file = cfg.DATA.DATA_LOADER.TRAIN_DB elif mode == 'val': db_file = cfg.DATA.DATA_LOADER.VALIDATE_DB elif mode == 'test': db_file = cfg.DATA.DATA_LOADER.TEST_DB else: print('Invalid mode with %s in BboxReader!' % mode) return self.sample_bboxes = self._read_data(db_file) self.targets_idx = self._get_targets_idx() if self.mode in ['test', 'deploy']: self.balanced_sampling_flag = False if self.balanced_sampling_flag: # resample neg self.count_dict_pos, self.index_dict_pos, \ self.count_dict_neg, self.resample_dict_neg, self.index_dict_neg = self._get_resample_dict() # resample pos # uniform resample by diameter # self.count_dict, self.resample_dict, self.index_dict = self._get_resample_dict() # 获取数据的id列表 self.list_IDs = self._get_list_ids() # print(self.list_IDs[:50]) def _read_data(self, db_file): bboxes_num = 0 sample_bboxes = [] env = lmdb.open(db_file) txn = env.begin() for key, value in txn.cursor(): key = str(key, encoding='utf-8') value = str(value, encoding='utf-8') label_info = json.loads(value) nidi_info = label_info['nidi_info'] bboxes_data = [] bboxes_type = [] bboxes_diameter = [] for nidus_info in nidi_info: index_point = nidus_info['index_point'] index_diameter = nidus_info['index_diameter'] density_type = nidus_info['density_type'] bboxes_data.append(np.array( [index_point[2], index_point[1], index_point[0], index_diameter[2], index_diameter[1], index_diameter[0]])) bboxes_diameter.append(max(index_diameter)) if isinstance(density_type, str): bboxes_type.append(density_type[0]) else: bboxes_type.append(density_type) bboxes_num += len(bboxes_data) bboxes = BBoxes(np.array(bboxes_data)) bboxes.add_field("bboxes_type", np.array(bboxes_type)) bboxes.add_field("bboxes_diameter", np.array(bboxes_diameter)) bboxes.add_field("filename", os.path.join(self.data_dir, key + '.nii.gz', 'm_ptrCuttingImage.nii.gz')) sample_bboxes.append(bboxes) env.close() print('In mode %s, num of ct is %d, num of bboxes: %d.' % (self.mode, len(sample_bboxes), bboxes_num)) return sample_bboxes def _get_targets_idx(self): targets_idx = [] for i, bboxes in enumerate(self.sample_bboxes): if len(bboxes.data) > 0: bboxes.add_field("bboxes_label", np.zeros(len(bboxes.data), dtype=np.int32)) for j, bbox_type in enumerate(bboxes.get_field("bboxes_type")): bbox_diameter = bboxes.get_field("bboxes_diameter")[j] bbox_data = bboxes.data[j] if int(bbox_type) in self.pos_target_range and (8 <= max((bbox_data[-3:] * self.spacing)[1:3]) <= 100): if self.mode == 'train': if int(bbox_type) == 2 or int(bbox_type) == 4: for _ in range(5): targets_idx.append(np.array([i, j])) elif int(bbox_type) == 3: for _ in range(2): targets_idx.append(np.array([i, j])) else: targets_idx.append(np.array(([i, j]))) if max((bbox_data[-3:] * self.spacing)[1:3]) >= 40: for _ in range(10): targets_idx.append(np.array([i, j])) else: targets_idx.append(np.array([i, j])) bboxes.get_field("bboxes_label")[j] = 1 elif int(bbox_type) in self.neg_target_range and (4 <= max((bbox_data[-3:] * self.spacing)[1:3]) <= 100): targets_idx.append(np.array([i, j])) bboxes.get_field("bboxes_label")[j] = -1 targets_idx = np.vstack(targets_idx) return targets_idx def __getitem__(self, idx): # t = time.time() # np.random.seed(int(str(t % 1)[2:7])) # seed according to time is_random_img = False if self.mode in ['train', 'val']: if self.balanced_sampling_flag: total_len = len(self.list_IDs) else: total_len = len(self.targets_idx) if idx >= total_len: is_random_crop = True idx = idx % total_len is_random_img = np.random.randint(2) else: is_random_crop = False else: is_random_crop = False if self.balanced_sampling_flag: idx = self.list_IDs[idx] if self.mode in ['train', 'val']: if not is_random_img: target_idx = self.targets_idx[idx] bboxes = self.sample_bboxes[target_idx[0]] new_bboxes = BBoxes(np.copy(bboxes.data)) new_bboxes._copy_extra_fields(bboxes) target_bbox = new_bboxes.data[target_idx[1]] new_bboxes.add_field("target_idx", target_idx[1]) filename = new_bboxes.get_field("filename") imgs = self._load_img(filename) isScale = self.augtype.SCALE and (self.mode == 'train') try: sample, new_bboxes.data = self.crop(imgs, target_bbox, new_bboxes.data, isScale, is_random_crop) except: print('Crop fail in %s.' % filename) return assert sample.shape[1] == self.crop_size[0] and sample.shape[2] == self.crop_size[1] \ and sample.shape[3] == self.crop_size[2], 'crop patch of {} has illegal shape: {}'.format( filename, sample.shape) if self.mode == 'train' and not is_random_crop: target_bbox = new_bboxes.data[target_idx[1]] sample, new_bboxes.data = augment(sample, target_bbox, new_bboxes.data, do_flip=self.augtype.FLIP, do_rotate=self.augtype.ROTATE, do_swap=self.augtype.SWAP) else: randimid = np.random.randint(len(self.sample_bboxes)) bboxes = self.sample_bboxes[randimid] new_bboxes = BBoxes(np.copy(bboxes.data)) new_bboxes._copy_extra_fields(bboxes) filename = new_bboxes.get_field("filename") imgs = self._load_img(filename) isScale = self.augtype.SCALE and (self.mode == 'train') sample, new_bboxes.data = self.crop(imgs, [], new_bboxes.data, isScale=False, isRand=True) assert sample.shape[1] == self.crop_size[0] and sample.shape[2] == self.crop_size[1] \ and sample.shape[3] == self.crop_size[2], 'patch of {} has illegal shape: {}'.format(filename, sample.shape) sample = norm(sample, hu_min=-1000.0, hu_max=600.0) new_bboxes = self._filter_bboxes(new_bboxes) if len(new_bboxes.data) <= 0: pass # print('Invalid patch in image %s.' % filename) else: new_bboxes.data[:, -3:] = new_bboxes.data[:, -3:] + self.bbox_border return [torch.from_numpy(sample), new_bboxes] if self.mode in ['test']: image = self._load_img(self.sample_bboxes[idx].get_field("filename")) original_image = image[0] image = pad2factor(image[0], pad_value=self.pad_value) image = np.expand_dims(image, 0) input = self.norm(image, hu_min=-1000.0, hu_max=600.0) return [torch.from_numpy(input).float(), original_image] def _load_img(self, path_to_img): path_to_npy = path_to_img.replace('m_ptrCuttingImage.nii.gz', 'm_ptrCuttingImage.npy') if os.path.exists(path_to_npy): img = np.load(path_to_npy)[np.newaxis, ...] else: itk_image = sitk.ReadImage(path_to_img) img = sitk.GetArrayFromImage(itk_image)[np.newaxis, ...] return img def _get_resample_dict_pos(self): min_d = 1 max_d = 11 count_dict = dict() resample_dict = dict() index_dict = dict() # count_dict_neg = dict() # index_dict_neg = dict() # 初始化dict for d in range(min_d, max_d + 1): count_dict[d] = 0 resample_dict[d] = 0 index_dict[d] = list() count_4_11 = 0 count_11_more = 0 for i, target_idx in enumerate(self.targets_idx): bboxes = self.sample_bboxes[target_idx[0]] bboxes_label = bboxes.get_field("bboxes_label") target_bbox = bboxes.data[target_idx[1]] target_c_label = bboxes_label[target_idx[1]] diameter = max((target_bbox[3:] * np.array(self.spacing))[1:3]) key = int(diameter) assert (key != 0) if key >= max_d: key = max_d count_dict[key] += 1 index_dict[key].append(i) if 4 <= int(key) < 11: count_4_11 += 1 elif int(key) >= 11: count_11_more += 1 ratio = count_11_more / count_4_11 for key in count_dict.keys(): if count_dict[key] == 0: resample_dict[key] = 0 if int(key) < 4 or int(key) >= 11: resample_dict[key] = -1 else: resample_dict[key] = count_dict[key] * ratio resample_dict[max_d] = -1 return count_dict, resample_dict, index_dict def _generate_resample_indexes_pos(self): result_index = [] # resample indexes for key in self.resample_dict.keys(): if self.resample_dict[key] == -1: result_index.extend(self.index_dict[key]) else: tmp_index = random.sample(self.index_dict[key], int(self.resample_dict[key])) # if int(key) == 4: # print('index number:') # print(tmp_index[:50]) result_index.extend(tmp_index) return result_index def _get_resample_dict(self): min_d = 1 max_d = 60 count_dict_pos = dict() count_dict_neg = dict() resample_dict_neg = dict() index_dict_pos = dict() index_dict_neg = dict() # 初始化dict for d in range(min_d, max_d + 1): count_dict_pos[d] = 0 count_dict_neg[d] = 0 resample_dict_neg[d] = 0 index_dict_pos[d] = list() index_dict_neg[d] = list() for i, target_idx in enumerate(self.targets_idx): bboxes = self.sample_bboxes[target_idx[0]] bboxes_label = bboxes.get_field("bboxes_label") target_bbox = bboxes.data[target_idx[1]] target_c_label = bboxes_label[target_idx[1]] diameter = max((target_bbox[3:] * np.array(self.spacing))[1:3]) key = int(diameter) assert (key != 0) if key >= max_d: key = max_d if int(target_c_label) == 1: count_dict_pos[key] += 1 index_dict_pos[key].append(i) elif int(target_c_label) == -1: try: count_dict_neg[key] += 1 index_dict_neg[key].append(i) except: print(key) for key in count_dict_pos.keys(): # negative resample if count_dict_neg[key] == 0: resample_dict_neg[key] = -1 else: if count_dict_pos[key] // 2 < count_dict_neg[key]: resample_dict_neg[key] = count_dict_pos[key] // 2 else: resample_dict_neg[key] = -1 resample_dict_neg[max_d] = -1 return count_dict_pos, index_dict_pos, count_dict_neg, resample_dict_neg, index_dict_neg def _generate_resample_indexes(self): result_index = [] # resample indexes for key in self.resample_dict_neg.keys(): result_index.extend(self.index_dict_pos[key]) if self.resample_dict_neg[key] == -1: result_index.extend(self.index_dict_neg[key]) else: tmp_index = random.sample(self.index_dict_neg[key], self.resample_dict_neg[key]) result_index.extend(tmp_index) return result_index def norm(self, image, hu_min=-300.0, hu_max=1000.0): image = (np.clip(image.astype(np.float32), hu_min, hu_max) - hu_min) / float(hu_max - hu_min) # image = (image * 255).astype('uint8') # image = image / 255.0 return (image - 0.5) * 2.0