import os
import sys
import copy
import math
import time
import glob
import random
import datetime
import numpy as np
from time import time
from scipy.ndimage import zoom
from skimage.draw import line_aa
# from matplotlib.lines import Line2D 
# from matplotlib import pyplot as plt
from scipy.ndimage import gaussian_filter
from skimage.measure import label,regionprops
from collections import defaultdict,OrderedDict
from scipy.ndimage.measurements import center_of_mass
from skimage.morphology import binary_erosion,binary_dilation

def Image_preprocess(image,deltas,ratio=1.0,input_shape=[64,64,64,1]):
    '''
    Define image preprocess function
    1. crop center patch with target_size*ratio
    2. zoom center patch to target input_shape of network
    @image image: Target image.(float) 
    @deltas: int array of length=3. Translation param on (z,y,x)
    @ratio: float value around 1.0. Define shape of center patch
    @input_shape: int array of length>=3. First 3 elements correspond to (depth,height,width) of network input
    @Return: final_image. 5D array. Image after preprocessing
    
    '''
    center = [int(val/2) for val in image.shape]
    target_size = [int(round(ratio*val)) for val in input_shape[:3]]
    
    image_center_patch = image[int(max(0,center[0]-target_size[0]/2+deltas[0])):int(center[0]+target_size[0]/2+deltas[0]),
                              int(max(0,center[1]-target_size[1]/2+deltas[1])):int(center[1]+target_size[1]/2+deltas[1]),
                              int(max(0,center[1]-target_size[1]/2+deltas[2])):int(center[1]+target_size[1]/2+deltas[2])]
    
    #### zoom image_center_patch if its shape is not equal to input_shape
    zoom_values = [input_shape[idx]/float(image_center_patch.shape[idx]) for idx in range(len(image_center_patch.shape))]
    
    target_region = zoom(image_center_patch,zoom=zoom_values,order=1)
    final_image = target_region[np.newaxis,:,:,:,np.newaxis]
    return final_image

def pad_mask(final_mask):
    '''
    Pad mask image if there is inconsistent inside.
    before padding.
    np.sum(blank_slice) = 0
    after padding
    blank_slice = 0.5*(slice_before_blank_slice + slice_after_blank_slice)
    
    '''
    length = len(final_mask.shape)
    final_mask = np.squeeze(final_mask)
    nonzero_slices = []
    for slice_idx in range(final_mask.shape[0]):
        if np.sum(final_mask[slice_idx])>0:
            nonzero_slices.append(slice_idx)
    if len(nonzero_slices)>0:
        min_slice,max_slice = min(nonzero_slices),max(nonzero_slices)
        for slice_idx in range(min_slice,max_slice):
            if slice_idx == 0 or slice_idx == final_mask.shape[0]-1:
                continue
            current_slice = final_mask[slice_idx]
            if np.sum(current_slice)==0:
                target_result = final_mask[slice_idx-1] + final_mask[slice_idx+1]
                target_result = target_result>0
                final_mask[slice_idx] = target_result
    if length==4:
        return final_mask[:,:,:,np.newaxis]
    else:    
        return final_mask[np.newaxis,:,:,:,np.newaxis]

def Data_preprocess(image,mask,aug,ratio=1.0,input_shape=[64,64,64,1],shift_ratio=1.2):
    '''
    Preprocess image& mask with same parameter
    image: original image
    mask: original mask(binary)
    image and mask should have same shape
    aug: bool. Whether do augmentation or not
    @Return. final_image,final_mask are both 5D array after preprocessing
    
    '''
    
    '''
    Extract the largest connected component from mask image if there exist more than 2 connected components
    '''
    mask = mask>0
    lb = label(np.squeeze(mask))
    final_mask = mask
    
    max_area = 0
    final_bbox = 0
    for region in regionprops(lb,intensity_image=None):
        area = region.area
        if area>max_area:
            max_area = area
            final_mask = mask == region.label
            final_bbox = region.bbox
    
    '''
    Calculate range of shift(deltas) based on diameter of current nodule
    '''
    margin_value = 5
    depth,height,width = input_shape[:3]
    
    final_mask = final_mask>0
    num_pixels = np.sum(mask)
    diameter = int(float(num_pixels**(1/3.0)))
    radius = int(diameter/2)
    
    x_distance = int(min(max(3,radius*shift_ratio),height/2-(radius*shift_ratio+margin_value)))
    z_distance = int(x_distance/2)

    
    ### z_distance = 2
    z_distance,x_distance = 2,5
    
    if aug:
        delta_0 = np.random.randint(low=-z_distance,high=z_distance,size=1)[0]
        deltas = np.random.randint(low=-x_distance,high=x_distance,size=3)
        deltas[0] = delta_0

    else:
        deltas = [0,0,0]
    
    '''
    Crop same region on mask image and ori image to target shape
    '''
    final_image = Image_preprocess(image,deltas,ratio=ratio,input_shape=input_shape)
    final_mask = Image_preprocess(mask,deltas,ratio=ratio,input_shape=input_shape)
    final_mask = final_mask>0
    final_mask = pad_mask(final_mask)
    
    return final_image,final_mask

def DataNormlization(data,info,HU_min=-1024,HU_max=1024,mode='mean'):
    result = np.zeros_like(data)
    for data_idx, (current_data,current_info) in enumerate(zip(data,info)):
        max_value,mean_value,std_value,min_value = current_info[-4:].astype('float')
        if mode == 'mean':
            current_mean,current_std = mean_value,std_value
        elif mode == 'HU':
            current_data = np.clip(current_data,HU_min,HU_max)
            current_mean,current_std = HU_min,float(HU_max-HU_min)
        elif mode == 'self':
            current_mean,current_std = np.mean(current_data),np.std(current_data)
        else:
            current_mean,current_std = min_value,(max_value-min_value)
        
        current_data = (current_data - current_mean)/current_std
        result[data_idx] = current_data
    return result


def CalculateMaxDisatnce(mask):
    '''
    calculate max length of mask in x and y direction by slice.

    '''
    center = int(mask.shape[1]/2)
    distance = 0
    for mask_slice in mask:
        if np.sum(mask_slice)==0:
            continue
        points = np.where(mask>0)
        y_min,y_max,x_min,x_max = np.amin(points[0]),np.amax(points[0]),np.amin(points[1]),np.amax(points[1])
        distance = max(distance,max(y_max-center+1,center-y_min+1,x_max-center+1,center-x_min+1))
    return distance

def CalculatePatchSizeWithDiameter(mask,stride_ratio=8,diameter_enlarge_ratio=1.5,direction_z_ratio=1,max_patch_size=64,min_patch_size=8,aug=True):
    '''
    calculate patch size of nodule based on its diameter and
    1. stride_ratio: patch_size should be stride_ratio * N. stride_ratio is same as the one in the network
    2. diameter_enlarge_ratio. how much ratio should be left in the patch
    3. maximum patch size is 64,minimum patch size is 8
    '''

    diameter_px = CalculateMaxDisatnce(mask)
    min_patch_size = stride_ratio
    if aug:
        diameter_ratio = np.random.randint(90,120,size=1)[0]/100.0
    else:
        diameter_ratio = 1.0

    diameter_px *= diameter_ratio
    target_patch_size = math.ceil(diameter_px/float(stride_ratio) * diameter_enlarge_ratio )*stride_ratio
    target_inplane_size = min(max_patch_size,max(min_patch_size,int(target_patch_size)))

    return [(target_inplane_size*direction_z_ratio),target_inplane_size,target_inplane_size,1]