# -*- coding: utf-8 -*-
import os, sys
import pathlib

current_dir = pathlib.Path(__file__).parent.resolve()
while "cls_train" != current_dir.name:
    current_dir = current_dir.parent
sys.path.append(current_dir.as_posix())

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.patches as patches
import seaborn as sns
import cv2

from data.data_process_utils.test_data_utils import check_and_makedirs
from data.data_process_utils.test_data_utils import get_crop_start_end, get_diameter_pixel_length, get_nodule_rect
from data.data_process_utils.test_data_utils import continuous_2d, var_2d


def plot_image_and_rect(data, rect=None,
                        z_start=0, stride=1, image_title=None,
                        cmap=plt.cm.gray, show=False, save=False, file_path=None):
    vmin = -1350
    vmax = 150
    col = 1
    row = max(data.shape[0], 2)
    fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5))
    for i in range(0, data.shape[0], stride):
        z = z_start + i + 1

        ax = plots[i]
        ax.set_title('original z=' + str(z))
        if rect is not None:
            ax.set_title('ground truth z=' + str(z))
            if rect[0][0] <= i < rect[0][1]:
                rectangle = patches.Rectangle((rect[2][0], rect[1][0]),
                                              rect[2][1] - rect[2][0],
                                              rect[1][1] - rect[1][0],
                                              linewidth=0.5, edgecolor='r', facecolor='none')
                ax.add_patch(rectangle)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

    if image_title is not None:
        fig.suptitle(str(image_title), fontsize=18)
    if save:
        check_and_makedirs(file_path, is_file=True)
        fig.savefig(file_path, dpi=90, bbox_inches='tight')
        if not show:
            plt.close()
    if show:
        plt.show()


