import copy
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from functools import lru_cache


def custom_collate_fn_2d(batch):
    label_id, npy_data_2d, label, z_index, rotate_count = zip(*batch)
    label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id]
    z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index]
    rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count]
    npy_data_2d = torch.cat(npy_data_2d, dim=0)
    label = torch.cat(label, dim=0)
    return label_id, npy_data_2d, label, z_index, rotate_count

def custom_collate_fn_2d_error(batch):
    pass


def custom_collate_fn_3d(batch):
    label_id, npy_data_3d, label, z_index, rotate_count = zip(*batch)
    label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id]
    z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index]
    rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count]
    npy_data_3d = torch.cat(npy_data_3d, dim=0)
    label = torch.cat(label, dim=0)
    return label_id, npy_data_3d, label, z_index, rotate_count

def custom_collate_fn_3d_error(batch):
    pass


def custom_collate_fn_2d3d(batch):
    label_id_2d, npy_data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, npy_data_3d, label_3d, z_index_3d, rotate_count_3d = zip(*batch)
    label_id_2d = [idx_label_id for sub_label_id in label_id_2d for idx_label_id in sub_label_id]
    label_id_3d = [idx_label_id for sub_label_id in label_id_3d for idx_label_id in sub_label_id]
    z_index_2d = [idx_z_index for sub_z_index in z_index_2d for idx_z_index in sub_z_index]
    z_index_3d = [idx_z_index for sub_z_index in z_index_3d for idx_z_index in sub_z_index]
    rotate_count_2d = [idx_rotate_count for sub_rotate_count in rotate_count_2d for idx_rotate_count in sub_rotate_count]
    rotate_count_3d = [idx_rotate_count for sub_rotate_count in rotate_count_3d for idx_rotate_count in sub_rotate_count]
    npy_data_2d = torch.cat(npy_data_2d, dim=0)
    npy_data_3d = torch.cat(npy_data_3d, dim=0)
    label_2d = torch.cat(label_2d, dim=0)
    label_3d = torch.cat(label_3d, dim=0)
    return label_id_2d, npy_data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, npy_data_3d, label_3d, z_index_3d, rotate_count_3d

def custom_collate_fn_2d3d_error(batch):
    pass

def custom_collate_fn_s3d(batch):
    label_id, npy_data, label, z_index, rotate_count = zip(*batch)
    label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id]
    z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index]
    rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count]
    npy_data = torch.cat(npy_data, dim=0)
    label = torch.cat(label, dim=0)
    return label_id, npy_data, label, z_index, rotate_count

def custom_collate_fn_s3d_error(batch):
    pass


def custom_collate_fn_resnet3d(batch):
    label_id, npy_data, label, z_index, rotate_count = zip(*batch)
    label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id]
    z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index]
    rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count]
    npy_data = torch.cat(npy_data, dim=0)
    label = torch.cat(label, dim=0)
    return label_id, npy_data, label, z_index, rotate_count


def custom_collate_fn_resnet3d_error(batch):
    pass



def custom_collate_fn_d2d(batch):
    label_id, npy_data_2d, label, z_index, rotate_count = zip(*batch)
    label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id]
    z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index]
    rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count]
    npy_data_2d = torch.cat(npy_data_2d, dim=0)
    label = torch.cat(label, dim=0)
    return label_id, npy_data_2d, label, z_index, rotate_count

def custom_collate_fn_d2d_error(batch):
    pass


