# -*- coding: utf-8 -*- import os, sys import pathlib current_dir = pathlib.Path(__file__).parent.resolve() while "cls_train" != current_dir.name: current_dir = current_dir.parent sys.path.append(current_dir.as_posix()) import time import logging import numpy as np import SimpleITK as sitk import pydicom import traceback from concurrent.futures import ThreadPoolExecutor from multiprocessing.pool import Pool from data.data_process_utils.test_box_utils import NoduleBox, nodule_raw2standard from data.data_process_utils.test_data_utils import check_and_makedirs from data.data_process_utils.test_data_utils import clip_data, downsample_data, upsample_mask, resample_data, resample_mask import base64 import cv2 import ast def lung_segment(data_npy): def morphology_opening_2d(mask, data_npy, pool_size=5): mask_npy = np.zeros(data_npy.shape, data_npy.dtype) chest_mask_npy = np.zeros(data_npy.shape, data_npy.dtype) def morphology_opening_2d_task(i, mask_2d): mask_slice = sitk.BinaryMorphologicalOpening(mask_2d, 2) mask_npy[i] = sitk.GetArrayFromImage(mask_slice) chest_mask_npy[i] = sitk.GetArrayFromImage(sitk.BinaryFillhole(mask_slice)) with ThreadPoolExecutor(pool_size) as executor: for i in range(data_npy.shape[0]): executor.submit(morphology_opening_2d_task, i, mask[:, :, i]) return sitk.GetImageFromArray(mask_npy), sitk.GetImageFromArray(chest_mask_npy) def morphology_closing_2d(mask, data_npy, pool_size=5): mask_npy = np.zeros(data_npy.shape, data_npy.dtype) def morphology_closing_2d_task(i, mask_2d): mask_slice = sitk.BinaryMorphologicalClosing(mask_2d, 15) mask_slice = sitk.BinaryDilate(mask_slice, 15) mask_npy[i] = sitk.GetArrayFromImage(sitk.BinaryFillhole(mask_slice)) mask_npy[i] = np.bitwise_or(mask_npy[i], mask_npy[i][:, ::-1]) with ThreadPoolExecutor(pool_size) as executor: for i in range(data_npy.shape[0]): executor.submit(morphology_closing_2d_task, i, mask[:, :, i]) return sitk.GetImageFromArray(mask_npy) data_npy = data_npy.astype(np.int) data_npy = clip_data(data_npy) data = sitk.GetImageFromArray(data_npy) mask = 1 - sitk.OtsuThreshold(data) mask, chest_mask = morphology_opening_2d(mask, data_npy) lung_mask = sitk.Subtract(chest_mask, mask) # Remove areas not in the chest, when CT covers regions below the chest eroded_mask = sitk.BinaryErode(lung_mask, 10) seed_npy = sitk.GetArrayFromImage(eroded_mask) seed_npy = np.array(seed_npy.nonzero())[[2, 1, 0]] seeds = seed_npy.T.tolist() lung_mask = sitk.ConfidenceConnected(lung_mask, seeds, multiplier=2.5) lung_mask = morphology_closing_2d(lung_mask, data_npy) return sitk.GetArrayFromImage(lung_mask) def lung_segment_enhance(data): z_max = max(int(np.ceil(data.shape[0] / 100)), 2) if data.shape[1] <= 512 and data.shape[2] <= 512: scale = (z_max, 2, 2) else: scale = (z_max, 4, 4) new_data_shape = np.array(data.shape) // scale * scale new_data = data[:new_data_shape[0], :new_data_shape[1], :new_data_shape[2]] new_data = downsample_data(new_data, scale) new_data = new_data.astype(data.dtype) new_mask = lung_segment(new_data) mask = upsample_mask(new_mask, scale) pad = np.zeros((3, 2), np.int) pad[:, 1] = np.array(data.shape) - np.array(mask.shape) mask = np.pad(mask, pad, mode='edge') return mask def get_lung_mask_and_box(data, uid, segment=False, segment_data=False, segment_margin=[0, 0, 0], min_points=10000): lung_box = np.zeros((3, 2), np.int) lung_box[:, 1] = data.shape found_lung = False if segment: try: start_time = time.time() lung_mask = lung_segment_enhance(data) logging.info('{}, {}, Lung segment run time {:.2f}(s)'.format( time.strftime("%Y-%m-%d %H:%M:%S"), uid, (time.time() - start_time))) if np.sum(lung_mask == 1) > min_points: found_lung = True if found_lung and segment_data: segment_margin = np.array(segment_margin) coords = np.asarray(np.where(lung_mask == 1)) lung_box[:, 0] = np.maximum(coords.min(axis=1) - segment_margin, 0) lung_box[:, 1] = np.minimum(coords.max(axis=1) + segment_margin + 1, data.shape) data = data[lung_box[0, 0]:lung_box[0, 1], lung_box[1, 0]:lung_box[1, 1], lung_box[2, 0]:lung_box[2, 1]] lung_mask = lung_mask[lung_box[0, 0]:lung_box[0, 1], lung_box[1, 0]:lung_box[1, 1], lung_box[2, 0]:lung_box[2, 1]] except Exception as e: traceback.print_exc() if not found_lung: lung_mask = np.ones(data.shape, np.uint8) return data, lung_mask, lung_box def transform_file_type(input_file, output_file): image = sitk.ReadImage(input_file) sitk.WriteImage(image, output_file) def load_single_dicom(input_file): image = sitk.ReadImage(input_file) slice = sitk.GetArrayFromImage(image)[0, :, :] return slice class CTSeries(object): def __init__(self): self._PatientID = None self._SeriesInstanceUID = None self._SOPInstanceUIDs = None self._ReconstructionDiameter = None self._Rows = None self._Columns = None self._AcquisitionDate = None self._Manufacturer = None self._InstitutionName = None self._raw_data = None self._raw_spacing = None self._raw_origin = None self._raw_direction = None self._dicoms_is_loaded = False self._lung_mask = None self._lung_box = None self._standard_data = None self._standard_spacing = None self._dicoms_is_preprocessed = False self._raw_labels = None self._standard_labels = None self._label_is_loaded = False self._standard_is_loaded = False def load_dicoms(self, dicom_dir_path): logging.info('{}, Loading dicoms from {}...'.format(time.strftime('%Y-%m-%d %H:%M:%S'), dicom_dir_path)) dicom_names = [f for f in os.listdir(dicom_dir_path) if '.xml' not in f] dicom_paths = list(map(lambda x: os.path.join(dicom_dir_path, x), dicom_names)) dicoms = list(map(lambda x: pydicom.read_file(x, stop_before_pixels=True), dicom_paths)) try: slice_locations = list(map(lambda x: float(x.ImagePositionPatient[2]), dicoms)) except AttributeError: slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms)) # sort slices by their z coordinates from large to small if dicoms[0].get('PatientPosition') is None: patient_position = 'HFS' else: patient_position = dicoms[0].PatientPosition if patient_position in ['FFP', 'FFS']: idx_z_sorted = np.argsort(slice_locations)[::-1] else: idx_z_sorted = np.argsort(slice_locations)[::-1] dicom_paths = np.asarray(dicom_paths)[idx_z_sorted] self._SeriesInstanceUID = dicoms[0].SeriesInstanceUID self._SOPInstanceUIDs = np.array(list(map(lambda x: x.SOPInstanceUID, dicoms)))[idx_z_sorted] try: self._PatientID = dicoms[0].PatientID self._ReconstructionDiameter = dicoms[0].ReconstructionDiameter self._Rows = dicoms[0].Rows self._Columns = dicoms[0].Columns self._AcquisitionDate = dicoms[0].AcquisitionDate self._Manufacturer = dicoms[0].Manufacturer self._InstitutionName = dicoms[0].InstitutionName except AttributeError: self._PatientID = None self._ReconstructionDiameter = None self._Rows = None self._Columns = None self._AcquisitionDate = None self._Manufacturer = None self._InstitutionName = None reader = sitk.ImageSeriesReader() reader.SetFileNames(dicom_paths) image_itk = reader.Execute() # all in [z, y, x] order self._raw_data = sitk.GetArrayFromImage(image_itk) self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing()))) self._raw_origin = np.array(list(reversed(image_itk.GetOrigin()))) self._raw_direction = image_itk.GetDirection() self._dicoms_is_loaded = True def load_dicoms_mp(self, dicom_dir_path): logging.info('{}, Loading dicoms from {}...'.format(time.strftime('%Y-%m-%d %H:%M:%S'), dicom_dir_path)) dicom_names = [f for f in os.listdir(dicom_dir_path) if '.xml' not in f] dicom_paths = list(map(lambda x: os.path.join(dicom_dir_path, x), dicom_names)) dicoms = list(map(lambda x: pydicom.read_file(x, stop_before_pixels=True), dicom_paths)) try: slice_locations = list(map(lambda x: float(x.ImagePositionPatient[2]), dicoms)) except AttributeError: slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms)) # sort slices by their z coordinates from large to small if dicoms[0].get('PatientPosition') is None: patient_position = 'HFS' else: patient_position = dicoms[0].PatientPosition if patient_position in ['FFP', 'FFS']: idx_z_sorted = np.argsort(slice_locations)[::-1] else: idx_z_sorted = np.argsort(slice_locations)[::-1] dicom_paths = np.asarray(dicom_paths)[idx_z_sorted] self._SeriesInstanceUID = dicoms[0].SeriesInstanceUID self._SOPInstanceUIDs = np.array(list(map(lambda x: x.SOPInstanceUID, dicoms)))[idx_z_sorted] try: self._PatientID = dicoms[0].PatientID self._ReconstructionDiameter = dicoms[0].ReconstructionDiameter self._Rows = dicoms[0].Rows self._Columns = dicoms[0].Columns self._AcquisitionDate = dicoms[0].AcquisitionDate self._Manufacturer = dicoms[0].Manufacturer self._InstitutionName = dicoms[0].InstitutionName except AttributeError: self._PatientID = None self._ReconstructionDiameter = None self._Rows = None self._Columns = None self._AcquisitionDate = None self._Manufacturer = None self._InstitutionName = None pool = Pool(processes=20) slice_list = pool.map(load_single_dicom, dicom_paths) pool.close() pool.join() raw_data = np.zeros((len(dicom_paths), dicoms[0].Rows, dicoms[0].Columns)) for idx, slice_data in enumerate(slice_list): raw_data[idx] = slice_data image0 = sitk.ReadImage(dicom_paths[0]) dicom0 = pydicom.read_file(dicom_paths[0], stop_before_pixels=True) # all in [z, y, x] order self._raw_data = raw_data self._raw_spacing = np.array(list(reversed(image0.GetSpacing()))) self._raw_origin = np.array(list(reversed(dicom0.ImagePositionPatient))) self._raw_direction = image0.GetDirection() self._dicoms_is_loaded = True def load_single_file(self, file_path, uid=None): logging.info('{}, Loading file from {}...'.format(time.strftime('%Y-%m-%d %H:%M:%S'), file_path)) self._SeriesInstanceUID = uid if self._SeriesInstanceUID is None: self._SeriesInstanceUID = os.path.splitext(os.path.basename(file_path))[0] image_itk = sitk.ReadImage(file_path) # all in [z, y, x] order self._raw_data = sitk.GetArrayFromImage(image_itk) self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing()))) self._raw_origin = np.array(list(reversed(image_itk.GetOrigin()))) self._raw_direction = image_itk.GetDirection() self._dicoms_is_loaded = True def load_file(self, file_path, uid=None): if os.path.splitext(os.path.basename(file_path))[-1] in ['.mhd']: self.load_single_file(file_path, uid) else: self.load_dicoms(file_path) def is_hrct(self): """ 是否是高清扫描 """ if self._Rows is not None and self._Rows >= 1024: return True return False def is_ultra_hrct(self, max_reconstruction_diameter=250): """ 是否是高清靶扫描 """ if self._Rows is not None and self._Rows >= 1024 and \ self._ReconstructionDiameter is not None and \ self._ReconstructionDiameter < max_reconstruction_diameter: return True return False def get_patient_id(self): return self._PatientID def get_series_instance_uid(self): return self._SeriesInstanceUID def get_reconstruction_diameter(self): return self._ReconstructionDiameter def get_rows(self): return self._Rows def get_columns(self): return self._Columns def get_acquisition_date(self): return self._AcquisitionDate def get_manufacturer(self): return self._Manufacturer def get_institution_name(self): return self._InstitutionName def get_raw_data(self): return self._raw_data def get_raw_spacing(self): return self._raw_spacing def get_raw_origin(self): return self._raw_origin def get_raw_direction(self): return self._raw_direction def get_lung_mask(self): return self._lung_mask def get_lung_box(self): return self._lung_box def get_standard_data(self): return self._standard_data def get_standard_spacing(self): return self._standard_spacing def get_raw_labels(self): return self._raw_labels def get_standard_labels(self): return self._standard_labels def save_raw_data(self, output_file): if self._dicoms_is_loaded: image = sitk.GetImageFromArray(self._raw_data) image.SetSpacing(np.array(list(reversed(self._raw_spacing)))) image.SetOrigin(np.array(list(reversed(self._raw_origin)))) image.SetDirection(self._raw_direction) check_and_makedirs(output_file, is_file=True) sitk.WriteImage(image, output_file) def world_to_voxel_coord(self, world_coord): voxel_coord = np.absolute(world_coord - self._raw_origin) / self._raw_spacing return voxel_coord def voxel_to_world_coord(self, voxel_coord): world_coord = voxel_coord * self._raw_spacing + self._raw_origin return world_coord def preprocess(self, segment=False, scale=(1, 1, 1)): logging.info('{}, {}, Preprocessing ...'.format( time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID)) standard_data = downsample_data(self._raw_data, scale) standard_spacing = self._raw_spacing * np.array(scale) standard_data, lung_mask, lung_box = \ get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment) self._lung_mask = lung_mask self._lung_box = lung_box self._standard_data = standard_data self._standard_spacing = standard_spacing self._dicoms_is_preprocessed = True def preprocess_1024u(self, check_spacing=False, segment=False, scale=(1, 1, 1)): logging.info('{}, {}, Preprocessing 1024u...'.format( time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID)) if check_spacing and \ (float(self._raw_spacing[0]) > 1.5 or self._ReconstructionDiameter is None or float(self._ReconstructionDiameter) > 190): logging.info('{}, {}, Preprocessing 1024u resample data...'.format( time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID)) # z轴间隔大于1.5mm或者FOV为空或者大于190mm,重采样z轴间隔为1mm与FOV为180mm new_spacing = np.array([1.0, 0.17578125, 0.17578125]) * np.array(scale) standard_data, standard_spacing = resample_data(self._raw_data, self._raw_spacing, new_spacing) else: standard_data = downsample_data(self._raw_data, scale) standard_spacing = self._raw_spacing * np.array(scale) standard_data, lung_mask, lung_box = \ get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment) self._lung_mask = lung_mask self._lung_box = lung_box self._standard_data = standard_data self._standard_spacing = standard_spacing self._dicoms_is_preprocessed = True def preprocess_1024(self, check_spacing=False, segment=False, scale=(1, 1, 1)): logging.info('{}, {}, Preprocessing 1024...'.format( time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID)) if check_spacing and \ float(self._raw_spacing[0]) > 1.5: logging.info('{}, {}, Preprocessing 1024 resample data...'.format( time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID)) # z轴间隔大于1.5mm,重采样z轴间隔为1mm new_spacing = np.array([1.0, 0.3515625, 0.3515625]) * np.array(scale) standard_data, standard_spacing = resample_data(self._raw_data, self._raw_spacing, new_spacing) else: standard_data = downsample_data(self._raw_data, scale) standard_spacing = self._raw_spacing * np.array(scale) standard_data, lung_mask, lung_box = \ get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment) self._lung_mask = lung_mask self._lung_box = lung_box self._standard_data = standard_data self._standard_spacing = standard_spacing self._dicoms_is_preprocessed = True def preprocess_512(self, check_spacing=False, segment=False, scale=(1, 1, 1)): logging.info('{}, {}, Preprocessing 512...'.format( time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID)) if check_spacing and \ float(self._raw_spacing[0]) > 1.5: logging.info('{}, {}, Preprocessing 512 resample data...'.format( time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID)) # z轴间隔大于1.5mm,重采样z轴间隔为1mm new_spacing = np.array([1.0, 0.703125, 0.703125]) * np.array(scale) standard_data, standard_spacing = resample_data(self._raw_data, self._raw_spacing, new_spacing) else: standard_data = downsample_data(self._raw_data, scale) standard_spacing = self._raw_spacing * np.array(scale) standard_data, lung_mask, lung_box = \ get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment) self._lung_mask = lung_mask self._lung_box = lung_box self._standard_data = standard_data self._standard_spacing = standard_spacing self._dicoms_is_preprocessed = True def load_labels(self, label_path): """ Load labels from label_path. label_path: path to the label file, which is a csv with 5 fields: [z, y, x, diameter, is_pos] """ if not self._dicoms_is_loaded: raise Exception('DICOM files have not been loaded yet') if not self._dicoms_is_preprocessed: raise Exception('DICOM files have not been preprocessed yet') logging.info('{}, {}, Loading labels from {}...'.format( time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID, label_path)) nodule_boxes = [] with open(label_path) as f: for line in f: z, y, x, diameter, is_pos = \ line.strip('\n').replace('"', '').split(',')[0:5] nodule_boxes.append( NoduleBox(float(z), float(y), float(x), float(diameter), float(is_pos), 1.0)) self._raw_labels = nodule_boxes self._standard_labels = nodule_raw2standard( self._raw_labels, self._raw_spacing, self._standard_spacing, start=self._lung_box[:, 0]) self._label_is_loaded = True def load_nodule_boxes(self, nodule_boxes): """ Load labels from nodule_boxes. """ if not self._dicoms_is_loaded: raise Exception('DICOM files have not been loaded yet') if not self._dicoms_is_preprocessed: raise Exception('DICOM files have not been preprocessed yet') self._raw_labels = nodule_boxes self._standard_labels = nodule_raw2standard( self._raw_labels, self._raw_spacing, self._standard_spacing, start=self._lung_box[:, 0]) self._label_is_loaded = True def save_standard_npy(self, npy_path, uid): """ Save *_standard data in numpy format. npy_path: str, path to save. uid: str, prefix for the files. """ if not self._dicoms_is_preprocessed: raise Exception('DICOM files have not been preprocessed yet') check_and_makedirs(npy_path) np.save(os.path.join(npy_path, uid + '_standard_data.npy'), self._standard_data.astype(np.float32)) np.save(os.path.join(npy_path, uid + '_standard_spacing.npy'), self._standard_spacing.astype(np.float32)) # 检查标注是否错误 standard_data_shape = np.array(self._standard_data.shape) standard_labels = [] for i in range(len(self._standard_labels)): nodule_box = self._standard_labels[i] if not (0 <= nodule_box.z <= standard_data_shape[0] - 1 and 0 <= nodule_box.y <= standard_data_shape[1] - 1 and 0 <= nodule_box.x <= standard_data_shape[2] - 1): logging.error('{}, {}, label index={} error.'.format(time.strftime("%Y-%m-%d %H:%M:%S"), uid, i)) standard_labels.append(np.array(self._standard_labels[i])) standard_labels = np.array(standard_labels) np.save(os.path.join(npy_path, uid + '_standard_labels.npy'), standard_labels.astype(np.float32)) def load_standard_npy(self, npy_path, uid, mmap_mode='r'): """ Load *_standard data in numpy format. npy_path: str, path to load. uid: str, prefix for the files. """ if self._dicoms_is_preprocessed: raise Exception('DICOM files have already been preprocessed') self._SeriesInstanceUID = uid self._standard_data = np.load( os.path.join(npy_path, uid + '_standard_data.npy'), mmap_mode=mmap_mode) self._standard_spacing = np.load( os.path.join(npy_path, uid + '_standard_spacing.npy')) self._standard_labels = [] standard_labels = np.load( os.path.join(npy_path, uid + '_standard_labels.npy')) for i in range(len(standard_labels)): z, y, x, diameter, is_pos = standard_labels[i][0:5] self._standard_labels.append( NoduleBox(z, y, x, diameter, is_pos, 1.0)) self._standard_is_loaded = True def save_standard_mask_npy(self, npy_path, uid, nodule_index, mask): """ 保存mask """ check_and_makedirs(npy_path) standard_mask = mask if np.any(self._standard_spacing != self._raw_spacing): standard_mask = resample_mask(mask, self._standard_data.shape) lung_box = self._lung_box standard_mask = standard_mask[lung_box[0, 0]:lung_box[0, 1], lung_box[1, 0]:lung_box[1, 1], lung_box[2, 0]:lung_box[2, 1]] np.save(os.path.join(npy_path, uid + '_' + str(nodule_index) + '_standard_mask.npy'), standard_mask.astype(np.uint8)) def load_standard_mask_npy(self, npy_path, uid, nodule_index, mmap_mode='r'): """ 加载mask """ standard_mask = np.load( os.path.join(npy_path, uid + '_' + str(nodule_index) + '_standard_mask.npy'), mmap_mode=mmap_mode) return standard_mask def base64_to_list(base64_str): indexs = '' list = [] img_np = None if base64_str: time_now = time.time() img_data = base64.b64decode(base64_str[22:]) nparr = np.fromstring(img_data, np.uint8) img_np = cv2.imdecode(nparr, 0) img_np[img_np != 0] = 1 point_list = np.where(img_np != 0) if len(point_list[0]) > 0: y_list = point_list[0] x_list = point_list[1] indexs = '{' for point_idx, x in enumerate(x_list): raw_x = x raw_y = y_list[point_idx] map = {'x': raw_x, 'y': raw_y} list.append(map) indexs = indexs + '[%s,%s],' % (raw_x, raw_y) indexs = indexs[:-1] + '}' print('run time {:.5f}(s)'.format(time.time() - time_now)) return list, indexs, img_np def meta_to_list(meta, img_np=np.zeros((1, 1))): meta_dict = ast.literal_eval(meta) x = meta_dict['x'] y = meta_dict['y'] w = meta_dict['w'] h = meta_dict['h'] delineation = meta_dict['delineation'] mask_str = '' list = [] indexs = '{' for e in delineation: mask_str = mask_str + str((bin(((1 << 32) - 1) & int(e))[2:]).zfill(32)) for idx, s in enumerate(mask_str): if s == '1': x_index = x + int(idx % w) y_index = y + int(idx / w) if len(img_np)>1: img_np[y_index][x_index] = 1 map = {'x': x_index, 'y': y_index} list.append(map) indexs = indexs + '[%s,%s],' % (x_index, y_index) return list, indexs, img_np