def plot_two_image(data1, data2,
                   z_start=0, stride=1, image_title=None,
                   cmap=plt.cm.gray, show=False, save=False, file_path=None):
    vmin = -1350
    vmax = 150
    col = 2
    row = max(data1.shape[0], 2)
    fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5))
    for i in range(0, data1.shape[0], stride):
        z = z_start + i + 1

        ax = plots[i, 0]
        ax.set_title('z=' + str(z))
        ax.imshow(data1[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 1]
        ax.set_title('z=' + str(z))
        ax.imshow(data2[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

    if image_title is not None:
        fig.suptitle(str(image_title), fontsize=18)
    if save:
        check_and_makedirs(file_path, is_file=True)
        fig.savefig(file_path, dpi=90, bbox_inches='tight')
        if not show:
            plt.close()
    if show:
        plt.show()


def plot_image_and_mask(data, mask, truth=None, rect=None,
                        spacing=None, show_whole_hist=False,
                        z_start=0, stride=1, image_title=None,
                        cmap=plt.cm.gray, show=False, save=False, file_path=None):
    vmin = -1350
    vmax = 150

    if np.sum(mask == 1) > 500:
        total_nodule_data = data[mask == 1]
        min_density = int(np.min(total_nodule_data))
        max_density = int(np.max(total_nodule_data))
        vmin = max(min_density - 100, -1350)
        vmax = min(max_density + 100, 150)
    if spacing is not None:
        single_volume = spacing[0] * spacing[1] * spacing[2]
        total_volume = int(np.sum(mask == 1) * single_volume)
    else:
        single_volume = 1
        total_volume = 0

    col = 3
    if truth is not None or rect is not None:
        col = col + 1
    row = max(data.shape[0], 2)
    if show_whole_hist:
        row = row + 3
    fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5))
    for i in range(0, data.shape[0], stride):
        z = z_start + i + 1

        ax = plots[i, 0]
        ax.set_title('original z=' + str(z))
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 1]
        ax.set_title('contour z=' + str(z))
        if np.any(mask[i] == 1):
            if spacing is not None:
                max_diameter = 0
                _, contours, _ = cv2.findContours(mask[i], cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
                for contour in contours:
                    rect_min = cv2.minAreaRect(contour)
                    ((_, _), (width, height), angle) = rect_min
                    diameter = np.round(max(width, height) * spacing[1], 1)
                    if max_diameter < diameter:
                        max_diameter = diameter

                    # box = cv2.boxPoints(rect_min)
                    # rectangle = patches.Rectangle(box[1], width, height, angle,
                    #                               linewidth=1, edgecolor='b', facecolor='none')
                    # ax.add_patch(rectangle)

                    # (x, y), radius = cv2.minEnclosingCircle(contour)
                    # (x, y, radius) = np.int0((x, y, radius))
                    # circle = patches.Circle((x, y), radius, linewidth=1, edgecolor='b', facecolor='none')
                    # ax.add_patch(circle)
                mask_point_count = np.sum(mask[i] == 1)
                if mask_point_count > 0:
                    volume = int(mask_point_count * single_volume)
                else:
                    volume = 0
                ax.set_title('contour z=' + str(z) + ' (p' + str(mask_point_count) + '/d' + str(max_diameter) + '/v'
                             + str(volume) + '/v' + str(total_volume) + ')')
            if np.any(mask[i] >= 1):
                ax.contour(mask[i] >= 1, [0.5], colors='g', linewidths=1, alpha=0.75)
            if np.any(mask[i] == 1):
                ax.contour(mask[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)

        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        if truth is not None or rect is not None:
            ax = plots[i, 2]
            ax.set_title('ground truth z=' + str(z))
            if truth is not None and np.any(truth[i] == 1):
                ax.contour(truth[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)
            if rect is not None and rect[0][0] <= i < rect[0][1]:
                rectangle = patches.Rectangle((rect[2][0], rect[1][0]),
                                              rect[2][1] - rect[2][0],
                                              rect[1][1] - rect[1][0],
                                              linewidth=1, edgecolor='r', facecolor='none')
                ax.add_patch(rectangle)
            ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
            ax.axis('off')

        ax = plots[i, col - 1]
        nodule_data = data[i][mask[i] == 1]
        if len(nodule_data) > 0:
            z_average_density = int(np.round(np.mean(nodule_data)))
            z_min_density = int(np.min(nodule_data))
            z_max_density = int(np.max(nodule_data))
            ax.set_title('HU histogram z=' + str(z) + ' (' + str(z_min_density) + '/' + str(z_average_density)
                         + '/' + str(z_max_density) + ')')
            # ax.set_xlabel('Hounsfield Units (HU)')
            # ax.set_ylabel('Count')
            # ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
            sns.set_palette('hls')
            bins = [x for x in range(min(nodule_data), max(nodule_data) + 1, 1)]
            sns.distplot(nodule_data.flatten(), bins=bins, color='blue',
                         kde=False, kde_kws={'bw': .2, 'color': 'blue'},
                         rug=False, ax=ax)

    if show_whole_hist:
        z_max = np.argmax(np.sum(mask == 1, axis=(1, 2)))
        z = z_start + z_max + 1

        new_mask = mask.copy()
        new_mask[new_mask == 99] = 1
        nodule_data_z_max = data[z_max][mask[z_max] == 1]
        nodule_data_z_max_new = data[z_max][new_mask[z_max] == 1]
        nodule_data_total = data[mask == 1]
        nodule_data_total_new = data[new_mask == 1]

        values, counts = np.unique(nodule_data_z_max_new, return_counts=True)
        y_max_z_max = np.ceil(max(counts) / 100) * 100

        values, counts = np.unique(nodule_data_total_new, return_counts=True)
        y_max = np.ceil(max(counts) / 100) * 100
        bins = [x for x in range(min(nodule_data_total_new), max(nodule_data_total_new) + 1, 1)]

        ax = plots[row - 3, col - 2]
        nodule_data = nodule_data_z_max
        if len(nodule_data) > 0:
            z_average_density = int(np.round(np.mean(nodule_data)))
            z_min_density = int(np.min(nodule_data))
            z_max_density = int(np.max(nodule_data))
            ax.set_title('HU histogram z=' + str(z) + ' (' + str(z_min_density) + '/' + str(z_average_density)
                         + '/' + str(z_max_density) + ')')
            ax.set_ylim([0, y_max_z_max])
            # ax.set_xlabel('Hounsfield Units (HU)')
            # ax.set_ylabel('Count')
            # ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
            sns.set_palette('hls')
            sns.distplot(nodule_data.flatten(), bins=bins, color='blue',
                         kde=False, kde_kws={'bw': .2, 'color': 'blue'},
                         rug=False, ax=ax)

        ax = plots[row - 3, col - 1]
        nodule_data = nodule_data_total
        if len(nodule_data) > 0:
            z_average_density = int(np.round(np.mean(nodule_data)))
            z_min_density = int(np.min(nodule_data))
            z_max_density = int(np.max(nodule_data))
            ax.set_title('HU histogram' + ' (' + str(z_min_density) + '/' + str(z_average_density)
                         + '/' + str(z_max_density) + '/v' + str(total_volume) + ')')
            ax.set_ylim([0, y_max])
            # ax.set_xlabel('Hounsfield Units (HU)')
            # ax.set_ylabel('Count')
            # ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
            sns.set_palette('hls')
            sns.distplot(nodule_data.flatten(), bins=bins, color='blue',
                         kde=False, kde_kws={'bw': .2, 'color': 'blue'},
                         rug=False, ax=ax)

        ax = plots[row - 2, col - 2]
        nodule_data = nodule_data_z_max_new
        if len(nodule_data) > 0:
            z_average_density = int(np.round(np.mean(nodule_data)))
            z_min_density = int(np.min(nodule_data))
            z_max_density = int(np.max(nodule_data))
            ax.set_title('HU histogram z=' + str(z) + ' (' + str(z_min_density) + '/' + str(z_average_density)
                         + '/' + str(z_max_density) + ')')
            ax.set_ylim([0, y_max_z_max])
            # ax.set_xlabel('Hounsfield Units (HU)')
            # ax.set_ylabel('Count')
            # ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
            sns.set_palette('hls')
            sns.distplot(nodule_data_z_max_new.flatten(), bins=bins, color='green',
                         kde=False, kde_kws={'bw': .2, 'color': 'green'},
                         rug=False, ax=ax)

        ax = plots[row - 2, col - 1]
        nodule_data = nodule_data_total_new
        if len(nodule_data) > 0:
            z_average_density = int(np.round(np.mean(nodule_data)))
            z_min_density = int(np.min(nodule_data))
            z_max_density = int(np.max(nodule_data))
            ax.set_title('HU histogram' + ' (' + str(z_min_density) + '/' + str(z_average_density)
                         + '/' + str(z_max_density) + '/v' + str(total_volume) + ')')
            ax.set_ylim([0, y_max])
            # ax.set_xlabel('Hounsfield Units (HU)')
            # ax.set_ylabel('Count')
            # ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
            sns.set_palette('hls')
            sns.distplot(nodule_data_total_new.flatten(), bins=bins, color='green',
                         kde=False, kde_kws={'bw': .2, 'color': 'green'},
                         rug=False, ax=ax)

        ax = plots[row - 1, col - 2]
        nodule_data = nodule_data_z_max_new
        if len(nodule_data) > 0:
            ax.set_title('HU histogram z=' + str(z))
            ax.set_ylim([0, y_max_z_max])
            # ax.set_xlabel('Hounsfield Units (HU)')
            # ax.set_ylabel('Count')
            # ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
            sns.set_palette('hls')
            sns.distplot(nodule_data_z_max_new.flatten(), bins=bins, color='green',
                         kde=False, kde_kws={'bw': .2, 'color': 'green'},
                         rug=False, ax=ax)
            sns.distplot(nodule_data_z_max.flatten(), bins=bins, color='blue',
                         kde=False, kde_kws={'bw': .2, 'color': 'blue'},
                         rug=False, ax=ax)

        ax = plots[row - 1, col - 1]
        nodule_data = nodule_data_total_new
        if len(nodule_data) > 0:
            ax.set_title('HU histogram')
            ax.set_ylim([0, y_max])
            # ax.set_xlabel('Hounsfield Units (HU)')
            # ax.set_ylabel('Count')
            # ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
            sns.set_palette('hls')
            sns.distplot(nodule_data_total_new.flatten(), bins=bins, color='green',
                         kde=False, kde_kws={'bw': .2, 'color': 'green'},
                         rug=False, ax=ax)
            sns.distplot(nodule_data_total.flatten(), bins=bins, color='blue',
                         kde=False, kde_kws={'bw': .2, 'color': 'blue'},
                         rug=False, ax=ax)

    if image_title is not None:
        fig.suptitle(str(image_title), fontsize=18)
    if save:
        check_and_makedirs(file_path, is_file=True)
        fig.savefig(file_path, dpi=90, bbox_inches='tight')
        if not show:
            plt.close()
    if show:
        plt.show()


def plot_mask(data, mask, spacing=None, show_whole_hist=False, nodule_peak_info=None,
              z_start=0, stride=1, image_title=None,
              cmap=plt.cm.gray, show=False, save=False, file_path=None):
    vmin = -1350
    vmax = 150

    if np.sum(mask == 1) > 500:
        total_nodule_data = data[mask == 1]
        min_density = int(np.min(total_nodule_data))
        max_density = int(np.max(total_nodule_data))
        vmin = max(min_density - 100, -1350)
        vmax = min(max_density + 100, 150)
    if spacing is not None:
        single_volume = spacing[0] * spacing[1] * spacing[2]
        total_volume = int(np.sum(mask == 1) * single_volume)
    else:
        single_volume = 1
        total_volume = 0

    col = 1
    row = 4
    fig, plots = plt.subplots(row, col, figsize=(col * 8, row * 5))
    fig.subplots_adjust(left=0.5, top=0.5, hspace=1)

    if show_whole_hist:
        new_mask = mask.copy()
        new_mask[new_mask == 99] = 1
        nodule_data_total = data[mask == 1]
        nodule_data_total_new = data[new_mask == 1]

        values, counts = np.unique(nodule_data_total_new, return_counts=True)
        y_max = np.ceil(1.3 * max(counts) / 100) * 100
        if nodule_peak_info is not None:
            y_max_ratio = nodule_peak_info[1] / y_max + 0.1
            y_max_height = y_max * y_max_ratio + 20
        temp_min_density = max(int(np.max(nodule_data_total_new)), -200)
        bins = [x for x in range(-1000, temp_min_density + 1, 1)]
        # ax = plots[0]
        # nodule_data = nodule_data_total
        # if len(nodule_data) > 0:
        #     z_average_density = int(np.round(np.mean(nodule_data)))
        #     z_min_density = int(np.min(nodule_data))
        #     z_max_density = int(np.max(nodule_data))
        #     ax.set_title('HU histogram' + ' (' + str(z_min_density) + '/' + str(z_average_density)
        #                  + '/' + str(z_max_density) + '/v' + str(total_volume) + ')')
        #     ax.set_ylim([0, y_max])
        #     if nodule_peak_info is not None:
        #         ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
        #         ax.text(nodule_peak_info[0] - 30, y_max_height,
        #                 str(nodule_peak_info[0]), fontsize=8, color='r')
        #     # ax.set_xlabel('Hounsfield Units (HU)')
        #     # ax.set_ylabel('Count')
        #     # ax.hist(nodule_data.flatten(), bins=bins, color='blue', alpha=0.75)
        #     sns.set_palette('hls')
        #     sns.distplot(nodule_data.flatten(), bins=bins, color='blue',
        #                  hist=True, kde=False, kde_kws={'bw': .2, 'color': 'blue', 'shade': True},
        #                  rug=False, norm_hist=False, ax=ax)
        #
        # ax = plots[1]
        # nodule_data = nodule_data_total_new
        # if len(nodule_data) > 0:
        #     ax.set_title('HU histogram')
        #     ax.set_ylim([0, y_max])
        #     if nodule_peak_info is not None:
        #         ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
        #         ax.text(nodule_peak_info[0] - 30, y_max_height,
        #                 str(nodule_peak_info[0]), fontsize=8, color='r')
        #     # ax.set_xlabel('Hounsfield Units (HU)')
        #     # ax.set_ylabel('Count')
        #     # ax.hist(nodule_data.flatten(), bins=bins, color='blue', alpha=0.75)
        #     sns.set_palette('hls')
        #     sns.distplot(nodule_data_total_new.flatten(), bins=bins, color='green',
        #                  hist=True, kde=False, kde_kws={'bw': .2, 'color': 'green', 'shade': True},
        #                  rug=False, norm_hist=False, ax=ax)
        #     sns.distplot(nodule_data_total.flatten(), bins=bins, color='blue',
        #                  hist=True, kde=False, kde_kws={'bw': .2, 'color': 'blue', 'shade': True},
        #                  rug=False, norm_hist=False, ax=ax)

        ax = plots[0]
        nodule_data = nodule_data_total
        if len(nodule_data) > 0:
            z_average_density = int(np.round(np.mean(nodule_data)))
            z_min_density = int(np.min(nodule_data))
            z_max_density = int(np.max(nodule_data))
            ax.set_title('HU Curve' + ' (' + str(z_min_density) + '/' + str(z_average_density)
                         + '/' + str(z_max_density) + '/v' + str(total_volume) + ')')
            ax.set_ylim([0, y_max])
            ax.set_xlim([bins[0], bins[-1]])
            if nodule_peak_info is not None:
                ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
                ax.text(nodule_peak_info[0] - 30, y_max_height,
                        str(nodule_peak_info[0]), fontsize=8, color='r')

            values, counts = np.unique(nodule_data, return_counts=True)
            ax.plot(values, counts, color='blue', linewidth=0.5)

        ax = plots[1]
        nodule_data = nodule_data_total_new
        if len(nodule_data) > 0:
            ax.set_title('HU Curve')
            ax.set_ylim([0, y_max])
            ax.set_xlim([bins[0], bins[-1]])
            if nodule_peak_info is not None:
                ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
                ax.text(nodule_peak_info[0] - 30, y_max_height,
                        str(nodule_peak_info[0]), fontsize=8, color='r')

            values, counts = np.unique(nodule_data, return_counts=True)
            ax.plot(values, counts, color='green', linewidth=0.5)

            values, counts = np.unique(nodule_data_total, return_counts=True)
            ax.plot(values, counts, color='blue', linewidth=0.5)

        ax = plots[2]
        nodule_data = nodule_data_total
        if len(nodule_data) > 0:
            ax.set_title('HU Curve')
            ax.set_ylim([0, y_max])
            ax.set_xlim([bins[0], bins[-1]])
            if nodule_peak_info is not None:
                ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
                ax.text(nodule_peak_info[0] - 30, y_max_height,
                        str(nodule_peak_info[0]), fontsize=8, color='r')

                if nodule_peak_info[3] - nodule_peak_info[2] > 50:
                    ax.axvline(nodule_peak_info[2], color='gray', ls='--', lw=0.5)
                    ax.axvline(nodule_peak_info[3], color='gray', ls='--', lw=0.5)
                    ax.text(int((nodule_peak_info[2] + nodule_peak_info[3]) / 2) - 20, 10,
                            str(int(np.round(nodule_peak_info[4] * 100))) + '%', fontsize=8, color='gray')

            values, counts = np.unique(nodule_data, return_counts=True)
            ax.plot(values, counts, color='blue', linewidth=0.5)

        ax = plots[3]
        nodule_data = nodule_data_total_new
        if len(nodule_data) > 0:
            ax.set_title('HU Curve')
            ax.set_ylim([0, y_max])
            ax.set_xlim([bins[0], bins[-1]])
            if nodule_peak_info is not None:
                ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
                ax.text(nodule_peak_info[0] - 30, y_max_height,
                        str(nodule_peak_info[0]), fontsize=8, color='r')

                if nodule_peak_info[3] - nodule_peak_info[2] > 50:
                    ax.axvline(nodule_peak_info[2], color='gray', ls='--', lw=0.5)
                    ax.axvline(nodule_peak_info[3], color='gray', ls='--', lw=0.5)
                    ax.text(int((nodule_peak_info[2] + nodule_peak_info[3]) / 2) - 20, 10,
                            str(int(np.round(nodule_peak_info[4] * 100))) + '%', fontsize=8, color='gray')

            values, counts = np.unique(nodule_data, return_counts=True)
            ax.plot(values, counts, color='green', linewidth=0.5)

            values, counts = np.unique(nodule_data_total, return_counts=True)
            ax.plot(values, counts, color='blue', linewidth=0.5)

    if image_title is not None:
        fig.suptitle(str(image_title), fontsize=18)
    if save:
        check_and_makedirs(file_path, is_file=True)
        fig.savefig(file_path, dpi=90, bbox_inches='tight')
        if not show:
            plt.close()
    if show:
        plt.show()


def plot_hu_curve(data,
                  z_start=0, stride=1, image_title=None,
                  cmap=plt.cm.gray, show=False, save=False, file_path=None):
    vmin = -1350
    vmax = 150

    col = 2
    row = max(data.shape[0], 2)
    fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5), gridspec_kw={'width_ratios': [1, 8]})
    for i in range(0, data.shape[0], stride):
        z = z_start + i + 1

        ax = plots[i, 0]
        ax.set_title('original z=' + str(z))
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, col-1]
        nodule_data = continuous_2d(data[i])
        nodule_data = nodule_data.flatten()
        nodule_data = nodule_data[nodule_data > -200]
        if len(nodule_data) > 0:
            ax.set_title('HU Curve')
            ax.set_xlim([0, len(nodule_data)])
            x_values = [x for x in range(len(nodule_data))]
            ax.plot(x_values, nodule_data, color='blue', linewidth=0.5)

    if image_title is not None:
        fig.suptitle(str(image_title), fontsize=18)
    if save:
        check_and_makedirs(file_path, is_file=True)
        fig.savefig(file_path, dpi=90, bbox_inches='tight')
        if not show:
            plt.close()
    if show:
        plt.show()


