import numpy as np #from matplotlib_inline import backend_inline from matplotlib import pyplot as plt import os import torch class Accumulator: """在n个变量上累加""" def __init__(self, n): self.data = [0.0] * n def add(self, *args): self.data = [a + float(b) for a, b in zip(self.data, args)] def reset(self): self.data = [0.0] * len(self.data) def __getitem__(self, idx): return self.data[idx] class Animator: def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale='linear', yscale='linear', fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1, figsize=(5, 5)): if legend is None: legend = [] #仅仅是为了将格式改为scg便于在juoyter上显示 #self.use_svg_display() self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize) #print(self.axes) if nrows * ncols == 1: self.axes = [self.axes, ] print('成功') print(self.axes) #使用lambda函数捕获参数 print(self.axes[0]) self.config_axes = lambda: self.set_axes( axes=self.axes[0], xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale, legend=legend) self.X, self.Y, self.fmts = None, None, fmts ''' def use_svg_display(self): """ 使用svg格式在jupyter中显示绘图 """ backend_inline.set_matplotlib_formats('svg') ''' def set_axes(self, axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend): #print(axes) axes.set_xlabel(xlabel), axes.set_ylabel(ylabel) axes.set_xscale(xscale), axes.set_yscale(yscale) axes.set_xlim(xlim), axes.set_ylim(ylim) if legend: axes.legend(legend) axes.grid() def add(self, x, y): #向图表中添加多个数据点 #hasattr(object, name) 用于判断对象是否包含对应的属性 if not hasattr(y, '__len__'): y = [y] n = len(y) if not hasattr(x, '__len__'): x = [x] * n if not self.X: self.X = [[] for _ in range(n)] if not self.Y: self.Y = [[] for _ in range(n)] for i, (a, b) in enumerate(zip(x, y)): if a is not None and b is not None: self.X[i].append(a) self.Y[i].append(b) #用于清空axes对象的所有绘图元素 self.axes[0].cla() for x, y, fmt in zip(self.X, self.Y, self.fmts): print(x,' ', y) self.axes[0].plot(x, y, fmt) self.config_axes() print('更新一次') #self.fig.canvas.draw() #self.fig.canvas.flush_events() #plt.draw() plt.pause(2) #npy文件中的数据加载出来 def load_npy_to_data(input_file=None): original_data = np.load(input_file) return original_data #将一个固定的npy数组以图片的形式显示出来 def show_img(): input_file = r'/home/lung/project/ai-project/cls_train/data/train_data/plus_3d_0818/npy_data/cls_2047/495.npy' data = load_npy_to_data(input_file) #print(data) data = data[29].astype(np.float32) print(data) plt.imshow(data, cmap='gray') plt.show() #索引为24的数据是中心面 def show_img_2d(): input_file = r'/home/lung/ai-project/cls_train/data/train_data/plus_0512/npy_data/cls_2021/1893_10.npy' data = load_npy_to_data(input_file) data = data.astype(np.float32) plt.imshow(data, cmap='gray') plt.show() #测试数据 def run(): test = [[0.1, 0.2], [0.3, 0.6]] thred = 0.5 test = torch.tensor(test) result = test > thred print(result.float()) #print(test > thred) def test_sum(): test = [[[1, 1, 2], [2, 3, 4]], [[2, 0, 0], [9, 9, 9]]] test = np.array(test) print(np.sum(test, axis=0)) #获取多维数组中值大于指定值的索引 def test(): mask = [[[ True, True, True], [False, False, False], [ True, True, True]], [[ True, False, True], [False, False, False], [ True, True, True]]] mask = np.array(mask) indices = np.asarray(np.where(mask == 1)) print(indices) print(indices.min(axis=1)) print(indices.max(axis=1)) #将两个图像取差值,最后输出结果图像查看 def check_image(): for index in range(22, 37): data_1 = load_npy_to_data(f'/home/lung/project/ai-project/cls_train/log/npy/01/{index}.npy') data_2 = load_npy_to_data(f'/home/lung/project/ai-project/cls_train/log/npy/03/{index}.npy') #data = np.where(data_1 == data_2, data_1, 0) data = data_2 - data_1 plt.imsave(f'/home/lung/project/ai-project/cls_train/log/image/temp_02/{index}.png', data, cmap='gray') if __name__ == '__main__': #test_sum() #check_image() show_img()