import os
import sys
import copy
import math
import time
import glob
import random
import datetime
import numpy as np
import SimpleITK as sitk
from time import time
from scipy.ndimage import zoom
from skimage.draw import line_aa
# from matplotlib.lines import Line2D 
# from matplotlib import pyplot as plt
from scipy.ndimage import gaussian_filter
from skimage.measure import label,regionprops
from collections import defaultdict,OrderedDict
from scipy.ndimage.measurements import center_of_mass
from skimage.morphology import binary_erosion,binary_dilation
from scipy.ndimage import distance_transform_edt as distance
from .respacing_func import *

def Image_preprocess(image,deltas,ratio=1.0,input_shape=[64,64,64,1]):
    '''
    Define image preprocess function
    1. crop center patch with target_size*ratio
    2. zoom center patch to target input_shape of network
    @image image: Target image.(float) 
    @deltas: int array of length=3. Translation param on (z,y,x)
    @ratio: float value around 1.0. Define shape of center patch
    @input_shape: int array of length>=3. First 3 elements correspond to (depth,height,width) of network input
    @Return: final_image. 5D array. Image after preprocessing
    
    '''
    center = [int(val/2) for val in image.shape]
    target_size = [int(round(ratio*val)) for val in input_shape[:3]]
    
    image_center_patch = image[int(max(0,center[0]-target_size[0]/2+deltas[0])):int(center[0]+target_size[0]/2+deltas[0]),
                              int(max(0,center[1]-target_size[1]/2+deltas[1])):int(center[1]+target_size[1]/2+deltas[1]),
                              int(max(0,center[1]-target_size[1]/2+deltas[2])):int(center[1]+target_size[1]/2+deltas[2])]
    
    #### zoom image_center_patch if its shape is not equal to input_shape
    zoom_values = [input_shape[idx]/float(image_center_patch.shape[idx]) for idx in range(len(image_center_patch.shape))]
    
    target_region = zoom(image_center_patch,zoom=zoom_values,order=1)
    final_image = target_region[np.newaxis,:,:,:,np.newaxis]
    return final_image

def pad_mask(final_mask):
    '''
    Pad mask image if there is inconsistent inside.
    before padding.
    np.sum(blank_slice) = 0
    after padding
    blank_slice = 0.5*(slice_before_blank_slice + slice_after_blank_slice)
    
    '''
    length = len(final_mask.shape)
    final_mask = np.squeeze(final_mask)
    nonzero_slices = []
    for slice_idx in range(final_mask.shape[0]):
        if np.sum(final_mask[slice_idx])>0:
            nonzero_slices.append(slice_idx)
    if len(nonzero_slices)>0:
        min_slice,max_slice = min(nonzero_slices),max(nonzero_slices)
        for slice_idx in range(min_slice,max_slice):
            if slice_idx == 0 or slice_idx == final_mask.shape[0]-1:
                continue
            current_slice = final_mask[slice_idx]
            if np.sum(current_slice)==0:
                target_result = final_mask[slice_idx-1] + final_mask[slice_idx+1]
                target_result = target_result>0
                final_mask[slice_idx] = target_result
    if length==4:
        return final_mask[:,:,:,np.newaxis]
    else:    
        return final_mask[np.newaxis,:,:,:,np.newaxis]

def Data_preprocess(image,mask,aug,ratio=1.0,input_shape=[64,64,64,1],shift_ratio=1.2):
    '''
    Preprocess image& mask with same parameter
    image: original image
    mask: original mask(binary)
    image and mask should have same shape
    aug: bool. Whether do augmentation or not
    @Return. final_image,final_mask are both 5D array after preprocessing
    
    '''
    
    '''
    Extract the largest connected component from mask image if there exist more than 2 connected components
    '''
    mask = mask>0
    lb = label(np.squeeze(mask))
    final_mask = mask
    
    max_area = 0
    final_bbox = 0
    for region in regionprops(lb,intensity_image=None):
        area = region.area
        if area>max_area:
            max_area = area
            final_mask = mask == region.label
            final_bbox = region.bbox
    
    '''
    Calculate range of shift(deltas) based on diameter of current nodule
    '''
    margin_value = 5
    depth,height,width = input_shape[:3]
    
    final_mask = final_mask>0
    num_pixels = np.sum(mask)
    diameter = int(float(num_pixels**(1/3.0)))
    radius = int(diameter/2)
    
    x_distance = int(min(max(3,radius*shift_ratio),height/2-(radius*shift_ratio+margin_value)))
    z_distance = int(x_distance/2)

    
    ### z_distance = 2
    z_distance,x_distance = 2,5
    
    if aug:
        delta_0 = np.random.randint(low=-z_distance,high=z_distance,size=1)[0]
        deltas = np.random.randint(low=-x_distance,high=x_distance,size=3)
        deltas[0] = delta_0

    else:
        deltas = [0,0,0]
    
    '''
    Crop same region on mask image and ori image to target shape
    '''
    final_image = Image_preprocess(image,deltas,ratio=ratio,input_shape=input_shape)
    final_mask = Image_preprocess(mask,deltas,ratio=ratio,input_shape=input_shape)
    final_mask = final_mask>0
    final_mask = pad_mask(final_mask)
    
    return final_image,final_mask

