import os
import logging
import time
import numpy as np
import scipy.ndimage
import SimpleITK as sitk
import pydicom


from cls_utils.utils import check_and_makedirs

from collections import namedtuple
NoduleBox = namedtuple('NoduleBox', ['z', 'y', 'x', 'diameter', 'uid', 'dicom_path'])

def get_cts(bundle, dicom_path):
    cts = CTSeries()
    start_time = time.time()
    #读取dicom文件
    cts.load_dicoms(dicom_path)
    print('start load dicoms time {:.5f}(s)'.format(time.time() - start_time))
    return cts

def get_nodule_y_pixel_length(spacing, nodule_box):
    return int(np.ceil(nodule_box.diameter / spacing[1]))


def get_diameter_pixel_length(spacing, diameter):
    length = np.zeros(3, np.int)
    for i in range(3):
        length[i] = int(np.ceil(diameter / spacing[i]))
    return length


def get_nodule_rect(spacing, nodule_box):
    diameter_pixel_length = get_diameter_pixel_length(spacing, nodule_box.diameter)
    center = np.array([nodule_box.z, nodule_box.y, nodule_box.x])
    rect = np.zeros((3, 2), np.int)
    for i in range(3):
        rect[i][0] = center[i] - diameter_pixel_length[i] // 2
        rect[i][1] = rect[i][0] + diameter_pixel_length[i]
    return rect


def resample(data, spacing, new_spacing=[1.0, 1.0, 1.0], order=1):
    if data is None:
        return None
    new_shape = np.round(data.shape * spacing / new_spacing)

    resample_spacing = spacing * data.shape / new_shape

    resize_factor = new_shape / data.shape

    new_data = scipy.ndimage.interpolation.zoom(data, resize_factor, mode='nearest', order=order)

    return new_data, resample_spacing


def resample_nodule_box(nodule_box, spacing, new_spacing=[1.0, 1.0, 1.0]):
    if nodule_box is None:
        return None
    z = int(np.ceil(nodule_box.z * spacing[0] / new_spacing[0]))
    y = int(np.ceil(nodule_box.y * spacing[1] / new_spacing[1]))
    x = int(np.ceil(nodule_box.x * spacing[2] / new_spacing[2]))
    new_box = NoduleBox(int(z), int(y), int(x), nodule_box.diameter, nodule_box.uid, nodule_box.dicom_path)
    return new_box


def maxip(input, level_num=2):
    output = np.zeros(input.shape, input.dtype)
    for i in range(len(input)):
        length = level_num if i >= level_num else i
        output[i] = np.max(input[i-length:i+level_num+1], axis=0)
    return output


def minip(input, level_num=2):
    output = np.zeros(input.shape, input.dtype)
    for i in range(len(input)):
        length = level_num if i >= level_num else i
        output[i] = np.min(input[i-length:i+level_num+1], axis=0)
    return output


class CTSeries(object):
    def __init__(self):
        self._SeriesInstanceUID = None
        self._SOPInstanceUIDs = None
        self._raw_image = None
        self._raw_origin = None
        self._raw_spacing = None
        self._raw_direction = None
        self._dicoms_is_loaded = False

    def load_dicoms(self, folder_path):
        logging.info('{}, Loading dicoms from {}...'.format(
            time.strftime('%Y-%m-%d %H:%M:%S'), folder_path))

        #print(folder_path)
        dicom_names = [f for f in os.listdir(folder_path) if '.xml' not in f]
        dicom_paths = list(map(lambda x: os.path.join(folder_path, x), dicom_names))
        dicoms = list(map(lambda x: pydicom.read_file(x), dicom_paths))
        # slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms))
        # sort slices by their z coordinates from large to small
        # idx_z_sorted = np.argsort(slice_locations)[::-1]

        try:
            slice_locations = list(map(lambda x: float(x.ImagePositionPatient[2]), dicoms))
        except AttributeError:
            try:
                slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms))
            except AttributeError:
                slice_locations = []
                for i in range(len(dicoms)):
                    try:
                        slice_locations.append(float(dicoms[i].ImagePositionPatient[2]))
                    except AttributeError:
                        print(i, dicoms[i].SeriesInstanceUID)

        patient_position = dicoms[0].PatientPosition
        self._SeriesInstanceUID = dicoms[0].SeriesInstanceUID

        if patient_position in ['FFP', 'FFS']:
            idx_z_sorted = np.argsort(slice_locations)[::-1]
        else:
            idx_z_sorted = np.argsort(slice_locations)[::-1]
        
        #将dicom文件按照指定的idx_z_sorted进行重排
        dicoms = list(map(lambda x: dicoms[x], idx_z_sorted))

        self._SOPInstanceUIDs = np.array(list(map(lambda x: x.SOPInstanceUID, dicoms)))

        #dicom_path_before = np.array(dicom_paths)
        #print(dicom_path_before)

        dicom_paths = np.array(dicom_paths)[idx_z_sorted]
        #print(dicom_paths)
        #print(self._SOPInstanceUIDs)
        #dicoms = np.array(dicoms)[idx_z_sorted]

        

        reader = sitk.ImageSeriesReader()
        reader.SetFileNames(dicom_paths)
        image_itk = reader.Execute()

        # all in [z, y, x] order
        self._raw_image = sitk.GetArrayFromImage(image_itk)
        self._raw_origin = np.array(list(reversed(image_itk.GetOrigin())))
        self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing())))
        self._raw_direction = image_itk.GetDirection()

        # print('raw_image', self._raw_image.shape, 'raw_spacing', self._raw_spacing)
        # print('raw_origin', self._raw_origin, 'raw_direction', self._raw_direction)
        self._dicoms_is_loaded = True

    def load_single_file(self, file_path):
        logging.info('{}, Loading file from {}...'.format(
            time.strftime('%Y-%m-%d %H:%M:%S'), file_path))

        image_itk = sitk.ReadImage(file_path)

        # all in [z, y, x] order
        self._raw_image = sitk.GetArrayFromImage(image_itk)
        self._raw_origin = np.array(list(reversed(image_itk.GetOrigin())))
        self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing())))
        self._raw_direction = image_itk.GetDirection()

        # print('raw_image', self._raw_image.shape, 'raw_spacing', self._raw_spacing)
        # print('raw_origin', self._raw_origin, 'raw_direction', self._raw_direction)
        self._dicoms_is_loaded = True

    def get_raw_image(self):
        return self._raw_image

    def set_raw_image(self, data):
        self._raw_image = data

    def get_raw_origin(self):
        return self._raw_origin

    def get_raw_spacing(self):
        return self._raw_spacing

    def get_raw_image_affine(self):
        affine = np.diag(list(self._raw_spacing) + [1])
        affine[:, 3][:3] = np.array(list(self._raw_origin))
        return affine

    def save_raw_image(self, output_file):
        if self._dicoms_is_loaded:
            self.save_image(self._raw_image, output_file)

    def save_image(self, data, output_file):
        if self._dicoms_is_loaded:
            image = sitk.GetImageFromArray(data)
            image.SetSpacing(np.array(list(reversed(self._raw_spacing))))
            image.SetOrigin(np.array(list(reversed(self._raw_origin))))
            image.SetDirection(self._raw_direction)
            check_and_makedirs(output_file)
            sitk.WriteImage(image, output_file)

    def transform_file_type(self, input_file, output_file):
        image = sitk.ReadImage(input_file)
        sitk.WriteImage(image, output_file)