import torch.distributed from BaseDetector.data.bbox_reader import * import glob class BboxReader_Nodule(BboxReader): def __init__(self, cfg, mode='train', log_fun=None, split_combine=None): self.spacing = cfg.data['data_process']['spacing'] self.crop = Crop(cfg) self.mode = mode self.crop_size = cfg.data['data_process']['crop_size'] self.bbox_border = cfg.data['data_process']['bbox_border'] self.r_rand = cfg.data['data_loader']['rand_crop_ratio'] self.augtype = cfg.data['data_process']['augtype'] self.pad_value = cfg.data['data_process']['pad_value'] self.data_dir = cfg.data['data_loader']['data_dir'] self.stride = cfg.data['data_process']['stride'] self.balanced_sampling_flag = cfg.data['data_loader']['balanced_sampling'] self.pos_target_range = cfg.data['data_loader']['pos_target_range'] self.neg_target_range = cfg.data['data_loader']['neg_target_range'] if log_fun is None: self.log_fun = print else: self.log_fun = log_fun self.split_combine = split_combine if mode == 'train': db_file = cfg.data['data_loader']['train_db'] elif mode == 'val': db_file = cfg.data['data_loader']['validate_db'] elif mode == 'test': db_file = cfg.data['data_loader']['test_db'] elif mode == 'infer': self.data_nii_gz = glob.glob(os.path.join(self.data_dir, '*.nii.gz')) # else: # self.log_fun('Invalid mode with %s in BboxReader!' % mode) # return if mode == 'train' or mode == 'val' or mode == 'test': self.sample_bboxes = self._read_data(db_file) self.targets_idx = self._get_targets_idx() if self.mode in ['test', 'deploy', 'infer']: self.balanced_sampling_flag = False if self.balanced_sampling_flag: # resample neg self.count_dict_pos, self.index_dict_pos, \ self.count_dict_neg, self.resample_dict_neg, self.index_dict_neg = self._get_resample_dict() # resample pos # uniform resample by diameter # self.count_dict, self.resample_dict, self.index_dict = self._get_resample_dict() # 获取数据的id列表 self.list_IDs = self._get_list_ids() # print(self.list_IDs[:50]) def _read_data(self, db_file): bboxes_num = 0 sample_bboxes = [] env = lmdb.open(db_file) txn = env.begin() for key, value in txn.cursor(): key = str(key, encoding='utf-8') value = str(value, encoding='utf-8') label_info = json.loads(value) nidi_info = label_info['nidi_info'] bboxes_data = [] bboxes_type = [] bboxes_diameter = [] for nidus_info in nidi_info: index_point = nidus_info['index_point'] index_diameter = nidus_info['index_diameter'] density_type = nidus_info['density_type'] ## zyx bboxes_data.append(np.array( [index_point[2], index_point[1], index_point[0], index_diameter[2], index_diameter[1], index_diameter[0]])) bboxes_diameter.append(max(index_diameter)) if isinstance(density_type, str): bboxes_type.append(density_type[0]) else: bboxes_type.append(density_type) bboxes_num += len(bboxes_data) bboxes = BBoxes(np.array(bboxes_data)) bboxes.add_field("bboxes_type", np.array(bboxes_type)) bboxes.add_field("bboxes_diameter", np.array(bboxes_diameter)) # change # bboxes.add_field("filename", os.path.join(self.data_dir, key + '.nii.gz', 'm_ptrCuttingImage.nii.gz')) bboxes.add_field("filename", os.path.join(self.data_dir, key + '.nii.gz')) sample_bboxes.append(bboxes) txn.commit() env.close() self.log_fun('In mode %s, num of ct is %d, num of bboxes: %d.' % (self.mode, len(sample_bboxes), bboxes_num)) return sample_bboxes def _get_targets_idx(self): targets_idx = [] print(f'pos_target_range: {self.pos_target_range}') for i, bboxes in enumerate(self.sample_bboxes): # print(bboxes.data) if len(bboxes.data) > 0: bboxes.add_field("bboxes_label", np.zeros(len(bboxes.data), dtype=np.int32)) for j, bbox_type in enumerate(bboxes.get_field("bboxes_type")): bbox_diameter = bboxes.get_field("bboxes_diameter")[j] bbox_data = bboxes.data[j] # print(f'bbox type: {bbox_type}') # print(f'max val: {max((bbox_data[-3:] * self.spacing)[1:3])}') # if int(bbox_type) in self.pos_target_range and (8 <= max((bbox_data[-3:] * self.spacing)[1:3]) <= 100): if int(bbox_type) in self.pos_target_range and (8 <= max((bbox_data[-3:] * self.spacing)[1:3]) <= 100): if self.mode == 'train': if int(bbox_type) == 2 or int(bbox_type) == 4: for _ in range(5): targets_idx.append(np.array([i, j])) elif int(bbox_type) == 3: for _ in range(2): targets_idx.append(np.array([i, j])) else: targets_idx.append(np.array(([i, j]))) if max((bbox_data[-3:] * self.spacing)[1:3]) >= 40: for _ in range(10): targets_idx.append(np.array([i, j])) else: targets_idx.append(np.array([i, j])) bboxes.get_field("bboxes_label")[j] = 1 elif int(bbox_type) in self.neg_target_range and (4 <= max((bbox_data[-3:] * self.spacing)[1:3]) <= 100): targets_idx.append(np.array([i, j])) bboxes.get_field("bboxes_label")[j] = -1 else: targets_idx.append(np.array([i, j])) bboxes.get_field("bboxes_label")[j] = -1 targets_idx = np.vstack(targets_idx) return targets_idx def __getitem__(self, idx): # t = time.time() # np.random.seed(int(str(t % 1)[2:7])) # seed according to time is_random_img = False if self.mode in ['train', 'val']: if self.balanced_sampling_flag: total_len = len(self.list_IDs) else: total_len = len(self.targets_idx) if idx >= total_len: is_random_crop = True idx = idx % total_len is_random_img = np.random.randint(2) else: is_random_crop = False else: is_random_crop = False if self.balanced_sampling_flag: idx = self.list_IDs[idx] if self.mode in ['train', 'val']: if not is_random_img: target_idx = self.targets_idx[idx] bboxes = self.sample_bboxes[target_idx[0]] new_bboxes = BBoxes(np.copy(bboxes.data)) new_bboxes._copy_extra_fields(bboxes) target_bbox = new_bboxes.data[target_idx[1]] new_bboxes.add_field("target_idx", target_idx[1]) filename = new_bboxes.get_field("filename") imgs = self._load_img(filename) isScale = self.augtype['scale'] and (self.mode == 'train') try: sample, new_bboxes.data = self.crop(imgs, target_bbox, new_bboxes.data, isScale, is_random_crop) except: print('Crop fail in %s.' % filename) return # print(sample.shape, self.crop_size) assert sample.shape[1] == self.crop_size[0] and sample.shape[2] == self.crop_size[1] \ and sample.shape[3] == self.crop_size[2], 'crop patch of {} has illegal shape: {}'.format( filename, sample.shape) if self.mode == 'train' and not is_random_crop: target_bbox = new_bboxes.data[target_idx[1]] sample, new_bboxes.data = augment(sample, target_bbox, new_bboxes.data, do_flip=self.augtype['flip'], do_rotate=self.augtype['rotate'], do_swap=self.augtype['swap']) else: randimid = np.random.randint(len(self.sample_bboxes)) bboxes = self.sample_bboxes[randimid] new_bboxes = BBoxes(np.copy(bboxes.data)) new_bboxes._copy_extra_fields(bboxes) filename = new_bboxes.get_field("filename") imgs = self._load_img(filename) isScale = self.augtype.SCALE and (self.mode == 'train') sample, new_bboxes.data = self.crop(imgs, [], new_bboxes.data, isScale=False, isRand=True) assert sample.shape[1] == self.crop_size[0] and sample.shape[2] == self.crop_size[1] \ and sample.shape[3] == self.crop_size[2], 'patch of {} has illegal shape: {}'.format(filename, sample.shape) sample = norm(sample, hu_min=-1000.0, hu_max=600.0) new_bboxes = self._filter_bboxes(new_bboxes) if len(new_bboxes.data) <= 0: pass # print('Invalid patch in image %s.' % filename) else: new_bboxes.data[:, -3:] = new_bboxes.data[:, -3:] + self.bbox_border return [torch.from_numpy(sample), new_bboxes] if self.mode in ['test', 'infer']: print('# bbox_reader-212: ', self.mode) # image = self._load_img(self.sample_bboxes[idx].get_field("filename")) x_min_ratio = 0.1491 x_max_ratio = 0.8442 y_min_ratio = 0.2685 y_max_ratio = 0.7606 z_min_ratio = 0.1330 z_max_ratio = 0.9143 # pad if self.mode == 'infer': image = self._load_img(self.data_nii_gz[idx]) elif self.mode == 'test': bboxes = self.sample_bboxes[idx] filename = bboxes.get_field("filename") image = self._load_img(filename) z, y, x = image.shape[1:] print(z,y,x) image = image[:, int(z_min_ratio*z): int(z_max_ratio*z), int(y_min_ratio*y): int(y_max_ratio*y), int(x_min_ratio*x): int(x_max_ratio*x)] # print('img path: ', self.data_nii_gz[idx]) # print('img shape: ', image.shape) nz, ny, nx = image.shape[1:] pz = int(np.ceil(float(nz) / self.stride)) * self.stride py = int(np.ceil(float(ny) / self.stride)) * self.stride px = int(np.ceil(float(nx) / self.stride)) * self.stride image = np.pad(image, [[0, 0], [0, pz - nz], [0, py - ny], [0, px - nx]], 'constant', constant_values=self.pad_value) # coord作为model的一个输入, 携带相对于img的位置信息, 取值在[-0.5, 0.5] zz, yy, xx = np.meshgrid(np.linspace(-0.5, 0.5, int(image.shape[1] / self.stride)), np.linspace(-0.5, 0.5, int(image.shape[2] / self.stride)), np.linspace(-0.5, 0.5, int(image.shape[3] / self.stride)), indexing='ij') coord = np.concatenate([zz[np.newaxis, ...], yy[np.newaxis, ...], xx[np.newaxis, :]], axis=0).astype( 'float32') # coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], axis=0).astype( # 'float32') # split image, patch_idx, n_zyx = self.split_combine.split(image, pad_value=self.pad_value) # patch coord, _, n_zyx_coord = self.split_combine.split(coord, side_len=int(self.split_combine.side_len / self.stride), margin=int(self.split_combine.margin / self.stride), pad_value=0) # 0.0 assert np.all(n_zyx == n_zyx_coord) # image = pad2factor(image[0], pad_value=self.pad_value) # image = np.expand_dims(image, 0) input = self.norm(image, hu_min=-1000.0, hu_max=600.0) # return [torch.from_numpy(input).float(), original_image] if self.mode == 'infer': return [torch.from_numpy(input), torch.from_numpy(patch_idx), torch.from_numpy(np.array(n_zyx)).repeat(input.shape[0], 1), torch.from_numpy(coord), [nz, ny, nx]] elif self.mode == 'test': return [torch.from_numpy(input), torch.from_numpy(patch_idx), torch.from_numpy(np.array(n_zyx)).repeat(input.shape[0], 1), torch.from_numpy(coord), [nz, ny, nx], bboxes] # else: # return None # return nz, ny, nx def _load_img(self, path_to_img): # change # path_to_npy = path_to_img.replace('m_ptrCuttingImage.nii.gz', 'm_ptrCuttingImage.npy') # if os.path.exists(path_to_npy): # img = np.load(path_to_npy)[np.newaxis, ...] # else: # itk_image = sitk.ReadImage(path_to_img) # img = sitk.GetArrayFromImage(itk_image)[np.newaxis, ...] if os.path.exists(path_to_img): itk_image = sitk.ReadImage(path_to_img) img = sitk.GetArrayFromImage(itk_image)[np.newaxis, ...] else: self.log_fun('{} is not exists.'.format(path_to_img)) return img def _get_resample_dict_pos(self): min_d = 1 max_d = 11 count_dict = dict() resample_dict = dict() index_dict = dict() # count_dict_neg = dict() # index_dict_neg = dict() # 初始化dict for d in range(min_d, max_d + 1): count_dict[d] = 0 resample_dict[d] = 0 index_dict[d] = list() count_4_11 = 0 count_11_more = 0 for i, target_idx in enumerate(self.targets_idx): bboxes = self.sample_bboxes[target_idx[0]] bboxes_label = bboxes.get_field("bboxes_label") target_bbox = bboxes.data[target_idx[1]] target_c_label = bboxes_label[target_idx[1]] diameter = max((target_bbox[3:] * np.array(self.spacing))[1:3]) key = int(diameter) assert (key != 0) if key >= max_d: key = max_d count_dict[key] += 1 index_dict[key].append(i) if 4 <= int(key) < 11: count_4_11 += 1 elif int(key) >= 11: count_11_more += 1 ratio = count_11_more / count_4_11 for key in count_dict.keys(): if count_dict[key] == 0: resample_dict[key] = 0 if int(key) < 4 or int(key) >= 11: resample_dict[key] = -1 else: resample_dict[key] = count_dict[key] * ratio resample_dict[max_d] = -1 return count_dict, resample_dict, index_dict def _generate_resample_indexes_pos(self): result_index = [] # resample indexes for key in self.resample_dict.keys(): if self.resample_dict[key] == -1: result_index.extend(self.index_dict[key]) else: tmp_index = random.sample(self.index_dict[key], int(self.resample_dict[key])) # if int(key) == 4: # print('index number:') # print(tmp_index[:50]) result_index.extend(tmp_index) return result_index def _get_resample_dict(self): min_d = 1 max_d = 60 count_dict_pos = dict() count_dict_neg = dict() resample_dict_neg = dict() index_dict_pos = dict() index_dict_neg = dict() # 初始化dict for d in range(min_d, max_d + 1): count_dict_pos[d] = 0 count_dict_neg[d] = 0 resample_dict_neg[d] = 0 index_dict_pos[d] = list() index_dict_neg[d] = list() for i, target_idx in enumerate(self.targets_idx): bboxes = self.sample_bboxes[target_idx[0]] bboxes_label = bboxes.get_field("bboxes_label") target_bbox = bboxes.data[target_idx[1]] target_c_label = bboxes_label[target_idx[1]] diameter = max((target_bbox[3:] * np.array(self.spacing))[1:3]) key = int(diameter) assert (key != 0) if key >= max_d: key = max_d if int(target_c_label) == 1: count_dict_pos[key] += 1 index_dict_pos[key].append(i) elif int(target_c_label) == -1: try: count_dict_neg[key] += 1 index_dict_neg[key].append(i) except: print(key) for key in count_dict_pos.keys(): # negative resample if count_dict_neg[key] == 0: resample_dict_neg[key] = -1 else: if count_dict_pos[key] // 2 < count_dict_neg[key]: resample_dict_neg[key] = count_dict_pos[key] // 2 else: resample_dict_neg[key] = -1 resample_dict_neg[max_d] = -1 return count_dict_pos, index_dict_pos, count_dict_neg, resample_dict_neg, index_dict_neg def _generate_resample_indexes(self): result_index = [] # resample indexes for key in self.resample_dict_neg.keys(): result_index.extend(self.index_dict_pos[key]) if self.resample_dict_neg[key] == -1: result_index.extend(self.index_dict_neg[key]) else: tmp_index = random.sample(self.index_dict_neg[key], self.resample_dict_neg[key]) result_index.extend(tmp_index) return result_index def norm(self, image, hu_min=-300.0, hu_max=1000.0): # image = (np.clip(image.astype(np.float32), hu_min, hu_max) - hu_min) / float(hu_max - hu_min) # # image = (image * 255).astype('uint8') # # image = image / 255.0 # return (image - 0.5) * 2.0 image = (np.clip(image.astype(np.float32), hu_min, hu_max) - hu_min) / float(hu_max - hu_min) image = image * 255 return (image.astype(np.float32) - 128) / 128 def __len__(self): if self.mode == 'train': if self.balanced_sampling_flag: return int(len(self.list_IDs) / (1 - self.r_rand)) else: return int(len(self.targets_idx) / (1 - self.r_rand)) elif self.mode == 'val': return len(self.targets_idx) elif self.mode == 'test': return len(self.sample_bboxes) else: return len(self.data_nii_gz)