def DataNormlization(data,info,HU_min=-1024,HU_max=1024,mode='mean'):
    result = np.zeros_like(data)
    for data_idx, (current_data,current_info) in enumerate(zip(data,info)):
        max_value,mean_value,std_value,min_value = current_info[-4:].astype('float')
        if mode == 'mean':
            current_mean,current_std = mean_value,std_value
        elif mode == 'HU':
            current_data = np.clip(current_data,HU_min,HU_max)
            current_mean,current_std = HU_min,float(HU_max-HU_min)
        elif mode == 'self':
            current_mean,current_std = np.mean(current_data),np.std(current_data)
        else:
            current_mean,current_std = min_value,(max_value-min_value)
        
        current_data = (current_data - current_mean)/current_std
        result[data_idx] = current_data
    return result


def CalculateMaxDisatnce(mask):
    '''
    calculate max length of mask in x and y direction by slice.
    逐层计算x/y方向距离中心点(半边bbox)的最大值,取最大,最后*2

    '''
    center = int(mask.shape[1]/2)
    distance = 0
    for mask_slice in mask:
        if np.sum(mask_slice)==0:
            continue
        points = np.where(mask>0)
        y_min,y_max,x_min,x_max = np.amin(points[0]),np.amax(points[0]),np.amin(points[1]),np.amax(points[1])
        # distance = max(distance,max(y_max-center+1,center-y_min+1,x_max-center+1,center-x_min+1))
        distance = max(distance,max(y_max-y_min,x_max-x_min)+1)
    return distance

def CalculatePatchSizeWithDiameter(mask,stride_ratio=8,diameter_enlarge_ratio=1.5,direction_z_ratio=1,max_patch_size=64,min_patch_size=8,aug=True,diameter_px=None):
    '''
    calculate patch size of nodule based on its diameter and
    1. stride_ratio: patch_size should be stride_ratio * N. stride_ratio is same as the one in the network
    2. diameter_enlarge_ratio. how much ratio should be left in the patch
    3. maximum patch size is 64,minimum patch size is 8
    '''
    if diameter_px is None:
        diameter_px = CalculateMaxDisatnce(mask)
    if aug:
        diameter_ratio = np.random.randint(90,120,size=1)[0]/100.0
    else:
        diameter_ratio = 1.0

    diameter_px *= diameter_ratio
    target_patch_size = math.ceil(diameter_px/float(stride_ratio) * diameter_enlarge_ratio )*stride_ratio
    target_inplane_size = min(max_patch_size,max(min_patch_size,int(target_patch_size)))

    return [(target_inplane_size*direction_z_ratio),target_inplane_size,target_inplane_size,1]


def CalculatePatchSizeWithDiameterV2(mask,stride_ratio=8,diameter_enlarge_ratio=1.5,direction_z_ratio=1,max_patch_size=64,min_patch_size=8,aug=True,diameter_px=None):
    '''
    calculate patch size of nodule based on its diameter and
    1. stride_ratio: patch_size should be stride_ratio * N. stride_ratio is same as the one in the network
    2. diameter_enlarge_ratio. how much ratio should be left in the patch
    3. maximum patch size is 64,minimum patch size is 8
    '''
    if diameter_px is None:
        diameter_px = CalculateMaxDisatnce(mask)
        
    '''
    train的阶段加入一些patch大小的变换以应对 mask计算长径和 alpha分割阶段输入长径大小不同的问题
    '''
    if aug:
        diameter_ratio = np.random.randint(90,150,size=1)[0]/100.0
    else:
        diameter_ratio = 1.0

    '''
    小结节用小stride,大结节用大的stride
    '''
    if diameter_px<2*min_patch_size:
        stride_ratio = stride_ratio
    else:
        stride_ratio = max(min_patch_size,stride_ratio)
        
    diameter_px *= diameter_ratio
    target_patch_size = int(round(diameter_px/float(stride_ratio) * diameter_enlarge_ratio )*stride_ratio)
    target_inplane_size = min(max_patch_size,max(min_patch_size,int(target_patch_size)))

    return [(target_inplane_size*direction_z_ratio),target_inplane_size,target_inplane_size,1]


