import os
import numpy as np

"""
Describe:
    根据传过来的中心点对数据进行切割
Parameters:
    data: 所有的切面数据
    crop_start: 在切面中选择出来的大小为(48, 256, 256)区域的三个维度的中心点的坐标
"""
def get_crop_data(data, crop_start, crop_size, mode='constant', constant_values=-1000):
    if data is None:
        return data
    data_shape = np.array(data.shape)
    crop_data = data[
        max(crop_start[0], 0):min(crop_start[0] + crop_size[0], data_shape[0]),
        max(crop_start[1], 0):min(crop_start[1] + crop_size[1], data_shape[1]),
        max(crop_start[2], 0):min(crop_start[2] + crop_size[2], data_shape[2])]
    
    # pad = np.zeros((3, 2), np.int)
    pad = np.zeros((3, 2), np.int32)
    #np.maximum(x1, x2) 用于逐个比较其中元素的大小,最后返回的是较大的形成的数据
    #往前需要填充多少数据
    pad[:, 0] = np.maximum(-crop_start, 0)
    pad[:, 1] = np.maximum(crop_start + crop_size - data_shape, 0)

    if not np.all(pad == 0):
        #pad表示在每个维度上要填充的长度
        if mode == 'edge':
            crop_data = np.pad(crop_data, pad, mode=mode)
        else:
            crop_data = np.pad(crop_data, pad, mode=mode, constant_values=constant_values)
    
    return crop_data

