#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import copy
import time
from typing import Iterable
import megengine as mge
import megengine.amp as amp
import megengine.distributed as dist
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from basecore.config import ConfigDict
from basecore.engine import BaseHook, BaseTrainer
from basecore.utils import MeterBuffer
from megengine import jit
from basecls.data import DataLoaderType
from basecls.layers import Preprocess, build_loss
from basecls.solver import Solver
from basecls.utils import registers
__all__ = ["ClsTrainer"]
[文档]@registers.trainers.register()
class ClsTrainer(BaseTrainer):
"""Classification trainer.
Args:
cfg: config for training.
model: model for training.
dataloader: dataloader for training.
solver: solver for training.
hooks: hooks for training.
Attributes:
cfg: config for training.
model: model for training.
ema: model exponential moving average.
dataloader: dataloader for training.
solver: solver for training.
progress: object for recording training process.
loss: loss function for training.
meter : object for recording metrics.
"""
def __init__(
self,
cfg: ConfigDict,
model: M.Module,
dataloader: DataLoaderType,
solver: Solver,
hooks: Iterable[BaseHook] = None,
):
super().__init__(model, dataloader, solver, hooks)
self.cfg = cfg
self.ema = copy.deepcopy(model) if cfg.model_ema.enabled else None
self.preprocess = Preprocess(cfg.preprocess.img_mean, cfg.preprocess.img_std)
self.loss = build_loss(cfg)
self.meter = MeterBuffer(cfg.log_every_n_iter)
if cfg.trace:
# FIXME: tracing makes the training slower than before, why?
self.model_step = jit.trace(self.model_step, symbolic=True)
[文档] def train(self):
start_training_info = (1, 1)
max_iter = len(self.dataloader)
max_training_info = (self.cfg.solver.max_epoch, max_iter)
super().train(start_training_info, max_training_info)
[文档] def before_train(self):
super().before_train()
[文档] def before_epoch(self):
super().before_epoch()
self.dataloader_iter = iter(self.dataloader)
[文档] def after_epoch(self):
del self.dataloader_iter
super().after_epoch()
[文档] def train_one_iter(self):
"""Basic logic of training one iteration."""
data_tik = time.perf_counter()
data = next(self.dataloader_iter)
samples, targets = self.preprocess(data)
mge._full_sync() # use full_sync func to sync launch queue for dynamic execution
data_tok = time.perf_counter()
train_tik = time.perf_counter()
losses, accs = self.model_step(samples, targets)
mge._full_sync() # use full_sync func to sync launch queue for dynamic execution
train_tok = time.perf_counter()
# TODO: stats and accs
loss_meters = {"loss": losses.item()}
stat_meters = {"stat_acc@1": accs[0].item() * 100, "stat_acc@5": accs[1].item() * 100}
time_meters = {"train_time": train_tok - train_tik, "data_time": data_tok - data_tik}
self.meter.update(**loss_meters, **stat_meters, **time_meters)
[文档] def model_step(self, samples, targets):
optimizer = self.solver.optimizer
grad_manager = self.solver.grad_manager
grad_scaler = self.solver.grad_scaler
with grad_manager:
with amp.autocast(enabled=self.cfg.amp.enabled):
outputs = self.model(samples)
losses = self.loss(outputs, targets)
if isinstance(losses, mge.Tensor):
total_loss = losses
elif isinstance(losses, dict):
if "total_loss" in losses:
total_loss = losses["total_loss"]
else:
# only key contains "loss" will be calculated.
total_loss = sum([v for k, v in losses.items() if "loss" in k])
losses["total_loss"] = total_loss
else:
# list or tuple
total_loss = sum(losses)
total_loss = total_loss / self.cfg.solver.accumulation_steps
# this is made compatible with one hot labels
if targets.ndim == 2:
targets = F.argmax(targets, axis=1)
accs = F.metric.topk_accuracy(outputs, targets, (1, 5))
if self.cfg.amp.enabled:
grad_scaler.backward(grad_manager, total_loss)
else:
grad_manager.backward(total_loss)
if self.progress.iter % self.cfg.solver.accumulation_steps == 0:
self.modify_grad()
optimizer.step().clear_grad()
self.model_ema_step()
return losses, accs
[文档] def modify_grad(self):
grad_cfg = self.cfg.solver.grad_clip
# TODO: support advanced params for grad clip in the future
params = self.model.parameters()
if grad_cfg.name is None:
return
elif grad_cfg.name == "norm":
optim.clip_grad_norm(params, grad_cfg.max_norm)
elif grad_cfg.name == "value":
optim.clip_grad_value(params, grad_cfg.lower, grad_cfg.upper)
else:
raise ValueError(f"Grad clip type '{grad_cfg.name}' not supported")
[文档] def model_ema_step(self):
"""Implement momentum based Exponential Moving Average (EMA) for model states
https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/model_ema.py
Also inspired by Pycls https://github.com/facebookresearch/pycls/pull/138/, which is more
flexible and efficient
Heuristically, one can use a momentum of 0.9999 as used by Tensorflow and 0.9998 as used
by timm, which updates model ema every iter. To be more efficient, one can set
``update_period`` to e.g. 8 or 32 to speed up your training, and decrease your momentum
at scale: set ``momentum=0.9978`` from 0.9999 (32 times) when you ``update_period=32``.
Also, to make model EMA really work (improve generalization), one should carefully tune
the momentum based on various factors, e.g. the learning rate scheduler,
the total batch size, the training epochs, e.t.c.
To initialize a momentum in Pycls style, one set ``model_ema.alpha = 1e-5`` instead.
Momentum will be calculated through ``_calculate_pycls_momentum``.
"""
if self.ema is None:
return
ema_cfg = self.cfg.model_ema
cur_iter, cur_epoch = self.progress.iter, self.progress.epoch
if cur_iter % ema_cfg.update_period == 0:
if cur_epoch > (ema_cfg.start_epoch or self.cfg.solver.warmup_epochs):
momentum = (
ema_cfg.momentum
if ema_cfg.alpha is None
else _calculate_pycls_momentum(
alpha=ema_cfg.alpha,
total_batch_size=self.cfg.batch_size * dist.get_world_size(),
max_epoch=self.cfg.solver.max_epoch,
update_period=ema_cfg.update_period,
)
)
else:
# copy model to ema
momentum = 0.0
if not hasattr(self, "_ema_states"):
self._ema_states = (
list(self.ema.parameters()) + list(self.ema.buffers()),
list(self.model.parameters()) + list(self.model.buffers()),
)
for e, p in zip(*self._ema_states):
# _inplace_add_(e, p, alpha=mge.tensor(momentum), beta=mge.tensor(1 - momentum))
e._reset(e * momentum + p * (1 - momentum))
def _calculate_pycls_momentum(
alpha: float, total_batch_size: int, max_epoch: int, update_period: int
):
"""pycls style momentum calculation which uses a relative model_ema to decouple momentum with
other training hyper-parameters e.g.
* training epochs
* interval to update ema
* batch sizes
Usually the alpha is a tiny positive floating number, e.g. 1e-4 or 1e-5,
with ``max_epoch=100``, ``total_batch_size=1024`` and ``update_period=32``, the ema
momentum should be 0.996723175, which has roughly same behavior to the default setting.
i.e. ``momentum=0.9999`` together with ``update_period=1``
"""
return max(0, 1 - alpha * (total_batch_size / max_epoch * update_period))