basecls.solver.build#

class basecls.solver.build.Solver(optimizer, grad_manager, grad_scaler)#

基类:tuple

grad_manager#

Alias for field number 1

grad_scaler#

Alias for field number 2

optimizer#

Alias for field number 0

class basecls.solver.build.BaseSolver[源代码]#

基类:object

Base class for solver factory.

A solver factory should return a Solver object, which combines an Optimizer and a GradManager.

classmethod build(cfg, model)[源代码]#

Abstract build function

参数
  • cfg (ConfigDict) – config for training.

  • model (Module) – model for training.

返回类型

Solver

返回

A solver.

class basecls.solver.build.DefaultSolver[源代码]#

基类:BaseSolver

The default solver factory.

According to cfg.reduce_mode, learning rate and weight decay will be scaled automatically following the linear scaling rule, see “Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour” for more details.

It supports "sgd", "adam" and "adamw".

备注

This linear scaling rule can only work well with SGD. We are still looking for the applicable scaling rule for Adam and AdamW. Thus we recommend keeping default training settings (like learning rate and world size) when using Adam and AdamW.

classmethod build(cfg, model)[源代码]#

Build function with the linear scaling strategy.

参数
  • cfg (ConfigDict) – config for training.

  • model (Module) – model for training.

返回类型

Solver

返回

A solver.

classmethod build_optimizer(cfg, params, lr, wd)[源代码]#

Build optimizer according to training config.

参数
  • cfg (ConfigDict) – config for training.

  • params (Union[Iterable[Parameter], dict]) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float) – learning rate.

  • weight_decay – weight decay (L2, penalty).

返回类型

Optimizer

返回

An optimizer.