import os
import numpy as np
import SimpleITK as sitk
import cv2

import imageio
from tqdm import tqdm
import sys

import argparse

def pic2gif(path, output_name, mdim=512):
    files = []
    for f in os.listdir(path):
        files.append(f)
    files.sort(key=lambda x: x[:-4])
    frames = []
    for i in tqdm(range(len(files))):
        im = imageio.imread(path + '/' + files[i])
        frames.append(im)
    fps = 24.0
    imageio.mimsave(output_name, frames, 'GIF', duration=1/fps)

def norm(image, hu_min=-1000.0, hu_max=600.0):
    image = (np.clip(image.astype(np.float32), hu_min, hu_max) - hu_min) / float(hu_max - hu_min)
    return image * 255.

def img2gif(tmp_path, raw_path):
    target_path = os.path.join(os.path.dirname(tmp_path), 'save_images')
    if not os.path.exists(target_path):
        os.makedirs(target_path)

    tmp_files = os.listdir(tmp_path)

    for file in tmp_files:
        print('deal with {}'.format(file))
        npy_info = np.load(os.path.join(tmp_path, file, 'all_rpns.npy'))
        raw_img = sitk.ReadImage(os.path.join(raw_path, file + '.nii.gz'))
        raw_img = sitk.GetArrayFromImage(raw_img)

        all_z = [int(i) for i in list(npy_info[:,1])]
        min_z, max_z = min(all_z)-2, max(all_z)+2

        used_slices = []
        for z in all_z:
            temp = []
            for i in range(z-3, z+3):
                temp.append(i)
            used_slices.append(temp)

        all_images = []
        for it_z in range(min_z, max_z):
            u_idx = []
            for key, value in enumerate(used_slices):
                if it_z in value:
                    u_idx.append(key)
            u_idx = sorted(u_idx)
            print('have {} box'.format(len(u_idx)))
            
            if len(u_idx) == 0:
                all_images.append(cv2.cvtColor(norm(raw_img[it_z,:,:]), cv2.COLOR_GRAY2BGR))
                # all_images.append(norm(raw_img[it_z,:,:]))
            else:
                worked_slice = norm(raw_img[it_z,:,:])
                worked_slice = cv2.cvtColor(worked_slice, cv2.COLOR_GRAY2BGR)
                for idx in u_idx:
                    z,y,x,dim_z,dim_y,dim_x = npy_info[idx][1:]
                    start_y = max(0, y - dim_y / 2)
                    end_y = min(worked_slice.shape[0], y + dim_y / 2)
                    start_z = max(0, z - dim_z / 2)
                    end_z = min(worked_slice.shape[1], z + dim_z / 2)
                    cv2.rectangle(worked_slice, (int(start_y), int(start_z)), (int(end_y), int(end_z)), (0,0,255), 2)
                all_images.append(worked_slice)
        
        for idx, img in enumerate(all_images):
            save_path = os.path.join(target_path, file)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            cv2.imwrite(os.path.join(save_path, '{:03d}.png'.format(idx)), img)

    output_gif_path = os.path.join(os.path.dirname(target_path), 'save_gifs')
    if not os.path.exists(output_gif_path):
        os.makedirs(output_gif_path)

    for file in os.listdir(target_path):
        files = []
        for f in os.listdir(os.path.join(target_path, file)):
            files.append(f)
        files.sort(key=lambda x: x[:-4])
        frames = []
        for i in tqdm(range(len(files))):
            im = imageio.imread(os.path.join(target_path, file) + '/' + files[i])
            frames.append(im)
        fps = 24.0
        imageio.mimsave(os.path.join(output_gif_path, file+'.gif'), frames, 'GIF', duration=1/fps)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='elastix processing')
    parser.add_argument('--job_data_root', default='/data/job_678/job_data_test', type=str, help='model')
    args = parser.parse_args()

    tmp_path = os.path.join(args.job_data_root, 'output/tmp/rpns')
    raw_path = os.path.join(args.job_data_root, 'output/preprocess/preprocess_file_details/nii')

    img2gif(tmp_path, raw_path)