import os import json import torch import torch.nn as nn from torch.cuda.amp import autocast, GradScaler from tqdm import tqdm from utils.helper import get_lr, evaluate from utils import constants from pytorchtools import EarlyStopping from devkit.core.dist_utils import average_gradients, broadcast_params def fit_one_cycle(epochs, start_epoch, max_lr, model, train_loader, val_loader, train_sampler, val_sampler, pathwrapper, weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD, logfunc=None, ddpmodule=False, savepath=None, g_writer=None, last_best_loss=float('inf')): if logfunc is None: logfunc = print torch.cuda.empty_cache() scaler = GradScaler() history = [] best_val_acc = 0 best_epoch = start_epoch best_loss = last_best_loss # patience = 20 # early_stopping = EarlyStopping(patience=patience, verbose=True) # setup custom optimizer with weight decay optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay) # setup one-cycle learning rate scheduler scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, steps_per_epoch=len(train_loader)) for epoch in range(start_epoch, epochs): # Training Phase model.train() ################################## if constants.USE_DDP: train_sampler.set_epoch(epoch) val_sampler.set_epoch(epoch) ################################## train_losses = [] train_accs = [] lrs = [] for batch in tqdm(train_loader): optimizer.zero_grad() if constants.MIX_TRAINING: ############################ # 建立autocast的上下文语句 with autocast(): if ddpmodule: loss, acc = model.module.training_step(batch) else: loss, acc = model.training_step(batch) last_scale = scaler.get_scale() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() skip_lr_sched = (last_scale != scaler.get_scale()) ########################## train_losses.append(loss) train_accs.append(acc) ########################## if constants.USE_DDP: average_gradients(model) ########################## # Gradient clipping if grad_clip: nn.utils.clip_grad_value_(model.parameters(), grad_clip) # optimizer.step() # Record and update learning rate ########################## if not skip_lr_sched: lrs.append(get_lr(optimizer)) scheduler.step() ########################## else: if ddpmodule: loss, acc = model.module.training_step(batch) else: loss, acc = model.training_step(batch) train_losses.append(loss) train_accs.append(acc) ########################## loss.backward() ########################## ########################## if constants.USE_DDP: average_gradients(model) ########################## # Gradient clipping if grad_clip: nn.utils.clip_grad_value_(model.parameters(), grad_clip) optimizer.step() # Record and update learning rate lrs.append(get_lr(optimizer)) scheduler.step() current_lr = get_lr(optimizer) # Validation phase if ddpmodule: result = evaluate(model.module, val_loader) else: result = evaluate(model, val_loader) train_losses_mean = torch.stack(train_losses).mean().item() train_accs_mean = torch.stack(train_accs).mean().item() result['train_loss'] = train_losses_mean result['train_acc'] = train_accs_mean result['lrs'] = lrs if ddpmodule: model.module.epoch_end(epoch, result) else: model.epoch_end(epoch, result) # EarlyStopping # early_stopping(result['val_loss'], model) # if early_stopping.early_stop: # print("Early stopping") # break if constants.G_RANK == 0: # ================================== # 计算整体的性能指标 # 通过之前获取的writer反馈给平台 # 这些指标的名称、作为key、是需要与平台端对齐 if g_writer is not None: g_writer.append_one_line([epoch, train_accs_mean, current_lr, train_losses_mean]) else: logfunc(fr'input g_writer is None ! write nothing') # ================================== # finding out the model with best val_acc # is_best = result['val_acc'] > best_val_acc is_best = result['val_loss'] < best_loss # model_path = constants.CHECKPOINT_PATH if savepath is None else fr'{savepath}/model.pth' model_path = constants.CHECKPOINT_PATH if savepath is None else fr'{savepath}/model.pth' if is_best: best_epoch = epoch logfunc(fr'saving best model @ Loss: {result["val_loss"]}, ACC: {best_val_acc} -> {result["val_acc"]}, epoch: {epoch}') else: logfunc(fr'saving model @ Loss: {result["val_loss"]}, ACC: {result["val_acc"]}, epoch: {epoch}') best_val_acc = max(best_val_acc, result['val_acc']) best_loss = min(best_loss, result['val_loss']) state = { 'epoch': epoch, 'state_dict': model.module.state_dict() if ddpmodule else model.state_dict() , 'optimizer_state_dict': optimizer.state_dict(), 'best_loss': result['val_loss'], 'best_acc': result['val_acc'], } # save_checkpoint(state=state, is_best=is_best, filename=model_path) # ================================== # 存储训练的权重 logfunc(fr'writing ckpt to {pathwrapper.get_output_train_latestmodel_pth_filepath()}') torch.save(state, pathwrapper.get_output_train_latestmodel_pth_filepath()) if is_best: logfunc(fr'writing best ckpt to {pathwrapper.get_output_train_bestmodel_pth_filepath()}') torch.save(state, pathwrapper.get_output_train_bestmodel_pth_filepath()) # ================================== history.append(result) # ================================== # wait for all process before check cancle flag torch.distributed.barrier() # change if pathwrapper.get_input_cancle_flag(): logfunc(fr'epoch end, cancle flag detected, stop training epochs') break # ================================== return history, best_epoch, best_loss, best_val_acc def save_checkpoint(state, is_best, filename): import shutil import os torch.save(state, filename) if is_best: shutil.copyfile(filename, os.path.join(os.path.dirname(filename), 'best_loss.pth'))