basecls.solver.optimizer#

class basecls.solver.optimizer.LAMB(params, lr, betas=(0.9, 0.999), eps=1e-8, bias_correction=True, weight_decay=0.0, always_adapt=False)[源代码]#

基类:Optimizer

Implements LAMB algorithm.

LAMB is proposed in “Large Batch Optimization for Deep Learning: Training BERT in 76 minutes”.

参数
  • params (Union[Iterable[Parameter], dict]) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float) – learning rate.

  • betas (Tuple[float, float]) – coefficients used for computing running averages of gradient and its square. Default: (0.9, 0.999)

  • eps (float) – term added to the denominator to improve numerical stability. Default: 1e-8

  • bias_correction (bool) – enables bias correction by 1 - beta ** step. Default: True

  • weight_decay (float) – weight decay (L2 penalty). Default: 0.0

  • always_adapt (bool) – apply adaptive lr to 0.0 weight decay parameter. Default: False

class basecls.solver.optimizer.LARS(params, lr, momentum=0.0, nesterov=False, weight_decay=0.0, always_adapt=False)[源代码]#

基类:Optimizer

Implements LARS algorithm.

LARS is proposed in “Large Batch Optimization for Deep Learning: Training BERT in 76 minutes”.

参数
  • params (Union[Iterable[Parameter], dict]) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float) – learning rate.

  • momentum (float) – momentum factor. Default: 0.0

  • nesterov (bool) – enables Nesterov momentum. Default: False

  • weight_decay (float) – weight decay (L2 penalty). Default: 0.0

  • always_adapt (bool) – apply adaptive lr to 0.0 weight decay parameter. Default: False

class basecls.solver.optimizer.SGD(params, lr, momentum=0.0, nesterov=False, weight_decay=0.0)[源代码]#

基类:Optimizer

Implements stochastic gradient descent.

Nesterov momentum is based on the formula from “On the importance of initialization and momentum in deep learning”.

参数
  • params (Union[Iterable[Parameter], dict]) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float) – learning rate.

  • momentum (float) – momentum factor. Default: 0.0

  • nesterov (bool) – enables Nesterov momentum. Default: False

  • weight_decay (float) – weight decay (L2 penalty). Default: 0.0