import os
import sys
import glob
import numpy as np
import pandas as pd
import SimpleITK as sitk
from imageIO import loadImage
from respacing_func import *
from skimage.measure import *

def get_bbox(image,val=None):
    if val is None:
        points = np.where(image>0)
    else:
        points = np.where(image==val)
        
    bbox = []
    for idx in range(len(points)):
        if len(points[idx])>0:
            min_val,max_val = np.amin(points[idx]),np.amax(points[idx])
            bbox += [min_val,max_val]
    return bbox
    



def PadImage(image,offset,pad_val=0):
    '''
    在image两侧都pad(offset)大小
    
    '''
    image = np.squeeze(image)
    image_shape = image.shape
    new_image_shape = [image_shape[idx] + offset[idx] * 2 for idx in range(3)]

    temp_image = np.ones(new_image_shape)*pad_val
    margin_0,margin_1,margin_2 = offset
    depth,height,width = image_shape
    temp_image[margin_0:margin_0+depth,margin_1:margin_1+height,margin_2:margin_2+width] = image
    return temp_image


def CropPatch(image,center,offset):
    '''
    以center为中心,在image裁取2倍offset大小的的patch
    '''
    center_z,center_y,center_x = [int(round(float(val))) for val in center]
    offset_z,offset_y,offset_x = [int(round(float(val))) for val in offset]
    current_patch = image[center_z-offset_z:center_z+offset_z,
                        center_y-offset_y:center_y+offset_y,
                         center_x-offset_x:center_x+offset_x]
    return current_patch



def GeneratePatchFromDataFrame(uid,df,output_dim,coord_kws,type_kws,pad_val=0.0,mode='image',crop_mode='mask',spacing=None,kw1=None,kw2=None):
    '''
    input: 
    uid
    df: dataframe
    output_dim:patch的大小
    info_kws:按照顺序包含[图像路径,中心点(3),长径(4),征象]信息
    '''
    
    data_list,info_list,mask_list = [],[],[]
    
    target_df = df[df['uid']==uid]
    target_df.reset_index(drop=True,inplace=True)
    
    info_kws = list(coord_kws)+list(type_kws)
    records = target_df[info_kws].values
    if len(records)==0:
        return None,None
    image_path = records[0][0]
    if not os.path.exists(image_path):
        return None,None
    
#     sub_str = str(spacing[0])
    image_itk_ori,image_ori = loadImage(image_path)
    
    if spacing is not None:
        
        # target_image_path = image_path.replace('RAW_NII','SPACING_%.1f_NII'%spacing[0])
        target_image_path = image_path.replace(kw1,kw2)
        # print ('='*60)
        print ('target_image_path is ',target_image_path)
        if os.path.exists(target_image_path):
            image_itk,image = loadImage(target_image_path)
        else:
            respacing_result = LinearResample(image_itk,spacing)
            image_itk = respacing_result[0]
            image = sitk.GetArrayFromImage(image_itk)
    else:
        image_itk,image = image_itk_ori,image_ori
            
    image_spacing = image_itk.GetSpacing()[::-1]
    
    offset = [int(x/2) for x in output_dim]
    
    pad_img = PadImage(image,offset,pad_val)
    
    min_val,max_val,mean_val,std_val = np.amin(image),np.amax(image),np.mean(image),np.std(image)
    
    for record in target_df[info_kws].values:
        image_path,z,y,x = record[:4]
        point_px = [int(round(float(val))) for val in [z,y,x]]
        
        type_infos = list(record[len(coord_kws):])
        
        pad_mask = None
        if mode=='seg':
            mask_path = type_infos[0]
            print ('mask_path is',mask_path)
            mask_itk,mask = loadImage(mask_path)
            mask = np.array(mask>0).astype(np.uint8)
            if crop_mode=='mask':
                mask_centroid = regionprops(label(mask))[0].centroid
                mask_centroid = [int(round(float(val))) for val in mask_centroid]
                point_px = mask_centroid
            
            if spacing is not None:
                mask_itk = NearestResample(mask_itk,spacing)
                mask = sitk.GetArrayFromImage(mask_itk)
                
            pad_mask = PadImage(mask,offset,0)
            
        if spacing is not None:
            point_phy = image_itk_ori.TransformContinuousIndexToPhysicalPoint(point_px[::-1])
            point_px = image_itk.TransformPhysicalPointToContinuousIndex(point_phy)[::-1]
            
        new_center = [point_px[idx]+offset[idx] for idx in range(3)]
        
        current_patch = CropPatch(pad_img,new_center,offset)

        mask_patch = None
        if mode=='seg':
            mask_patch = CropPatch(pad_mask,new_center,offset)
        
        data_list.append(current_patch)
        current_info = list(record[:8])+list(image_spacing)+[min_val,max_val,mean_val,std_val]+type_infos
        info_list.append(current_info)
        if mask_patch is not None:
            mask_list.append(mask_patch)
            
    if len(mask_list)==0:
        mask_list = None
    return np.array(data_list),np.array(info_list),mask_list