import matplotlib.pyplot as plt import numpy as np def plot_accuracies(history): accuracies = [x['val_acc'] for x in history] plt.plot(accuracies, '-x') plt.xlabel('epoch') plt.ylabel('accuracy') plt.title('Accuracy vs. No. of epochs') plt.savefig('accuracy.png') plt.show() def plot_losses(history): train_losses = [x.get('train_loss') for x in history] val_losses = [x['val_loss'] for x in history] plt.plot(train_losses, '-bx') plt.plot(val_losses, '-rx') plt.xlabel('epoch') plt.ylabel('loss') plt.legend(['Training', 'Validation']) plt.title('Loss vs. No. of epochs') plt.savefig('loss.png') plt.show() def plot_lrs(history): lrs = np.concatenate([x.get('lrs', []) for x in history]) plt.plot(lrs) plt.xlabel('Batch no.') plt.ylabel('Learning rate') plt.title('Learning Rate vs. Batch no.'); plt.savefig('learning_rate.png') plt.show()