import copy import torch from torch.utils.data import Dataset import pandas as pd import numpy as np from functools import lru_cache def custom_collate_fn_2d(batch): label_id, npy_data_2d, label, z_index, rotate_count = zip(*batch) label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id] z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index] rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count] npy_data_2d = torch.cat(npy_data_2d, dim=0) label = torch.cat(label, dim=0) return label_id, npy_data_2d, label, z_index, rotate_count def custom_collate_fn_2d_error(batch): pass def custom_collate_fn_3d(batch): label_id, npy_data_3d, label, z_index, rotate_count = zip(*batch) label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id] z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index] rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count] npy_data_3d = torch.cat(npy_data_3d, dim=0) label = torch.cat(label, dim=0) return label_id, npy_data_3d, label, z_index, rotate_count def custom_collate_fn_3d_error(batch): pass def custom_collate_fn_2d3d(batch): label_id_2d, npy_data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, npy_data_3d, label_3d, z_index_3d, rotate_count_3d = zip(*batch) label_id_2d = [idx_label_id for sub_label_id in label_id_2d for idx_label_id in sub_label_id] label_id_3d = [idx_label_id for sub_label_id in label_id_3d for idx_label_id in sub_label_id] z_index_2d = [idx_z_index for sub_z_index in z_index_2d for idx_z_index in sub_z_index] z_index_3d = [idx_z_index for sub_z_index in z_index_3d for idx_z_index in sub_z_index] rotate_count_2d = [idx_rotate_count for sub_rotate_count in rotate_count_2d for idx_rotate_count in sub_rotate_count] rotate_count_3d = [idx_rotate_count for sub_rotate_count in rotate_count_3d for idx_rotate_count in sub_rotate_count] npy_data_2d = torch.cat(npy_data_2d, dim=0) npy_data_3d = torch.cat(npy_data_3d, dim=0) label_2d = torch.cat(label_2d, dim=0) label_3d = torch.cat(label_3d, dim=0) return label_id_2d, npy_data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, npy_data_3d, label_3d, z_index_3d, rotate_count_3d def custom_collate_fn_2d3d_error(batch): pass def custom_collate_fn_s3d(batch): label_id, npy_data, label, z_index, rotate_count = zip(*batch) label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id] z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index] rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count] npy_data = torch.cat(npy_data, dim=0) label = torch.cat(label, dim=0) return label_id, npy_data, label, z_index, rotate_count def custom_collate_fn_s3d_error(batch): pass def custom_collate_fn_resnet3d(batch): label_id, npy_data, label, z_index, rotate_count = zip(*batch) label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id] z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index] rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count] npy_data = torch.cat(npy_data, dim=0) label = torch.cat(label, dim=0) return label_id, npy_data, label, z_index, rotate_count def custom_collate_fn_resnet3d_error(batch): pass def custom_collate_fn_d2d(batch): label_id, npy_data_2d, label, z_index, rotate_count = zip(*batch) label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id] z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index] rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count] npy_data_2d = torch.cat(npy_data_2d, dim=0) label = torch.cat(label, dim=0) return label_id, npy_data_2d, label, z_index, rotate_count def custom_collate_fn_d2d_error(batch): pass class ClassificationDataset2d(Dataset): def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"): self.df = pd.read_csv(csv_file, header=0, encoding="utf-8") print(f"{data_info} lens: {len(self.df)}") self.node_column = node_column self.label_id_column = label_id_column self.npy_file_column = npy_file_column self.class_column = class_column self.z_index_column = z_index_column self.rotate_column = rotate_column self.lens = len(self.df) def __len__(self): return self.lens @lru_cache(maxsize=10) def load_npy_data(self, npy_data_path): data = np.load(npy_data_path) # [256, 256] data = data[np.newaxis, np.newaxis, :, :] # [1, 1,256, 256] return data def __getitem__(self, idx): npy_file = self.df.loc[idx, self.npy_file_column] node = self.df.loc[idx, self.node_column] label_id = self.df.loc[idx, self.label_id_column] label = self.df.loc[idx, self.class_column] z_index = self.df.loc[idx, self.z_index_column] rotate_count = self.df.loc[idx, self.rotate_column] npy_data = self.load_npy_data(npy_file) label_id_str = f"{node}_{label_id}" data_2d = torch.from_numpy(npy_data).to(torch.float32) label = torch.tensor([label], dtype=torch.float32) return [label_id_str], data_2d, label, [z_index], [rotate_count] class ClassificationDatasetError2d(Dataset): pass class ClassificationDataset3d(Dataset): def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"): self.df = pd.read_csv(csv_file, header=0, encoding="utf-8") print(f"{data_info} lens: {len(self.df)}") self.node_column = node_column self.label_id_column = label_id_column self.npy_file_column = npy_file_column self.class_column = class_column self.z_index_column = z_index_column self.rotate_column = rotate_column self.lens = len(self.df) def __len__(self): return self.lens @lru_cache(maxsize=10) def load_npy_data(self, npy_data_path): data = np.load(npy_data_path) # [48, 256, 256] print(f"npy_data_path: {npy_data_path}, data.shape: {data.shape}") data = data[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256] return data def __getitem__(self, idx): npy_file = self.df.loc[idx, self.npy_file_column] node = self.df.loc[idx, self.node_column] label_id = self.df.loc[idx, self.label_id_column] label = self.df.loc[idx, self.class_column] z_index = self.df.loc[idx, self.z_index_column] rotate_count = self.df.loc[idx, self.rotate_column] npy_data = self.load_npy_data(npy_file) label_id_str = f"{node}_{label_id}" data_3d = torch.from_numpy(npy_data).to(torch.float32) label = torch.tensor([label], dtype=torch.float32) return [label_id_str], data_3d, label, [z_index], [rotate_count] class ClassificationDatasetError3d(Dataset): pass class ClassificationDataset2d3d(Dataset): def __init__(self, data_2d_csv_file, data_3d_csv_file, label_id_column="label_id", npy_file_column="npy_file", class_column="label", node_column="node", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"): self.df_2d = pd.read_csv(data_2d_csv_file, header=0, encoding="utf-8") self.df_3d = pd.read_csv(data_3d_csv_file, header=0, encoding="utf-8") print(f"{data_info} lens: {len(self.df_2d)}, {len(self.df_3d)}") self.label_id_column = label_id_column self.npy_file_column = npy_file_column self.node_column = node_column self.class_column = class_column self.z_index_column = z_index_column self.rotate_column = rotate_column self.lens_data_2d = len(self.df_2d) self.lens_data_3d = len(self.df_3d) self.lens = max(self.lens_data_2d, self.lens_data_3d) def __len__(self): return self.lens @lru_cache(maxsize=10) def load_npy_data(self, npy_data_path): print(f"npy_data_path: {npy_data_path}") return np.load(npy_data_path) def __getitem__(self, idx): index_data_2d = idx if idx < self.lens_data_2d else idx % self.lens_data_2d index_data_3d = idx if idx < self.lens_data_3d else idx % self.lens_data_3d npy_data_path_2d = self.df_2d.loc[index_data_2d, self.npy_file_column] npy_data_path_3d = self.df_3d.loc[index_data_3d, self.npy_file_column] node_2d = self.df_2d.loc[index_data_2d, self.node_column] node_3d = self.df_3d.loc[index_data_3d, self.node_column] label_id_2d = self.df_2d.loc[index_data_2d, self.label_id_column] label_id_3d = self.df_3d.loc[index_data_3d, self.label_id_column] label_2d = self.df_2d.loc[index_data_2d, self.class_column] label_3d = self.df_3d.loc[index_data_3d, self.class_column] data_2d_z_index = self.df_2d.loc[index_data_2d, self.z_index_column] data_2d_rotate_count = self.df_2d.loc[index_data_2d, self.rotate_column] data_3d_z_index = self.df_3d.loc[index_data_3d, self.z_index_column] data_3d_rotate_count = self.df_3d.loc[index_data_3d, self.rotate_column] npy_data_2d = self.load_npy_data(npy_data_path_2d) npy_data_3d = self.load_npy_data(npy_data_path_3d) npy_data_2d = npy_data_2d[np.newaxis, np.newaxis, :, :] # [1, 1, 256, 256] npy_data_3d = npy_data_3d[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256] label_id_2d_str = f"{node_2d}_{label_id_2d}" label_id_3d_str = f"{node_3d}_{label_id_3d}" data_2d = torch.from_numpy(npy_data_2d).to(torch.float32) data_3d = torch.from_numpy(npy_data_3d).to(torch.float32) label_2d = torch.tensor([label_2d], dtype=torch.float32) label_3d = torch.tensor([label_3d], dtype=torch.float32) return [label_id_2d_str], data_2d, label_2d, [data_2d_z_index], [data_2d_rotate_count], [label_id_3d_str], data_3d, label_3d, [data_3d_z_index], [data_3d_rotate_count] class ClassificationDatasetError2d3d(Dataset): pass class ClassificationDatasetS3d(Dataset): def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"): self.df = pd.read_csv(csv_file, header=0, encoding="utf-8") print(f"{data_info} lens: {len(self.df)}") self.node_column = node_column self.label_id_column = label_id_column self.npy_file_column = npy_file_column self.class_column = class_column self.z_index_column = z_index_column self.rotate_column = rotate_column self.lens = len(self.df) def __len__(self): return self.lens @lru_cache(maxsize=10) def load_npy_data(self, npy_data_path): npy_data = np.load(npy_data_path) print(f"npy_data_path: {npy_data_path}, npy_data.shape: {npy_data.shape}") npy_data = npy_data[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256] return npy_data def __getitem__(self, idx): npy_file = self.df.loc[idx, self.npy_file_column] node = self.df.loc[idx, self.node_column] label_id = self.df.loc[idx, self.label_id_column] label = self.df.loc[idx, self.class_column] z_index = self.df.loc[idx, self.z_index_column] rotate_count = self.df.loc[idx, self.rotate_column] npy_data = self.load_npy_data(npy_file) label_id_str = f"{node}_{label_id}" data_3d = torch.from_numpy(npy_data).to(torch.float32) label = torch.tensor([label], dtype=torch.float32) return [label_id_str], data_3d, label, [z_index], [rotate_count] class ClassificationDatasetErrorS3d(Dataset): pass class ClassificationDatasetResnet3d(Dataset): def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"): self.df = pd.read_csv(csv_file, header=0, encoding="utf-8") print(f"{data_info} lens: {len(self.df)}") self.node_column = node_column self.label_id_column = label_id_column self.npy_file_column = npy_file_column self.class_column = class_column self.z_index_column = z_index_column self.rotate_column = rotate_column self.lens = len(self.df) def __len__(self): return self.lens @lru_cache(maxsize=10) def load_npy_data(self, npy_data_path): npy_data = np.load(npy_data_path) npy_data = npy_data[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256] return npy_data def __getitem__(self, idx): npy_file = self.df.loc[idx, self.npy_file_column] node = self.df.loc[idx, self.node_column] label_id = self.df.loc[idx, self.label_id_column] label = self.df.loc[idx, self.class_column] z_index = self.df.loc[idx, self.z_index_column] rotate_count = self.df.loc[idx, self.rotate_column] npy_data = self.load_npy_data(npy_file) label_id_str = f"{node}_{label_id}" data_3d = torch.from_numpy(npy_data).to(torch.float32) label = torch.tensor([label], dtype=torch.float32) return [label_id_str], data_3d, label, [z_index], [rotate_count] class ClassificationDatasetErrorResnet3d(Dataset): pass class ClassificationDatasetD2d(Dataset): def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"): self.df = pd.read_csv(csv_file, header=0, encoding="utf-8") print(f"{data_info} lens: {len(self.df)}") self.node_column = node_column self.label_id_column = label_id_column self.npy_file_column = npy_file_column self.class_column = class_column self.z_index_column = z_index_column self.rotate_column = rotate_column self.lens = len(self.df) def __len__(self): return self.lens @lru_cache(maxsize=10) def load_npy_data(self, npy_data_path): npy_data = np.load(npy_data_path) # [3, 256, 256] print(f"npy_data_path: {npy_data_path}, npy_data.shape: {npy_data.shape}") npy_data = npy_data[np.newaxis, :, :, :] # [1, 3, 256, 256] return npy_data def __getitem__(self, idx): npy_file = self.df.loc[idx, self.npy_file_column] node_time = self.df.loc[idx, self.node_column] label_id = self.df.loc[idx, self.label_id_column] label = self.df.loc[idx, self.class_column] z_index = self.df.loc[idx, self.z_index_column] rotate_count = self.df.loc[idx, self.rotate_column] npy_data = self.load_npy_data(npy_file) label_id_str = f"{node_time}_{label_id}" data_2d = torch.from_numpy(npy_data).to(torch.float32) label = torch.tensor([label], dtype=torch.float32) return [label_id_str], data_2d, label, [z_index], [rotate_count] class ClassificationDatasetErrorD2d(Dataset): pass def cls_report_dict_to_string(report_dict): lines = [] # print(f"cls_report_dict_to_string, report_dict: {report_dict}") # 添加每类的指标 for label, metrics in report_dict.items(): if isinstance(metrics, dict): line = f"{label:<15} {metrics['precision']:<10.2f} {metrics['recall']:<10.2f} {metrics['f1-score']:<10.2f} {metrics['support']:<10.2f}\n" lines.append(line) # 计算总的支持数 total_support = sum(metrics['support'] for metrics in report_dict.values() if isinstance(metrics, dict)) # 添加总体指标 accuracy_line = f"{'accuracy':<15} {'':<10} {'':<10} {report_dict['accuracy']:<10.2f} {total_support:<10.2f}\n" lines.append(accuracy_line) # 添加 macro avg 和 weighted avg for key in ['macro avg', 'weighted avg']: metrics = report_dict[key] line = f"{key:<15} {metrics['precision']:<10.2f} {metrics['recall']:<10.2f} {metrics['f1-score']:<10.2f} {metrics['support']:<10.2f}\n" lines.append(line) line_first = f"{'':<15} {'precision':<10} {'recall':<10} {'f1-score':<10} {'support':<10}\n" lines.insert(0, line_first) return ''.join(lines)