basecls.engine.trainer#

class basecls.engine.trainer.ClsTrainer(cfg, model, dataloader, solver, hooks=None)[源代码]#

基类:BaseTrainer

Classification trainer.

参数
cfg#

config for training.

model#

model for training.

ema#

model exponential moving average.

dataloader#

dataloader for training.

solver#

solver for training.

progress#

object for recording training process.

loss#

loss function for training.

meter#

object for recording metrics.

train()[源代码]#
参数
  • start_info (Iterable) – [epoch, iter] for training start.

  • max_info (Iterable) – [max_epoch, max_iter] for training.

before_train()[源代码]#
before_epoch()[源代码]#
after_epoch()[源代码]#
train_one_iter()[源代码]#

Basic logic of training one iteration.

model_step(samples, targets)[源代码]#
modify_grad()[源代码]#
model_ema_step()[源代码]#

Implement momentum based Exponential Moving Average (EMA) for model states rwightman/pytorch-image-models

Also inspired by Pycls facebookresearch/pycls#, which is more flexible and efficient

Heuristically, one can use a momentum of 0.9999 as used by Tensorflow and 0.9998 as used by timm, which updates model ema every iter. To be more efficient, one can set update_period to e.g. 8 or 32 to speed up your training, and decrease your momentum at scale: set momentum=0.9978 from 0.9999 (32 times) when you update_period=32.

Also, to make model EMA really work (improve generalization), one should carefully tune the momentum based on various factors, e.g. the learning rate scheduler, the total batch size, the training epochs, e.t.c.

To initialize a momentum in Pycls style, one set model_ema.alpha = 1e-5 instead. Momentum will be calculated through _calculate_pycls_momentum.