#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
from collections import namedtuple
from typing import Iterable, Union
import megengine as mge
import megengine.distributed as dist
import megengine.module as M
import megengine.optimizer as optim
from basecore.config import ConfigDict
from megengine import Parameter
from megengine.amp import GradScaler
from megengine.autodiff import GradManager
from pkg_resources import packaging
from basecls.utils import registers
from .optimizer import LAMB, LARS, SGD
from .weight_decay import get_param_groups
__all__ = ["Solver", "BaseSolver", "DefaultSolver"]
Solver = namedtuple("Solver", ["optimizer", "grad_manager", "grad_scaler"])
[文档]class BaseSolver:
"""Base class for solver factory.
A solver factory should return a :py:class:`~Solver` object, which combines
an :py:class:`~megengine.optimizer.Optimizer` and
a :py:class:`~megengine.autodiff.GradManager`.
"""
[文档] @classmethod
def build(cls, cfg: ConfigDict, model: M.Module) -> Solver:
"""Abstract build function
Args:
cfg: config for training.
model: model for training.
Returns:
A solver.
"""
raise NotImplementedError
[文档]@registers.solvers.register()
class DefaultSolver(BaseSolver):
"""The default solver factory.
According to ``cfg.reduce_mode``, learning rate and weight decay will be scaled automatically
following the linear scaling rule, see
`"Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour"
<https://arxiv.org/abs/1706.02677>`_ for more details.
It supports ``"sgd"``, ``"adam"`` and ``"adamw"``.
Note:
This linear scaling rule can only work well with SGD. We are still looking for
the applicable scaling rule for Adam and AdamW. Thus we recommend keeping default
training settings (like learning rate and world size) when using Adam and AdamW.
"""
[文档] @classmethod
def build(cls, cfg: ConfigDict, model: M.Module) -> Solver:
"""Build function with the linear scaling strategy.
Args:
cfg: config for training.
model: model for training.
Returns:
A solver.
"""
amp_cfg = cfg.amp
cfg = cfg.solver
world_size = dist.get_world_size()
# build optimizer
lr = cfg.basic_lr * world_size # linear scaling rule
optim_params = get_param_groups(model, cfg.weight_decay)
optimizer = cls.build_optimizer(cfg, optim_params, lr, 0)
# build grad_manager
gm = GradManager()
callbacks = [dist.make_allreduce_cb("mean", dist.WORLD)] if world_size > 1 else None
gm.attach(model.parameters(), callbacks=callbacks)
# build grad_scaler
scaler = (
GradScaler(init_scale=65536.0, growth_interval=2000)
if amp_cfg.dynamic_scale
else GradScaler(init_scale=128.0, growth_interval=0)
)
return Solver(optimizer, gm, scaler)
[文档] @classmethod
def build_optimizer(
cls, cfg: ConfigDict, params: Union[Iterable[Parameter], dict], lr: float, wd: float
) -> optim.Optimizer:
"""Build optimizer according to training config.
Args:
cfg: config for training.
params: iterable of parameters to optimize or dicts defining parameter groups.
lr: learning rate.
weight_decay: weight decay (L2, penalty).
Returns:
An optimizer.
"""
if cfg.optimizer == "adam":
return optim.Adam(params, lr=lr, weight_decay=wd, betas=cfg.betas)
elif cfg.optimizer == "adamw":
return optim.AdamW(params, lr=lr, weight_decay=wd, betas=cfg.betas)
elif cfg.optimizer == "lamb":
return LAMB(
params, lr=lr, weight_decay=wd, betas=cfg.betas, always_adapt=cfg.always_adapt
)
elif cfg.optimizer == "lars":
return LARS(
params,
lr=lr,
weight_decay=wd,
momentum=cfg.momentum,
nesterov=cfg.nesterov,
always_adapt=cfg.always_adapt,
)
elif cfg.optimizer == "sgd":
if packaging.version.parse(mge.__version__) < packaging.version.parse("1.7.0"):
return SGD(
params, lr=lr, weight_decay=wd, momentum=cfg.momentum, nesterov=cfg.nesterov
)
return optim.SGD(
params, lr=lr, weight_decay=wd, momentum=cfg.momentum, nesterov=cfg.nesterov
)
else:
raise NotImplementedError(f"Optimizer '{cfg.optimizer}' not supported")