#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import bisect
import datetime
import math
import os
import pickle
import time
from typing import Optional
import megengine as mge
import megengine.distributed as dist
import megengine.module as M
from basecore.config import ConfigDict
from basecore.engine import BaseHook, BaseTrainer
from basecore.utils import (
Checkpoint,
MeterBuffer,
cached_property,
ensure_dir,
get_last_call_deltatime,
)
from loguru import logger
from tensorboardX import SummaryWriter
from basecls.layers import compute_precise_bn_stats
from basecls.models import sync_model
from basecls.utils import default_logging, registers
from .tester import ClsTester
__all__ = [
"CheckpointHook",
"EvalHook",
"LoggerHook",
"LRSchedulerHook",
"PreciseBNHook",
"ResumeHook",
"TensorboardHook",
]
def _create_checkpoint(trainer: BaseTrainer, save_dir: str) -> Checkpoint:
"""Create a checkpoint for save and resume"""
model = trainer.model
ema = trainer.ema
ckpt_kws = {"ema": ema} if ema is not None else {}
optim = trainer.solver.optimizer
scaler = trainer.solver.grad_scaler
progress = trainer.progress
ckpt = Checkpoint(
save_dir,
model,
tag_file=None,
optimizer=optim,
scaler=scaler,
progress=progress,
**ckpt_kws,
)
return ckpt
[文档]class CheckpointHook(BaseHook):
"""Hook for managing checkpoints during training.
Effect during ``after_epoch`` and ``after_train`` procedure.
Args:
save_dir: checkpoint directory.
save_every_n_epoch: interval for saving checkpoint. Default: ``1``
"""
def __init__(self, save_dir: str = None, save_every_n_epoch: int = 1):
super().__init__()
ensure_dir(save_dir)
self.save_dir = save_dir
self.save_every_n_epoch = save_every_n_epoch
[文档] def after_epoch(self):
progress = self.trainer.progress
ckpt = _create_checkpoint(self.trainer, self.save_dir)
ckpt.save("latest.pkl")
if progress.epoch % self.save_every_n_epoch == 0:
progress_str = progress.progress_str_list()
save_name = "_".join(progress_str[:-1]) + ".pkl"
ckpt.save(save_name)
logger.info(f"Save checkpoint {save_name} to {self.save_dir}")
[文档] def after_train(self):
# NOTE: usually final ema is not the best so we dont save it
mge.save(
{"state_dict": self.trainer.model.state_dict()},
os.path.join(self.save_dir, "dumped_model.pkl"),
pickle_protocol=pickle.DEFAULT_PROTOCOL,
)
[文档]class EvalHook(BaseHook):
"""Hook for evaluating during training.
Effect during ``after_epoch`` and ``after_train`` procedure.
Args:
save_dir: checkpoint directory.
eval_every_n_epoch: interval for evaluating. Default: ``1``
"""
def __init__(self, save_dir: str = None, eval_every_n_epoch: int = 1):
super().__init__()
ensure_dir(save_dir)
self.save_dir = save_dir
self.eval_every_n_epoch = eval_every_n_epoch
self.best_acc1 = 0
self.best_ema_acc1 = 0
[文档] def after_epoch(self):
trainer = self.trainer
cfg = trainer.cfg
model = trainer.model
ema = trainer.ema
progress = trainer.progress
if progress.epoch % self.eval_every_n_epoch == 0 and progress.epoch != progress.max_epoch:
self.test(cfg, model, ema)
[文档] def after_train(self):
trainer = self.trainer
cfg = trainer.cfg
model = trainer.model
ema = trainer.ema
# TODO: actually useless maybe when precise_bn is on
sync_model(model)
if ema is not None:
sync_model(ema)
self.test(cfg, model, ema)
[文档] def test(self, cfg: ConfigDict, model: M.Module, ema: Optional[M.Module] = None):
dataloader = registers.dataloaders.get(cfg.data.name).build(cfg, False)
# FIXME: need atomic user_pop, maybe in MegEngine 1.5?
# tester = BaseTester(model, dataloader, AccEvaluator())
tester = ClsTester(cfg, model, dataloader)
acc1, _ = tester.test()
if acc1 > self.best_acc1:
self.best_acc1 = acc1
if dist.get_rank() == 0:
mge.save(
{"state_dict": model.state_dict(), "acc1": self.best_acc1},
os.path.join(self.save_dir, "best_model.pkl"),
pickle_protocol=pickle.DEFAULT_PROTOCOL,
)
logger.info(
f"Epoch: {self.trainer.progress.epoch}, Test Acc@1: {acc1:.3f}, "
f"Best Test Acc@1: {self.best_acc1:.3f}"
)
if ema is None:
return
tester_ema = ClsTester(cfg, ema, dataloader)
ema_acc1, _ = tester_ema.test()
if ema_acc1 > self.best_ema_acc1:
self.best_ema_acc1 = ema_acc1
if dist.get_rank() == 0:
mge.save(
{"state_dict": ema.state_dict(), "acc1": self.best_ema_acc1},
os.path.join(self.save_dir, "best_ema_model.pkl"),
pickle_protocol=pickle.DEFAULT_PROTOCOL,
)
logger.info(
f"Epoch: {self.trainer.progress.epoch}, EMA Acc@1: {ema_acc1:.3f}, "
f"Best EMA Acc@1: {self.best_ema_acc1:.3f}"
)
[文档]class LoggerHook(BaseHook):
"""Hook for logging during training.
Effect during ``before_train``, ``after_train``, ``before_iter`` and ``after_iter`` procedure.
Args:
log_every_n_iter: interval for logging. Default: ``20``
"""
def __init__(self, log_every_n_iter: int = 20):
super().__init__()
self.log_every_n_iter = log_every_n_iter
self.meter = MeterBuffer(self.log_every_n_iter)
[文档] def before_train(self):
trainer = self.trainer
progress = trainer.progress
default_logging(trainer.cfg, trainer.model)
logger.info(f"Starting training from epoch {progress.epoch}, iteration {progress.iter}")
self.start_training_time = time.perf_counter()
[文档] def after_train(self):
total_training_time = time.perf_counter() - self.start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info(
"Total training time: {} ({:.4f} s / iter)".format(
total_time_str, self.meter["iters_time"].global_avg
)
)
[文档] def before_iter(self):
self.iter_start_time = time.perf_counter()
[文档] def after_iter(self):
single_iter_time = time.perf_counter() - self.iter_start_time
delta_time = get_last_call_deltatime()
if delta_time is None:
delta_time = single_iter_time
self.meter.update(
{
"iters_time": single_iter_time, # to get global average iter time
"eta_iter_time": delta_time, # to get ETA time
"extra_time": delta_time - single_iter_time, # to get extra time
}
)
trainer = self.trainer
progress = trainer.progress
epoch_id, iter_id = progress.epoch, progress.iter
max_epoch, max_iter = progress.max_epoch, progress.max_iter
if iter_id % self.log_every_n_iter == 0 or (iter_id == 1 and epoch_id == 1):
log_str_list = []
# step info string
log_str_list.append(str(progress))
# loss string
log_str_list.append(self.get_loss_str(trainer.meter))
# stat string
log_str_list.append(self.get_stat_str(trainer.meter))
# other training info like learning rate.
log_str_list.append(self.get_train_info_str())
# memory useage.
log_str_list.append(self.get_memory_str(trainer.meter))
# time string
left_iters = max_iter - iter_id + (max_epoch - epoch_id) * max_iter
time_str = self.get_time_str(left_iters)
log_str_list.append(time_str)
# filter empty strings
log_str_list = [s for s in log_str_list if len(s) > 0]
log_str = ", ".join(log_str_list)
logger.info(log_str)
# reset meters in trainer
trainer.meter.reset()
[文档] def get_loss_str(self, meter):
"""Get loss information during trainging process."""
loss_dict = meter.get_filtered_meter(filter_key="loss")
loss_str = ", ".join(
[f"{name}:{value.latest:.3f}({value.avg:.3f})" for name, value in loss_dict.items()]
)
return loss_str
[文档] def get_stat_str(self, meter):
"""Get stat information during trainging process."""
stat_dict = meter.get_filtered_meter(filter_key="stat")
stat_str = ", ".join(
[f"{name}:{value.latest:.3f}({value.avg:.3f})" for name, value in stat_dict.items()]
)
return stat_str
[文档] def get_memory_str(self, meter):
"""Get memory information during trainging process."""
def mem_in_Mb(mem_value):
return math.ceil(mem_value / 1024 / 1024)
mem_dict = meter.get_filtered_meter(filter_key="memory")
mem_str = ", ".join(
[
f"{name}:{mem_in_Mb(value.latest)}({mem_in_Mb(value.avg)})Mb"
for name, value in mem_dict.items()
]
)
return mem_str
[文档] def get_train_info_str(self):
"""Get training process related information such as learning rate."""
# extra info to display, such as learning rate
trainer = self.trainer
lr = trainer.solver.optimizer.param_groups[0]["lr"]
lr_str = f"lr:{lr:.3e}"
loss_scale = trainer.solver.grad_scaler.scale_factor
loss_scale_str = f", amp_loss_scale:{loss_scale:.1f}" if trainer.cfg.amp.enabled else ""
return lr_str + loss_scale_str
[文档] def get_time_str(self, left_iters: int) -> str:
"""Get time related information sucn as data_time, train_time, ETA and so on."""
# time string
trainer = self.trainer
time_dict = trainer.meter.get_filtered_meter(filter_key="time")
train_time_str = ", ".join(
[f"{name}:{value.avg:.3f}s" for name, value in time_dict.items()]
)
train_time_str += ", extra_time:{:.3f}s, ".format(self.meter["extra_time"].avg)
eta_seconds = self.meter["eta_iter_time"].global_avg * left_iters
eta_string = "ETA:{}".format(datetime.timedelta(seconds=int(eta_seconds)))
time_str = train_time_str + eta_string
return time_str
[文档]class LRSchedulerHook(BaseHook):
"""Hook for learning rate scheduling during training.
Effect during ``before_epoch`` procedure.
"""
[文档] def before_epoch(self):
trainer = self.trainer
epoch_id = trainer.progress.epoch
cfg = trainer.cfg.solver
lr_factor = self.get_lr_factor(cfg, epoch_id)
if epoch_id <= cfg.warmup_epochs:
alpha = (epoch_id - 1) / cfg.warmup_epochs
lr_factor *= cfg.warmup_factor * (1 - alpha) + alpha
scaled_lr = self.total_lr * lr_factor
for param_group in trainer.solver.optimizer.param_groups:
param_group["lr"] = scaled_lr
[文档] def get_lr_factor(self, cfg: ConfigDict, epoch_id: int) -> float:
"""Calculate learning rate factor.
It supports ``"step"``, ``"linear"``, ``"cosine"``, ``"exp"``, and ``"rel_exp"`` schedule.
Args:
cfg: config for training.
epoch_id: current epoch.
Returns:
Learning rate factor.
"""
if cfg.lr_schedule == "step":
return cfg.lr_decay_factor ** bisect.bisect_left(cfg.lr_decay_steps, epoch_id)
elif cfg.lr_schedule == "linear":
alpha = 1 - (epoch_id - 1) / cfg.max_epoch
return (1 - cfg.lr_min_factor) * alpha + cfg.lr_min_factor
elif cfg.lr_schedule == "cosine":
alpha = 0.5 * (1 + math.cos(math.pi * (epoch_id - 1) / cfg.max_epoch))
return (1 - cfg.lr_min_factor) * alpha + cfg.lr_min_factor
elif cfg.lr_schedule == "exp":
return cfg.lr_decay_factor ** (epoch_id - 1)
elif cfg.lr_schedule == "rel_exp":
if cfg.lr_min_factor <= 0:
raise ValueError(
"Exponential lr schedule requires lr_min_factor to be greater than 0"
)
return cfg.lr_min_factor ** ((epoch_id - 1) / cfg.max_epoch)
else:
raise NotImplementedError(f"Learning rate schedule '{cfg.lr_schedule}' not supported")
@cached_property
def total_lr(self) -> float:
"""Total learning rate."""
cfg = self.trainer.cfg.solver
total_lr = cfg.basic_lr * dist.get_world_size() # linear scaling rule
return total_lr
[文档]class PreciseBNHook(BaseHook):
"""Hook for precising BN during training.
Effect during ``after_epoch`` procedure.
Args:
precise_every_n_epoch: interval for precising BN. Default: ``1``
"""
def __init__(self, precise_every_n_epoch: int = 1):
super().__init__()
self.precise_every_n_epoch = precise_every_n_epoch
[文档] def before_train(self):
if self.precise_every_n_epoch == -1:
self.precise_every_n_epoch = self.trainer.progress.max_epoch
[文档] def after_epoch(self):
trainer = self.trainer
if (
trainer.progress.epoch % self.precise_every_n_epoch == 0
and trainer.cfg.bn.num_samples_precise > 0
):
logger.info(f"Apply Precising BN at epoch{trainer.progress.epoch}")
compute_precise_bn_stats(trainer.cfg, trainer.model, trainer.dataloader)
if trainer.ema is not None:
logger.info(f"Apply Precising BN for EMA at epoch{trainer.progress.epoch}")
compute_precise_bn_stats(trainer.cfg, trainer.ema, trainer.dataloader)
[文档]class ResumeHook(BaseHook):
"""Hook for resuming training process.
Effect during ``before_train`` procedure.
Args:
save_dir: checkpoint directory.
resume: enable resume or not. Default: ``False``
"""
def __init__(self, save_dir: int = None, resume: bool = False):
super().__init__()
ensure_dir(save_dir)
self.save_dir = save_dir
self.resume = resume
[文档] def before_train(self):
trainer = self.trainer
if self.resume:
progress = trainer.progress
ckpt = _create_checkpoint(self.trainer, self.save_dir)
filename = ckpt.get_checkpoint_file("latest.pkl")
logger.info(f"Load checkpoint from {filename}")
ckpt.resume(filename)
# since ckpt is dumped after every epoch,
# resume training requires epoch + 1 and set iter to 1
progress.epoch += 1
progress.iter = 1
[文档]class TensorboardHook(BaseHook):
"""Hook for tensorboard during training.
Effect during ``before_train``, ``after_train`` and ``after_iter`` procedure.
Args:
log_dir: tensorboard directory.
log_every_n_iter: interval for logging. Default: ``20``
scalar_type: statistic to record, supports ``"latest"``, ``"avg"``, ``"global_avg"`` and
``"median"``. Default: ``"latest"``
"""
def __init__(self, log_dir: str, log_every_n_iter: int = 20, scalar_type: str = "latest"):
super().__init__()
if scalar_type not in ("latest", "avg", "global_avg", "median"):
raise ValueError(f"Tensorboard scalar type '{scalar_type}' not supported")
ensure_dir(log_dir)
self.log_dir = log_dir
self.log_every_n_iter = log_every_n_iter
self.scalar_type = scalar_type
[文档] def before_train(self):
self.writer = SummaryWriter(self.log_dir)
[文档] def after_train(self):
self.writer.close()
[文档] def after_iter(self):
trainer = self.trainer
epoch_id, iter_id = trainer.progress.epoch, trainer.progress.iter
if iter_id % self.log_every_n_iter == 0 or (iter_id == 1 and epoch_id == 1):
self.write(context=trainer)
[文档] def write(self, context):
cur_iter = self.calc_iter(context.progress)
for key, meter in context.meter.items():
value = getattr(meter, self.scalar_type, meter.latest)
for prefix in ("loss", "stat", "time", "memory"):
if prefix in key:
key = f"{prefix}/{key}"
break
self.writer.add_scalar(key, value, cur_iter)
# write lr into tensorboard
lr = context.solver.optimizer.param_groups[0]["lr"]
self.writer.add_scalar("lr", lr, cur_iter)
# write loss_scale into tensorboard
if context.cfg.amp.enabled:
loss_scale = context.solver.grad_scaler.scale_factor
self.writer.add_scalar("amp_loss_scale", loss_scale, cur_iter)
[文档] @classmethod
def calc_iter(cls, progress):
return (progress.epoch - 1) * progress.max_iter + progress.iter - 1