quality_control.py 6.7 KB
import SimpleITK as sitk
from pydicom import dicomio
import os
import numpy as np
import glob
from tqdm import tqdm
import argparse
import SimpleITK as sitk

def load_ct_from_dicom(dcm_path, sort_by_distance=True):
    class DcmInfo(object):
        def __init__(self, dcm_path, series_instance_uid, acquisition_number, sop_instance_uid, instance_number,
                     image_orientation_patient, image_position_patient):
            super(DcmInfo, self).__init__()

            self.dcm_path = dcm_path
            self.series_instance_uid = series_instance_uid
            self.acquisition_number = acquisition_number
            self.sop_instance_uid = sop_instance_uid
            self.instance_number = instance_number
            self.image_orientation_patient = image_orientation_patient
            self.image_position_patient = image_position_patient

            self.slice_distance = self._cal_distance()

        def _cal_distance(self):
            normal = [self.image_orientation_patient[1] * self.image_orientation_patient[5] -
                      self.image_orientation_patient[2] * self.image_orientation_patient[4],
                      self.image_orientation_patient[2] * self.image_orientation_patient[3] -
                      self.image_orientation_patient[0] * self.image_orientation_patient[5],
                      self.image_orientation_patient[0] * self.image_orientation_patient[4] -
                      self.image_orientation_patient[1] * self.image_orientation_patient[3]]

            distance = 0
            for i in range(3):
                distance += normal[i] * self.image_position_patient[i]
            return distance

    def is_sop_instance_uid_exist(dcm_info, dcm_infos):
        for item in dcm_infos:
            if dcm_info.sop_instance_uid == item.sop_instance_uid:
                return True
        return False

    def get_dcm_path(dcm_info):
        return dcm_info.dcm_path

    reader = sitk.ImageSeriesReader()
    if sort_by_distance:
        dcm_infos = []

        files = os.listdir(dcm_path)
        for file in files:
            file_path = os.path.join(dcm_path, file)

            dcm = dicomio.read_file(file_path, force=True)
            _series_instance_uid = dcm.SeriesInstanceUID
            _sop_instance_uid = dcm.SOPInstanceUID
            _instance_number = dcm.InstanceNumber
            _acquisition_number = dcm.AcquisitionNumber
            _image_orientation_patient = dcm.ImageOrientationPatient
            _image_position_patient = dcm.ImagePositionPatient

            dcm_info = DcmInfo(file_path, _series_instance_uid, _acquisition_number, _sop_instance_uid,
                               _instance_number, _image_orientation_patient, _image_position_patient)

            if is_sop_instance_uid_exist(dcm_info, dcm_infos):
                continue

            dcm_infos.append(dcm_info)

        dcm_infos.sort(key=lambda x: x.slice_distance)
        dcm_series = list(map(get_dcm_path, dcm_infos))
    else:
        dcm_series = reader.GetGDCMSeriesFileNames(dcm_path)

    reader.SetFileNames(dcm_series)
    sitk_image = reader.Execute()
    return sitk_image

def control_single(target_path):

    dcms = os.listdir(target_path)

    if len(dcms) < 30:
        return 'LAYERCOUNT_LESS'

    # dcm_names = [int(name.split('.')[0], 16) for name in dcms]
    # #print(len(dcm_names), len(set(dcm_names)))
    # if len(dcm_names) != len(set(dcm_names)):
    #     return 'NO INSTANCE_NUM_BREAK'   

    # sorted_names = sorted(dcm_names)
    #print(sorted_names)
    #for i in range(len(sorted_names) - 1):
    #    if sorted_names[i + 1] != sorted_names[i] + 1:
    #        return 'NO INSTANCE_NUM_BREAK'  

    dcm = dicomio.read_file(os.path.join(target_path, dcms[0]), force=True)
    if dcm.get('Modality', '') not in ['CT', 'CTSR', 'SRCT']:
        return 'NOT_CT'
    
    #check_strs = ['chest', 'thorax', 'lung', 'thorroutine', '胸', '肺', '肋']
    check_strs = ['chest', 'thorax', 'lung', 'nodule']
    #check_res = False
    check_results = np.zeros(4)
    for i, target_str in enumerate([dcm.get('StudyDescription', ''), dcm.get('SeriesDescription', ''),
                                    dcm.get('BodyPartExamined', ''), dcm.get('ProtocolName', '')]):
        #target_str = target_str.lower()
        #print(f'target str:---{target_str}---')
        #for j in check_strs:
        #    if j in target_str:
        #        check_res = True
        #        break
        #if check_res:
        #    break
        #print(target_str)
        score = np.zeros(len(check_strs))
        for idx, cstr in enumerate(check_strs):
            if target_str.lower().find(cstr) != -1:
                score[idx] = 1
        if list(score).count(1) != 0:
            check_results[i] = 1
    if sum(check_results) == 0:
        return 'MISS_SPECFIC_STR'
    #if not check_res:
    #    return 'MISS_SPECFIC_STR'
    if dcm.get('SliceThickness', '') <= 0:
        return 'MISS_SLICETHICKNESS'

    if (dcm.get('Rows', '') == dcm.get('Columns', '')) and (dcm.get('Rows', '') == 512 or dcm.get('Columns', '') == 1024):
        pass
    else:
        return 'ROW_COLUMS' 
    
    for idx, value in enumerate(list(dcm.get('ImageOrientationPatient', ''))):
        if int(value) != int(float(['1.0', '0.0', '0.0', '0.0', '1.0', '0.0'][idx])):
            return 'IMAGEORIENTATIONPATIENT_CAUSE'
    
    return 'SUCCESS'

def crop_lung(job_data_root):
    job_data_root = os.path.join(job_data_root, 'output/tmp')

    files = [file for file in os.listdir(job_data_root) if 'rpn' not in file]
    
    for file in files:
        path = os.path.join(job_data_root, file)
        result = control_single(path)
        print(file, result)
        if result == 'SUCCESS':
            x_min_ratio = 0.1491
            x_max_ratio = 0.8442
            y_min_ratio = 0.2685
            y_max_ratio = 0.7606
            z_min_ratio = 0.1330
            z_max_ratio = 0.9143
            sitk_img = load_ct_from_dicom(path)
            np_img = sitk.GetArrayFromImage(sitk_img)
            z, y, x = np_img.shape
            np_img = np_img[int(z_min_ratio*z): int(z_max_ratio*z), int(y_min_ratio*y): int(y_max_ratio*y), int(x_min_ratio*x): int(x_max_ratio*x)]
            save_path = os.path.join(os.path.dirname(job_data_root), 'preprocess', 'qc')
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            save_path = os.path.join(save_path, file + '.nii.gz')
            sitk_img = sitk.GetImageFromArray(np_img)
            sitk.WriteImage(sitk_img, save_path)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch qc')
    parser.add_argument('--job_data_root', default='/data/job_715/job_data_preprocess',type=str)
    args = parser.parse_args()

    crop_lung(args.job_data_root)