import json import os import sys import traceback from multiprocessing import Pool, cpu_count import SimpleITK as sitk from pydicom import dicomio import lmdb import numpy as np from datetime import datetime from resample import ItkResample # from resample import ItkResample BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # 添加路径 sys.path.append(BASE_DIR) class MyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, datetime): return obj.__str__() else: return super(MyEncoder, self).default(obj) def load_ct_from_dicom(dcm_path, sort_by_distance=True): class DcmInfo(object): def __init__(self, dcm_path, series_instance_uid, acquisition_number, sop_instance_uid, instance_number, image_orientation_patient, image_position_patient): super(DcmInfo, self).__init__() self.dcm_path = dcm_path self.series_instance_uid = series_instance_uid self.acquisition_number = acquisition_number self.sop_instance_uid = sop_instance_uid self.instance_number = instance_number self.image_orientation_patient = image_orientation_patient self.image_position_patient = image_position_patient self.slice_distance = self._cal_distance() def _cal_distance(self): normal = [self.image_orientation_patient[1] * self.image_orientation_patient[5] - self.image_orientation_patient[2] * self.image_orientation_patient[4], self.image_orientation_patient[2] * self.image_orientation_patient[3] - self.image_orientation_patient[0] * self.image_orientation_patient[5], self.image_orientation_patient[0] * self.image_orientation_patient[4] - self.image_orientation_patient[1] * self.image_orientation_patient[3]] distance = 0 for i in range(3): distance += normal[i] * self.image_position_patient[i] return distance def is_sop_instance_uid_exist(dcm_info, dcm_infos): for item in dcm_infos: if dcm_info.sop_instance_uid == item.sop_instance_uid: return True return False def get_dcm_path(dcm_info): return dcm_info.dcm_path reader = sitk.ImageSeriesReader() if sort_by_distance: dcm_infos = [] files = os.listdir(dcm_path) for file in files: file_path = os.path.join(dcm_path, file) dcm = dicomio.read_file(file_path, force=True) _series_instance_uid = dcm.SeriesInstanceUID _sop_instance_uid = dcm.SOPInstanceUID _instance_number = dcm.InstanceNumber _acquisition_number = dcm.AcquisitionNumber _image_orientation_patient = dcm.ImageOrientationPatient _image_position_patient = dcm.ImagePositionPatient dcm_info = DcmInfo(file_path, _series_instance_uid, _acquisition_number, _sop_instance_uid, _instance_number, _image_orientation_patient, _image_position_patient) if is_sop_instance_uid_exist(dcm_info, dcm_infos): continue dcm_infos.append(dcm_info) dcm_infos.sort(key=lambda x: x.slice_distance) dcm_series = list(map(get_dcm_path, dcm_infos)) else: dcm_series = reader.GetGDCMSeriesFileNames(dcm_path) reader.SetFileNames(dcm_series) sitk_image = reader.Execute() return sitk_image class Trans_Dicom_2_NPY(object): def __init__(self, in_path, out_path, datainfos): ''' :param in_path: path of dicom :param out_path: path of nii.gz ''' self.in_path = in_path self.out_path = out_path self.datainfos = datainfos self.itk_resample = ItkResample() self.target_spacing = 1 self.suid_list = [] for info in datainfos: self.suid_list.append(info[-2][0].split('/')[-2]) # print('self.suid_list', self.suid_list) # print(self.suid_list) # self.target_spacing = 1 # self.itk_resample = ItkResample() if not os.path.exists(self.out_path): os.makedirs(self.out_path) def __call__(self): uids = os.listdir(self.in_path) print(f'num cpus: {cpu_count()}') # pool = Pool(int(cpu_count() * 0.7)) for uid in uids: if not os.path.exists(os.path.join(self.out_path, uid + '_data.npy')) or not os.path.exists(os.path.join(self.out_path, uid + '_info.npy')): # pool.apply_async(self._single_transform, (uid, )) # np.save(os.path.join(self.out_path, uid + '_data.npy'), whole_data) # np.save(os.path.join(self.out_path, uid + '_info.npy'), whole_info) self._single_transform(uid) # pool.close() # pool.join() def _single_transform(self, uid): # dcm -> npy print('Processing series uid %s' % uid) dcm_folder = os.path.join(self.in_path, uid) # get itkimage try: itk_image = load_ct_from_dicom(dcm_folder) except Exception as err: print('!!!!! Read %s throws exception %s.' % (uid, err)) return # resample spacing print('Resample uid %s to spacing %.1f.' % (uid, self.target_spacing)) itk_image_resample = self.itk_resample.resample_to_spacing( itk_image=itk_image, target_spacing=(tuple([self.target_spacing] * 3)), interpolator=sitk.sitkBSpline) itk_image_resample.SetOrigin(itk_image.GetOrigin()) itk_image_resample.SetSpacing([self.target_spacing] * 3) # 将itk数据准换位npy格式,并且进行裁切 origin = itk_image_resample.GetOrigin() spacing = itk_image_resample.GetSpacing() array_image = sitk.GetArrayFromImage(itk_image_resample) bz, by, bx = array_image.shape # print(uid) # print(self.suid_list) work_index = self.suid_list.index(uid) info = self.datainfos[work_index][-1] mask_path = None for i in info: if 'nii.gz' in i: mask_path = i break if mask_path is None: return # label_infos = json.load(open(self.datainfos[work_index][-1][0], 'r'))['annotationSessions'] print(f'mask path {mask_path}') mask = sitk.ReadImage(mask_path) itk_mask_resample = self.itk_resample.resample_to_spacing( itk_image=mask, target_spacing=(tuple([self.target_spacing] * 3)), interpolator=sitk.sitkBSpline) itk_mask_resample.SetOrigin(itk_image.GetOrigin()) itk_mask_resample.SetSpacing([self.target_spacing] * 3) spacing_mask = itk_mask_resample.GetSpacing() array_mask = sitk.GetArrayFromImage(itk_mask_resample) print(f'img shape: {array_image.shape}, spacing {spacing}') print(f'mask shape: {array_mask.shape}, mask spacing {spacing_mask}') ## 每个mask id切个图 mask_ids = np.unique(array_mask) cut_size = 32 shape_min = min(array_image.shape) cut_size = min(cut_size, shape_min) save_dir = os.path.join(self.out_path, uid) os.makedirs(save_dir, exist_ok=True) for i in mask_ids: try: mask_i = np.where(array_mask==i, 1, 0) zs, ys, xs = np.where(mask_i==1) z_min, y_min, x_min = min(zs), min(ys), min(xs) z_max, y_max, x_max = max(zs), max(ys), max(xs) print(f'cut coords: {z_min}, {z_max}, {y_min}, {y_max}, {x_min}, {x_max}') # zc, yc, xc = int((z_min+z_max)/2), int((y_min+y_max)/2), int((x_min+x_max)/2) zc, yc, xc = int(np.median(zs)), int(np.median(ys)), int(np.median(xs)) z_min_cut, y_min_cut, x_min_cut = int(zc-cut_size/2), int(yc-cut_size/2), int(xc-cut_size/2) z_max_cut, y_max_cut, x_max_cut = int(zc+cut_size/2), int(yc+cut_size/2), int(xc+cut_size/2) if z_min_cut < 0: z_min_cut = 0 z_max_cut = cut_size if y_min_cut < 0: y_min_cut = 0 y_max_cut = cut_size if x_min_cut < 0: x_min_cut = 0 x_max_cut = cut_size if z_max_cut > array_image.shape[0]: z_max_cut = array_image.shape[0] z_min_cut = array_image.shape[0]-cut_size if y_max_cut > array_image.shape[1]: y_max_cut = array_image.shape[1] y_min_cut = array_image.shape[1]-cut_size if x_max_cut > array_image.shape[2]: x_max_cut = array_image.shape[2] x_min_cut = array_image.shape[2]-cut_size mask_i = mask_i[z_min_cut: z_max_cut, y_min_cut: y_max_cut, x_min_cut: x_max_cut] img_i = array_image[z_min_cut: z_max_cut, y_min_cut: y_max_cut, x_min_cut: x_max_cut] print(f'img_i shape: {img_i.shape}, mask_i shape: {mask_i.shape}') save_mask_path = os.path.join(save_dir, str(i) + '_mask.npy') save_img_path = os.path.join(save_dir, str(i) + '_data.npy') save_info_path = os.path.join(save_dir, str(i) + '_info.npy') np.save(save_mask_path, mask_i) np.save(save_img_path, img_i) np.save(save_info_path, np.array(spacing)) except: print(f'mask id {i} cutting failed')