def get_crop_data_padding_2d_opt(ct_data=None, select_box=None, crop_size=None, mode='constant', constant_values=-1000):
    # 获取原始数据
    data = ct_data.get_raw_image()
    data_shape = np.array(data.shape)
    crop_size = np.array(crop_size)

    crop_start = np.int0((select_box[:, 0] + select_box[:, 1])//2 - crop_size // 2)

    crop_data = get_crop_data(data=data, 
                             crop_start=crop_start, 
                             crop_size=crop_size, 
                             mode=mode, 
                             constant_values=constant_values)
    
    return crop_data



"""
只保留一层,将该层数据放置中心点,别的层都填充为-1000
crop_size=[48, 256, 256]
"""
def get_crop_data_padding(ct_data=None, select_box=None, crop_size=None, mode='constant', constant_values=-1000):
    data = ct_data.get_raw_image()
    data_shape = np.array(data.shape)
    crop_size = np.array(crop_size)
    z_index = int(select_box[0][0])

    crop_start = np.int0((select_box[1:, 0] + select_box[1:, 1])//2 - crop_size[1:] // 2)
    crop_data = data[z_index:z_index+1,
                     max(crop_start[0], 0) : min(crop_start[0] + crop_size[1], data_shape[1]),
                     max(crop_start[1], 0) : min(crop_start[1] + crop_size[2], data_shape[2])]

    # pad = np.zeros((3, 2), np.int)
    pad = np.zeros((3, 2), np.int32)

    pad[1:, 0] = np.maximum(-crop_start, 0)
    pad[1:, 1] = np.maximum(crop_start + crop_size[1:] - data_shape[1:], 0)

    if not np.all(pad == 0):
        crop_data = np.pad(crop_data, pad, mode=mode, constant_values=constant_values)
    
    result_data = np.full(crop_size, -1000)
    result_data[crop_size[0]//2] = crop_data

    return result_data


#从当前(1024, 1024)的图片数据中获取到指定坐标区域的数据,如果所在区域不够(256, 256),则需要相对应的补充-1000
#select_box=[[z_index, z_index], [y_min, y_max], [x_min, x_max]]
#crop_size=[256. 256]
def get_crop_data_2d(data=None, select_box=None, crop_size=None, mode='constant', constant_values=-1000):
    data_shape = np.array(data.shape)
    crop_size = np.array(crop_size)

    crop_start = np.int0((select_box[1:, 0] + select_box[1:, 1])//2 - crop_size[:]//2)
    crop_data = data[max(crop_start[0], 0) : min(crop_start[0] + crop_size[0], data_shape[0]),
                     max(crop_start[1], 0) : min(crop_start[1] + crop_size[1], data_shape[1])]
    
    
    # pad = np.zeros((2, 2), np.int)
    pad = np.zeros((2, 2), np.int32)
    
    pad[:, 0] = np.maximum(-crop_start, 0)
    pad[:, 1] = np.maximum(crop_start + crop_size[:] - data_shape[:], 0)

    if not np.all(pad == 0):
        crop_data = np.pad(crop_data, pad, mode=mode, constant_values=constant_values)

    """result_data = np.full(crop_size, -1000)
    result_data[crop_size[0]//2] = crop_data"""
    result_data = crop_data

    return result_data
    

'''
#将对切面数据继续处理,返回(48, 256, 256)大小的数据
def crop_ct_data(ct_data=None, select_box=None, crop_size=None):
    data = ct_data.get_raw_image()
    #print(data.shape)
    spacing = ct_data.get_raw_spacing()

    z_center = int((select_box[0, 0] + select_box[0, 1]) // 2)
    
    crop_size = np.array(crop_size)
    #以下操作包含原box的情况下,并z_center处于中间位置
    #z_diff=crop_size[0] - (select_box[0, 1] - select_box[0, 0])
    z_diff = crop_size[0] - (select_box[0, 1] - select_box[0, 0]) - 1
    #所选择的层数小于48
    if z_diff > 0:
        #新增前面层数
        add_front_z = 0
        #新增后面层数
        add_behind_z = 0
        #select_box所对应的中间层距离z_min相差的层数
        z_front_num = z_center - select_box[0, 0]
        #select_box所对应的中间层距离z_max相差的层数
        z_behind_num = select_box[0, 1] - z_center
        #这里add_front_z为1或0
        add_front_z = min(z_behind_num - z_front_num, z_diff)
        #当z_diff != 1 时
        if add_front_z < z_diff:
            add_front_z = add_front_z + (z_diff - add_front_z) // 2
            add_behind_z = z_diff - add_front_z
        else:
            add_front_z = add_front_z
            add_behind_z = 0
        select_box[0, 0] = select_box[0, 0] - add_front_z
        select_box[0, 1] = select_box[0, 1] + add_behind_z
    scale = (1, 1, 1)
    #scale = (1, 1, 1) if spacing[0] > 0.5 else (2, 1, 1)
    crop_size = np.int0(crop_size * scale)
    crop_start = np.int0((select_box[:, 0] + select_box[:, 1])//2 - crop_size // 2)
    #根据尺寸大小对数据进行切割
    original_data = get_crop_data(data=data, crop_start=crop_start, crop_size=crop_size)
    return original_data
'''

def crop_ct_data(ct_data=None, select_box=None, crop_size=None):
    data = ct_data.get_raw_image()
    #print(data.shape)
    spacing = ct_data.get_raw_spacing()

    z_center = int((select_box[0, 0] + select_box[0, 1]) // 2)
    
    crop_size = np.array(crop_size)
    #以下操作包含原box的情况下,并z_center处于中间位置
    #z_diff=crop_size[0] - (select_box[0, 1] - select_box[0, 0])
    z_diff = crop_size[0] - (select_box[0, 1] - select_box[0, 0])
    #所选择的层数小于48
    if z_diff > 0:
        #新增前面层数
        add_front_z = 0
        #新增后面层数
        add_behind_z = 0
        #select_box所对应的中间层距离z_min相差的层数
        z_front_num = z_center - select_box[0, 0]
        #select_box所对应的中间层距离z_max相差的层数
        z_behind_num = select_box[0, 1] - z_center
        if z_behind_num >= z_front_num > 0:
            add_front_z = min(z_behind_num - z_front_num, z_diff)
            if add_front_z < z_diff:
                add_front_z = add_front_z + (z_diff - add_front_z) // 2
                add_behind_z = z_diff - add_front_z
        elif z_front_num > z_behind_num > 0:
            add_behind_z = min(z_front_num - z_behind_num, z_diff)
            if add_behind_z < z_diff:
                add_behind_z = add_behind_z + (z_diff - add_behind_z) // 2
                add_front_z = z_diff - add_behind_z
        select_box[0, 0] = select_box[0, 0] - add_front_z
        select_box[0, 1] = select_box[0, 1] + add_behind_z
    scale = (1, 1, 1)
    #scale = (1, 1, 1) if spacing[0] > 0.5 else (2, 1, 1)
    crop_size = np.int0(crop_size * scale)
    crop_start = np.int0((select_box[:, 0] + select_box[:, 1])//2 - crop_size // 2)

    #根据尺寸大小对数据进行切割
    original_data = get_crop_data(data=data, crop_start=crop_start, crop_size=crop_size)
    return original_data

#将自动分割出来的数据都将其每个切面都作为中心面,然后其余地方进行填充成(48, 256, 256)的数据进行预测