def plot_hu_curve2(data,
                   z_start=0, stride=1, image_title=None,
                   cmap=plt.cm.gray, show=False, save=False, file_path=None):
    vmin = -1350
    vmax = 150

    col = 12
    row = max(data.shape[0], 2)
    fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5),
                              gridspec_kw={'width_ratios': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 8]})
    for i in range(0, data.shape[0], stride):
        z = z_start + i + 1

        ax = plots[i, 0]
        ax.set_title('original z=' + str(z))
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 1]
        kernel_size = 3
        var = 9
        ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
        var_data = var_2d(data[i], kernel_size=kernel_size)
        change_coords = np.asarray(np.where(var_data <= var))
        ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 2]
        kernel_size = 3
        var = 18
        ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
        var_data = var_2d(data[i], kernel_size=kernel_size)
        change_coords = np.asarray(np.where(var_data <= var))
        ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 3]
        kernel_size = 5
        var = 25
        ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
        var_data = var_2d(data[i], kernel_size=kernel_size)
        change_coords = np.asarray(np.where(var_data <= var))
        ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 4]
        kernel_size = 5
        var = 50
        ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
        var_data = var_2d(data[i], kernel_size=kernel_size)
        change_coords = np.asarray(np.where(var_data <= var))
        ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 5]
        kernel_size = 7
        var = 49
        ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
        var_data = var_2d(data[i], kernel_size=kernel_size)
        change_coords = np.asarray(np.where(var_data <= var))
        ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 6]
        kernel_size = 7
        var = 98
        ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
        var_data = var_2d(data[i], kernel_size=kernel_size)
        change_coords = np.asarray(np.where(var_data <= var))
        ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 7]
        kernel_size = 9
        var = 81
        ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
        var_data = var_2d(data[i], kernel_size=kernel_size)
        change_coords = np.asarray(np.where(var_data <= var))
        ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 8]
        kernel_size = 9
        var = 162
        ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
        var_data = var_2d(data[i], kernel_size=kernel_size)
        change_coords = np.asarray(np.where(var_data <= var))
        ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 9]
        kernel_size = 11
        var = 121
        ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
        var_data = var_2d(data[i], kernel_size=kernel_size)
        change_coords = np.asarray(np.where(var_data <= var))
        ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, 10]
        kernel_size = 11
        var = 242
        ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
        var_data = var_2d(data[i], kernel_size=kernel_size)
        change_coords = np.asarray(np.where(var_data <= var))
        ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        ax = plots[i, col-1]
        nodule_data = continuous_2d(data[i])
        nodule_data = nodule_data.flatten()
        nodule_data = nodule_data[nodule_data > -200]
        if len(nodule_data) > 0:
            ax.set_title('HU Curve')
            ax.set_xlim([0, len(nodule_data)])
            x_values = [x for x in range(len(nodule_data))]
            ax.plot(x_values, nodule_data, color='blue', linewidth=0.5)

    if image_title is not None:
        fig.suptitle(str(image_title), fontsize=18)
    if save:
        check_and_makedirs(file_path, is_file=True)
        fig.savefig(file_path, dpi=90, bbox_inches='tight')
        if not show:
            plt.close()
    if show:
        plt.show()