def CalculatePatchSizeWithDiameterV3(mask,stride_ratio=8,diameter_enlarge_ratio=1.5,direction_z_ratio=1,max_patch_size=64,min_patch_size=8,aug=True,diameter_px=None):
    '''
    calculate patch size of nodule based on its diameter and
    1. stride_ratio: patch_size should be stride_ratio * N. stride_ratio is same as the one in the network
    2. diameter_enlarge_ratio. how much ratio should be left in the patch
    3. maximum patch size is 64,minimum patch size is 8
    '''
    if diameter_px is None:
        diameter_px = CalculateMaxDisatnce(mask)
    if aug:
        diameter_ratio = np.random.randint(90,150,size=1)[0]/100.0
    else:
        diameter_ratio = 1.0
     
    diameter_px *= diameter_ratio
    target_patch_size = int(round(diameter_px/float(stride_ratio) * diameter_enlarge_ratio )*stride_ratio)
    target_inplane_size = min(max_patch_size,max(min_patch_size,int(target_patch_size)))

    return [(target_inplane_size*direction_z_ratio),target_inplane_size,target_inplane_size,1]



def calc_dist_map(seg):
    ##### distance val increase from boundary
    res = np.zeros_like(seg)
    posmask = seg.astype(np.bool)
    if posmask.any():
        negmask = ~posmask
        neg_distance = distance(negmask)*negmask
        pos_distance = -(distance(posmask)-1)*posmask
        res = distance(negmask)*negmask-(distance(posmask)-1)*posmask
    return res

def calc_dist_map_batch(y_true):
    y_true_numpy = y_true
    return np.array([calc_dist_map(y) for y in y_true_numpy]).astype(np.float32)


def RespacingImage(image,mask,target_spacing = (4,4,4)):
    
    itk_mask = sitk.GetArrayFromImage(mask)
    itkImage,OriginalSpacing,OriginalSize = LinearResample(image, target_spacing)

    itk_mask.SetSpacing(OriginalSpacing)

    itkLung = NearestResample(itk_mask, target_spacing)
    itkLung.SetDirection(itkImage.GetDirection())
    itkLung.SetOrigin(itkImage.GetOrigin())
    itkLung.SetSpacing(itkImage.GetSpacing())
    return itkImage, itkLung,OriginalSpacing,OriginalSize

def imagepair_respacing(image,mask,ori_spacing,target_spacing):
    itk_image = sitk.GetImageFromArray(image)
    itk_mask = sitk.GetImageFromArray(mask)
    
    itk_image.SetSpacing(ori_spacing[::-1])
    itk_mask.SetSpacing(ori_spacing[::-1])
    
    new_itkImage,OriginalSpacing,OriginalSize = LinearResample(itk_image, target_spacing)

    itk_mask.SetSpacing(OriginalSpacing)

    new_itkMask = NearestResample(itk_mask, target_spacing,new_itkImage.GetSize())
    new_itkMask.SetSpacing(new_itkImage.GetSpacing())
    return new_itkImage,new_itkMask

def PadToTargetShape(image,target_shape):
    diff = [max(0,y-x) for y,x in zip(target_shape,image.shape)]
    left_pad_values = [int(val/2) for val in diff]
    right_pad_values = [y-x for y,x in zip(diff,left_pad_values)]
    pad_values = [(x,y) for x,y in zip(left_pad_values,right_pad_values)]
    
    image = np.pad(image,pad_values,mode='constant')
    return image

def CropToTargetShape(image,target_shape):
    diff = [max(0,y-x) for y,x in zip(image.shape,target_shape)]
    left_diff = [int(val/2) for val in diff]
    return image[left_diff[0]:target_shape[0]+left_diff[0],
              left_diff[1]:target_shape[1]+left_diff[1],
                 left_diff[2]:target_shape[2]+left_diff[2]]

    
def GetTargetShapePatch(image,target_shape):
    image = PadToTargetShape(image,target_shape)
    image = CropToTargetShape(image,target_shape)
    return image

def imagepari_target_shape_respacing(image,mask,ori_spacing,target_shape):
    image = np.squeeze(image)
    mask = np.squeeze(mask)
    mask = mask.astype(np.uint8)
    ratios = [val_1/float(val_2) for val_1,val_2 in zip(image.shape,target_shape)]
    new_spacing = [spac_val*ratio_val for spac_val,ratio_val in zip(ori_spacing,ratios)]
    target_spacing = new_spacing[::-1]

    new_itkImage,new_itkMask = imagepair_respacing(image,mask,ori_spacing,target_spacing)
    new_image,new_mask = sitk.GetArrayFromImage(new_itkImage),sitk.GetArrayFromImage(new_itkMask)
    image,mask = GetTargetShapePatch(new_image,target_shape),GetTargetShapePatch(new_mask,target_shape)
#     print ('final image shape is',image.shape,mask.shape)
    return image,mask,new_spacing