class ClassificationDataset2d(Dataset):
    def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
        self.df = pd.read_csv(csv_file, header=0, encoding="utf-8")
        print(f"{data_info} lens: {len(self.df)}")
        self.node_column = node_column
        self.label_id_column = label_id_column
        self.npy_file_column = npy_file_column
        self.class_column = class_column
        self.z_index_column = z_index_column
        self.rotate_column = rotate_column
        self.lens = len(self.df)

    def __len__(self):
        return self.lens

    @lru_cache(maxsize=10)
    def load_npy_data(self, npy_data_path):
        data = np.load(npy_data_path) # [256, 256]
        data = data[np.newaxis, np.newaxis, :, :] # [1, 1,256, 256]
        return data
    
    def __getitem__(self, idx):
        npy_file = self.df.loc[idx, self.npy_file_column]
        node = self.df.loc[idx, self.node_column]
        label_id = self.df.loc[idx, self.label_id_column]
        label = self.df.loc[idx, self.class_column]
        z_index = self.df.loc[idx, self.z_index_column]
        rotate_count = self.df.loc[idx, self.rotate_column]
        npy_data = self.load_npy_data(npy_file)

        label_id_str = f"{node}_{label_id}"
        data_2d = torch.from_numpy(npy_data).to(torch.float32)
        label = torch.tensor([label], dtype=torch.float32)
        return [label_id_str], data_2d, label, [z_index], [rotate_count]


class ClassificationDatasetError2d(Dataset):
    pass


class ClassificationDataset3d(Dataset):
    def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
        self.df = pd.read_csv(csv_file, header=0, encoding="utf-8")
        print(f"{data_info} lens: {len(self.df)}")
        self.node_column = node_column
        self.label_id_column = label_id_column
        self.npy_file_column = npy_file_column
        self.class_column = class_column
        self.z_index_column = z_index_column
        self.rotate_column = rotate_column
        self.lens = len(self.df)
    
    def __len__(self):
        return self.lens
    
    @lru_cache(maxsize=10)
    def load_npy_data(self, npy_data_path):
        data = np.load(npy_data_path)   # [48, 256, 256]
        print(f"npy_data_path: {npy_data_path}, data.shape: {data.shape}")
        data = data[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256]
        return data
    
    def __getitem__(self, idx):
        npy_file = self.df.loc[idx, self.npy_file_column]
        node = self.df.loc[idx, self.node_column]
        label_id = self.df.loc[idx, self.label_id_column]
        label = self.df.loc[idx, self.class_column]
        z_index = self.df.loc[idx, self.z_index_column]
        rotate_count = self.df.loc[idx, self.rotate_column]
        npy_data = self.load_npy_data(npy_file)

        label_id_str = f"{node}_{label_id}"
        data_3d = torch.from_numpy(npy_data).to(torch.float32)
        label = torch.tensor([label], dtype=torch.float32)

        return [label_id_str], data_3d, label, [z_index], [rotate_count]


class ClassificationDatasetError3d(Dataset):
    pass