def plot_hu_3d(data):
    fig = plt.figure()
    ax = Axes3D(fig)
    xpos, ypos = np.meshgrid(np.linspace(0, data.shape[1] - 1, data.shape[1]),
                             np.linspace(0, data.shape[0] - 1, data.shape[0]))
    ax.plot_surface(xpos, ypos, data, rstride=1, cstride=1, cmap=plt.get_cmap('rainbow'))
    ax.contour(xpos, ypos, data, zdir='z', cmap=plt.get_cmap('rainbow'))
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('hu')
    plt.show()


def plot_hu_hist(nodule_data):
    plt.hist(nodule_data.flatten(), bins=[x for x in range(-1000, -200 + 1, 1)], color='c', alpha=0.75)
    plt.xlabel('Hounsfield Units (HU)')
    plt.ylabel('Count')
    plt.show()


def plot_image_multi_mask(data, mask, mask1=None, mask2=None, mask3=None,
                          spacing=None,
                          z_start=0, stride=1, image_title=None,
                          cmap=plt.cm.gray, show=False, save=False, file_path=None):
    vmin = -1350
    vmax = 150

    if np.sum(mask == 1) > 500:
        total_nodule_data = data[mask == 1]
        min_density = int(np.min(total_nodule_data))
        max_density = int(np.max(total_nodule_data))
        vmin = max(min_density - 100, -1350)
        vmax = min(max_density + 100, 150)
    if spacing is not None:
        single_volume = spacing[0] * spacing[1] * spacing[2]
    else:
        single_volume = 1

    col = 1
    if mask1 is not None:
        col = col + 1
    if mask2 is not None:
        col = col + 1
    if mask3 is not None:
        col = col + 1
    row = max(data.shape[0], 2)
    fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5))
    for i in range(0, data.shape[0], stride):
        z = z_start + i + 1

        ax = plots[i, 0]
        mask_point_count = np.sum(mask[i] == 1)
        ax.set_title('original z=' + str(z) + ' (p' + str(mask_point_count) + ')')
        if spacing is not None:
            volume = int(mask_point_count * single_volume)
            total_volume = int(np.sum(mask == 1) * single_volume)
            ax.set_title('original z=' + str(z) + ' (p' + str(mask_point_count) + '/v'
                         + str(volume) + '/v' + str(total_volume) + ')')

        if np.any(mask[i] == 1):
            ax.contour(mask[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)
        ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
        ax.axis('off')

        if mask1 is not None:
            ax = plots[i, 1]
            mask_point_count = np.sum(mask1[i] == 1)
            ax.set_title('contour z=' + str(z) + ' (p' + str(mask_point_count) + ')')
            if spacing is not None:
                volume = int(mask_point_count * single_volume)
                total_volume = int(np.sum(mask1 == 1) * single_volume)
                ax.set_title('contour z=' + str(z) + ' (p' + str(mask_point_count) + '/v'
                             + str(volume) + '/v' + str(total_volume) + ')')

            if np.any(mask1[i] == 1):
                ax.contour(mask1[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)
            ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
            ax.axis('off')

        if mask2 is not None:
            ax = plots[i, 2]
            mask_point_count = np.sum(mask2[i] == 1)
            ax.set_title('extend z=' + str(z) + ' (p' + str(mask_point_count) + ')')
            if spacing is not None:
                volume = int(mask_point_count * single_volume)
                total_volume = int(np.sum(mask2 == 1) * single_volume)
                ax.set_title('extend z=' + str(z) + ' (p' + str(mask_point_count) + '/v'
                             + str(volume) + '/v' + str(total_volume) + ')')

            if np.any(mask2[i] >= 1):
                ax.contour(mask2[i] >= 1, [0.5], colors='g', linewidths=1, alpha=0.75)
            if np.any(mask2[i] == 1):
                ax.contour(mask2[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)
            change_mask = mask[i] - mask2[i]
            change_coords = np.asarray(np.where(change_mask == 1))
            ax.scatter(change_coords[1, :], change_coords[0, :], color='b', s=1, alpha=1)
            change_coords = np.asarray(np.where(change_mask == 255))
            ax.scatter(change_coords[1, :], change_coords[0, :], color='g', s=1, alpha=1)
            ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
            ax.axis('off')

        if mask3 is not None:
            ax = plots[i, 3]
            mask_point_count = np.sum(mask3[i] == 1)
            ax.set_title('shrink z=' + str(z) + ' (p' + str(mask_point_count) + ')')
            if spacing is not None:
                volume = int(mask_point_count * single_volume)
                total_volume = int(np.sum(mask3 == 1) * single_volume)
                ax.set_title('shrink z=' + str(z) + ' (p' + str(mask_point_count) + '/v'
                             + str(volume) + '/v' + str(total_volume) + ')')

            if np.any(mask3[i] == 1):
                ax.contour(mask3[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)
            change_mask = mask[i] - mask3[i]
            change_coords = np.asarray(np.where(change_mask == 1))
            ax.scatter(change_coords[1, :], change_coords[0, :], color='b', s=1, alpha=1)
            change_coords = np.asarray(np.where(change_mask == 255))
            ax.scatter(change_coords[1, :], change_coords[0, :], color='g', s=1, alpha=1)
            # filter_data = ndimage.median_filter(data[i], 3)
            # filter_data = data[i].copy()
            # filter_data = ndimage.gaussian_filter(filter_data, sigma=2)
            ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
            ax.axis('off')

    if image_title is not None:
        fig.suptitle(str(image_title), fontsize=18)
    if save:
        check_and_makedirs(file_path, is_file=True)
        fig.savefig(file_path, dpi=90, bbox_inches='tight')
        if not show:
            plt.close()
    if show:
        plt.show()


def generate_video_one_image(data, interval=200, file_path=None):
    """
    Given CT img, return an animation across axial slice
    data: [D,H,W] or [D,H,W,3]
    interval: interval between each slice, default 200
    file_path: path to save the animation if not None, default None

    return: matplotlib.animation.Animation
    """
    fig = plt.figure()
    imgs = []
    for i in range(len(data)):
        img = plt.imshow(data[i], animated=True)
        imgs.append([img])
    anim = animation.ArtistAnimation(fig, imgs, interval=interval, blit=True, repeat_delay=1000)
    if file_path:
        Writer = animation.writers['ffmpeg']
        writer = Writer(fps=30, metadata=dict(artist='Me'), bitrate=1800)
        anim.save(file_path)

    return anim


def generate_video_two_image(data1, title1, data2, title2, interval=200, file_path=None):
    vmin = -1350
    vmax = 150
    fig, axes = plt.subplots(1, 2)
    imgs = []
    for i in range(len(data1)):
        ax = axes[0]
        ax.set_title(title1)
        img1 = ax.imshow(data1[i], cmap='gray', vmin=vmin, vmax=vmax, animated=True)

        ax = axes[1]
        ax.set_title(title2)
        img2 = ax.imshow(data2[i], cmap='gray', vmin=vmin, vmax=vmax, animated=True)
        imgs.append([img1, img2])

    anim = animation.ArtistAnimation(fig, imgs, interval=interval, blit=True, repeat_delay=1000)
    if file_path:
        Writer = animation.writers['ffmpeg']
        writer = Writer(fps=30, metadata=dict(artist='Me'), bitrate=1800)
        anim.save(file_path)

    return anim


def generate_video_predict(data, mask, truth, interval=200, file_path=None):
    vmin = -1350
    vmax = 150
    fig, axes = plt.subplots(1, 2)
    imgs = []
    for i in range(len(data)):
        ax = axes[0]
        ax.set_title('predict')
        coords = np.asarray(np.where(mask[i] == 1))
        predict_mask = ax.scatter(coords[1, :], coords[0, :], color='r', s=1, alpha=0.5, animated=True)
        predict_img = ax.imshow(data[i], cmap='gray', vmin=vmin, vmax=vmax, animated=True)

        ax = axes[1]
        ax.set_title('ground truth')
        coords = np.asarray(np.where(truth[i] == 1))
        gt_mask = ax.scatter(coords[1, :], coords[0, :], color='c', s=1, alpha=0.5, animated=True)
        gt_img = ax.imshow(data[i], cmap='gray', vmin=vmin, vmax=vmax, animated=True)
        imgs.append([predict_img, predict_mask, gt_img, gt_mask])

    anim = animation.ArtistAnimation(fig, imgs, interval=interval, blit=True, repeat_delay=1000)
    if file_path:
        Writer = animation.writers['ffmpeg']
        writer = Writer(fps=30, metadata=dict(artist='Me'), bitrate=1800)
        anim.save(file_path)

    return anim


def plot_shrink_extend_mask(data, one_nodule_mask, result_mask, extend_mask, shrink_mask, spacing, file_path):
    coords = np.asarray(np.where(one_nodule_mask == 1))
    if len(coords[0]) > 0:
        coord_start = coords.min(axis=1)
        coord_end = coords.max(axis=1) + 1
        z_num = coord_end[0] - coord_start[0] + 4
        image_height = max(coord_end[1] - coord_start[1], coord_end[2] - coord_start[2], 200)
        crop_start, crop_end = get_crop_start_end(coord_start, coord_end, (z_num, image_height, image_height))
        show_data = data[crop_start[0]:crop_end[0],
                         crop_start[1]:crop_end[1],
                         crop_start[2]:crop_end[2]]
        show_mask = one_nodule_mask[crop_start[0]:crop_end[0],
                                    crop_start[1]:crop_end[1],
                                    crop_start[2]:crop_end[2]]
        result_mask1 = result_mask[crop_start[0]:crop_end[0],
                                   crop_start[1]:crop_end[1],
                                   crop_start[2]:crop_end[2]]
        extend_mask1 = extend_mask[crop_start[0]:crop_end[0],
                                   crop_start[1]:crop_end[1],
                                   crop_start[2]:crop_end[2]]
        shrink_mask1 = shrink_mask[crop_start[0]:crop_end[0],
                                   crop_start[1]:crop_end[1],
                                   crop_start[2]:crop_end[2]]

        plot_image_multi_mask(show_data,
                              show_mask,
                              result_mask1,
                              extend_mask1,
                              shrink_mask1,
                              spacing=spacing,
                              z_start=crop_start[0],
                              stride=1,
                              show=False,
                              save=True,
                              file_path=file_path)


def plot_nodule_box_mask(data, one_nodule_mask, nodule_box, spacing, file_path):
    coords = np.asarray(np.where(one_nodule_mask == 1))
    if len(coords[0]) > 0:
        coord_start = coords.min(axis=1)
        coord_end = coords.max(axis=1) + 1
        z_num = coord_end[0] - coord_start[0] + 4
        if nodule_box is not None:
            z_pixel = get_diameter_pixel_length(nodule_box.diameter, spacing[0])
            z_num = max(z_num, z_pixel + 4)
        image_height = max(coord_end[1] - coord_start[1], coord_end[2] - coord_start[2], 200)
        crop_start, crop_end = get_crop_start_end(coord_start, coord_end, (z_num, image_height, image_height))
        show_data = data[crop_start[0]:crop_end[0],
                         crop_start[1]:crop_end[1],
                         crop_start[2]:crop_end[2]]
        show_mask = one_nodule_mask[crop_start[0]:crop_end[0],
                                    crop_start[1]:crop_end[1],
                                    crop_start[2]:crop_end[2]]
        if nodule_box is not None:
            box = get_nodule_rect(nodule_box, spacing)
            box[:, 0] = box[:, 0] - crop_start
            box[:, 1] = box[:, 1] - crop_start
        else:
            box = np.zeros((3, 2), np.int)
            box[:, 0] = coord_start - crop_start
            box[:, 1] = coord_end - crop_start
        plot_image_and_mask(show_data,
                            show_mask,
                            rect=box,
                            spacing=spacing,
                            show_whole_hist=True,
                            z_start=crop_start[0],
                            stride=1,
                            show=False,
                            save=True,
                            file_path=file_path)