import warnings
warnings.filterwarnings('ignore')
import os
import copy
import sys
import math
import random
import shutil
import numpy as np
#from scipy.misc import comb
from sklearn import metrics


def data_augmentation(x, y, prob=0.5):
    # augmentation by flipping
    x = np.squeeze(x)
    y = np.squeeze(y)
    cnt = 3
    while random.random() < prob and cnt > 0:
        degree = random.choice([0, 1, 2])
        x = np.flip(x, axis=degree)
        y = np.flip(y, axis=degree)
        cnt = cnt - 1
    x = x[np.newaxis,...,np.newaxis]
    y = y[np.newaxis,...,np.newaxis]
    return x, y

def random_add(image):
    add_value = np.random.random_sample(size=1)[0] * 0.01
    return image + add_value


def cropWithDiameter(image_single,diameter,zoom=False,input_shape =[32,48,48],patch_max_sizes=[48,64,64],margin_z=4,min_size_z=16):
    '''
    Crop data s.t. data shape is in direct proportion to "diameter of nodule"
    TODO: change enlarged_spacing_x & enlarged_spacing_y.
    '''
    enlarged_spacing_x = 0.67
    enlarged_spacing_z = enlarged_spacing_x * (input_shape[0]*1.0/input_shape[1])
    enlarged_spacing_y = enlarged_spacing_x
    min_size_z = int(patch_max_sizes[0]/2)
    min_size_y = int(round(min_size_z * (input_shape[0]*1.0/input_shape[1])))
    min_size_x = min_size_y

    margin_z = 4
    margin_y = int(round(margin_z * (input_shape[0]*1.0/input_shape[1])))
    margin_x = margin_y
    target_size = [min(max(diameter/enlarged_spacing_z * 2 + margin_z,min_size_z),patch_max_sizes[0]),
               min(max(diameter/enlarged_spacing_x * 2 + margin_y,min_size_y),patch_max_sizes[1]),
               min(max(diameter/enlarged_spacing_y * 2 + margin_x,min_size_x),patch_max_sizes[2])]
    
    target_size = [int(x) for x in target_size]
    deltas = [int((image_single.shape[idx]-target_size[idx])/2) for idx in range(3)]
    image_single = image_single[deltas[0]:deltas[0]+target_size[0],
                               deltas[1]:deltas[1]+target_size[1],
                               deltas[2]:deltas[2]+target_size[2]]

    zoom_values = [input_shape[idx]/float(target_size[idx]) for idx in range(3)]

    if zoom:
        image_single = zoom(image_single,zoom_values,order=2)
        new_deltas = [input_shape[idx] - image_single.shape[idx] for idx in range(3)]
    return image_single