class ClassificationDataset2d3d(Dataset):
    def __init__(self, data_2d_csv_file, data_3d_csv_file, label_id_column="label_id", npy_file_column="npy_file", class_column="label", node_column="node", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
        self.df_2d = pd.read_csv(data_2d_csv_file, header=0, encoding="utf-8")
        self.df_3d = pd.read_csv(data_3d_csv_file, header=0, encoding="utf-8")
        print(f"{data_info} lens: {len(self.df_2d)}, {len(self.df_3d)}")
        self.label_id_column = label_id_column
        self.npy_file_column = npy_file_column
        self.node_column = node_column
        self.class_column = class_column
        self.z_index_column = z_index_column
        self.rotate_column = rotate_column
        self.lens_data_2d = len(self.df_2d)
        self.lens_data_3d = len(self.df_3d)
        self.lens = max(self.lens_data_2d, self.lens_data_3d)

    def __len__(self):
        return self.lens
    
    @lru_cache(maxsize=10)
    def load_npy_data(self, npy_data_path):
        print(f"npy_data_path: {npy_data_path}")
        return np.load(npy_data_path)
    
    def __getitem__(self, idx):
        index_data_2d = idx if idx < self.lens_data_2d else idx % self.lens_data_2d
        index_data_3d = idx if idx < self.lens_data_3d else idx % self.lens_data_3d
        npy_data_path_2d = self.df_2d.loc[index_data_2d, self.npy_file_column]
        npy_data_path_3d = self.df_3d.loc[index_data_3d, self.npy_file_column]
        node_2d = self.df_2d.loc[index_data_2d, self.node_column]
        node_3d = self.df_3d.loc[index_data_3d, self.node_column]
        label_id_2d = self.df_2d.loc[index_data_2d, self.label_id_column]
        label_id_3d = self.df_3d.loc[index_data_3d, self.label_id_column]
        label_2d = self.df_2d.loc[index_data_2d, self.class_column]
        label_3d = self.df_3d.loc[index_data_3d, self.class_column]
        data_2d_z_index = self.df_2d.loc[index_data_2d, self.z_index_column]
        data_2d_rotate_count = self.df_2d.loc[index_data_2d, self.rotate_column]
        data_3d_z_index = self.df_3d.loc[index_data_3d, self.z_index_column]
        data_3d_rotate_count = self.df_3d.loc[index_data_3d, self.rotate_column]
        npy_data_2d = self.load_npy_data(npy_data_path_2d)
        npy_data_3d = self.load_npy_data(npy_data_path_3d)
        npy_data_2d = npy_data_2d[np.newaxis, np.newaxis, :, :] # [1, 1, 256, 256]
        npy_data_3d = npy_data_3d[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256]
        
        label_id_2d_str = f"{node_2d}_{label_id_2d}"
        label_id_3d_str = f"{node_3d}_{label_id_3d}"

        data_2d = torch.from_numpy(npy_data_2d).to(torch.float32)
        data_3d = torch.from_numpy(npy_data_3d).to(torch.float32)
        label_2d = torch.tensor([label_2d], dtype=torch.float32)
        label_3d = torch.tensor([label_3d], dtype=torch.float32)

        return [label_id_2d_str], data_2d, label_2d, [data_2d_z_index], [data_2d_rotate_count], [label_id_3d_str], data_3d, label_3d, [data_3d_z_index], [data_3d_rotate_count]



class ClassificationDatasetError2d3d(Dataset):
    pass



class ClassificationDatasetS3d(Dataset):
    def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
        self.df = pd.read_csv(csv_file, header=0, encoding="utf-8")
        print(f"{data_info} lens: {len(self.df)}")
        self.node_column = node_column
        self.label_id_column = label_id_column
        self.npy_file_column = npy_file_column
        self.class_column = class_column
        self.z_index_column = z_index_column
        self.rotate_column = rotate_column
        self.lens = len(self.df)

    def __len__(self):
        return self.lens
    
    @lru_cache(maxsize=10)
    def load_npy_data(self, npy_data_path):
        npy_data = np.load(npy_data_path)
        print(f"npy_data_path: {npy_data_path}, npy_data.shape: {npy_data.shape}")
        npy_data = npy_data[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256]
        return npy_data
    
    def __getitem__(self, idx):
        npy_file = self.df.loc[idx, self.npy_file_column]
        node = self.df.loc[idx, self.node_column]
        label_id = self.df.loc[idx, self.label_id_column]
        label = self.df.loc[idx, self.class_column]
        z_index = self.df.loc[idx, self.z_index_column]
        rotate_count = self.df.loc[idx, self.rotate_column]
        npy_data = self.load_npy_data(npy_file)

        label_id_str = f"{node}_{label_id}"
        data_3d = torch.from_numpy(npy_data).to(torch.float32)
        label = torch.tensor([label], dtype=torch.float32)
        return [label_id_str], data_3d, label, [z_index], [rotate_count]





class ClassificationDatasetErrorS3d(Dataset):
    pass


class ClassificationDatasetResnet3d(Dataset):
    def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
        self.df = pd.read_csv(csv_file, header=0, encoding="utf-8")
        print(f"{data_info} lens: {len(self.df)}")
        self.node_column = node_column
        self.label_id_column = label_id_column
        self.npy_file_column = npy_file_column
        self.class_column = class_column
        self.z_index_column = z_index_column
        self.rotate_column = rotate_column
        self.lens = len(self.df)

    def __len__(self):
        return self.lens
    
    @lru_cache(maxsize=10)
    def load_npy_data(self, npy_data_path):
        npy_data = np.load(npy_data_path)
        npy_data = npy_data[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256]
        return npy_data
    
    def __getitem__(self, idx):
        npy_file = self.df.loc[idx, self.npy_file_column]
        node = self.df.loc[idx, self.node_column]
        label_id = self.df.loc[idx, self.label_id_column]
        label = self.df.loc[idx, self.class_column]
        z_index = self.df.loc[idx, self.z_index_column]
        rotate_count = self.df.loc[idx, self.rotate_column]
        npy_data = self.load_npy_data(npy_file)
        label_id_str = f"{node}_{label_id}"
        data_3d = torch.from_numpy(npy_data).to(torch.float32)
        label = torch.tensor([label], dtype=torch.float32)
        return [label_id_str], data_3d, label, [z_index], [rotate_count]


class ClassificationDatasetErrorResnet3d(Dataset):
    pass




class ClassificationDatasetD2d(Dataset):
    def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
        self.df = pd.read_csv(csv_file, header=0, encoding="utf-8")
        print(f"{data_info} lens: {len(self.df)}")
        self.node_column = node_column
        self.label_id_column = label_id_column
        self.npy_file_column = npy_file_column
        self.class_column = class_column
        self.z_index_column = z_index_column
        self.rotate_column = rotate_column
        self.lens = len(self.df)

    def __len__(self):
        return self.lens
    
    @lru_cache(maxsize=10)
    def load_npy_data(self, npy_data_path):
        npy_data = np.load(npy_data_path) # [3, 256, 256]
        print(f"npy_data_path: {npy_data_path}, npy_data.shape: {npy_data.shape}")
        npy_data = npy_data[np.newaxis, :, :, :] # [1, 3, 256, 256]
        return npy_data
    
    def __getitem__(self, idx):
        npy_file = self.df.loc[idx, self.npy_file_column]
        node_time = self.df.loc[idx, self.node_column]
        label_id = self.df.loc[idx, self.label_id_column]
        label = self.df.loc[idx, self.class_column]
        z_index = self.df.loc[idx, self.z_index_column]
        rotate_count = self.df.loc[idx, self.rotate_column]
        npy_data = self.load_npy_data(npy_file)
        label_id_str = f"{node_time}_{label_id}"
        data_2d = torch.from_numpy(npy_data).to(torch.float32)
        label = torch.tensor([label], dtype=torch.float32)
        return [label_id_str], data_2d, label, [z_index], [rotate_count]


class ClassificationDatasetErrorD2d(Dataset):
    pass







def cls_report_dict_to_string(report_dict):
    lines = []

    # print(f"cls_report_dict_to_string, report_dict: {report_dict}")
    
    # 添加每类的指标
    for label, metrics in report_dict.items():
        if isinstance(metrics, dict):
            line = f"{label:<15} {metrics['precision']:<10.2f} {metrics['recall']:<10.2f} {metrics['f1-score']:<10.2f} {metrics['support']:<10.2f}\n"
            lines.append(line)
    
    # 计算总的支持数
    total_support = sum(metrics['support'] for metrics in report_dict.values() if isinstance(metrics, dict))
    
    # 添加总体指标
    accuracy_line = f"{'accuracy':<15} {'':<10} {'':<10} {report_dict['accuracy']:<10.2f} {total_support:<10.2f}\n"
    lines.append(accuracy_line)
    
    # 添加 macro avg 和 weighted avg
    for key in ['macro avg', 'weighted avg']:
        metrics = report_dict[key]
        line = f"{key:<15} {metrics['precision']:<10.2f} {metrics['recall']:<10.2f} {metrics['f1-score']:<10.2f} {metrics['support']:<10.2f}\n"
        lines.append(line)
    
    line_first = f"{'':<15} {'precision':<10} {'recall':<10} {'f1-score':<10} {'support':<10}\n"
    lines.insert(0, line_first)
    return ''.join(lines)