basecls.engine.trainer#
- class basecls.engine.trainer.ClsTrainer(cfg, model, dataloader, solver, hooks=None)[源代码]#
基类:
BaseTrainer
Classification trainer.
- 参数
cfg (
ConfigDict
) – config for training.model (
Module
) – model for training.dataloader (
Union
[DataLoader
,FakeDataLoader
]) – dataloader for training.solver (
Solver
) – solver for training.
- cfg#
config for training.
- model#
model for training.
- ema#
model exponential moving average.
- dataloader#
dataloader for training.
- solver#
solver for training.
- progress#
object for recording training process.
- loss#
loss function for training.
- meter#
object for recording metrics.
- train()[源代码]#
- 参数
start_info (Iterable) – [epoch, iter] for training start.
max_info (Iterable) – [max_epoch, max_iter] for training.
- model_ema_step()[源代码]#
Implement momentum based Exponential Moving Average (EMA) for model states rwightman/pytorch-image-models
Also inspired by Pycls facebookresearch/pycls#, which is more flexible and efficient
Heuristically, one can use a momentum of 0.9999 as used by Tensorflow and 0.9998 as used by timm, which updates model ema every iter. To be more efficient, one can set
update_period
to e.g. 8 or 32 to speed up your training, and decrease your momentum at scale: setmomentum=0.9978
from 0.9999 (32 times) when youupdate_period=32
.Also, to make model EMA really work (improve generalization), one should carefully tune the momentum based on various factors, e.g. the learning rate scheduler, the total batch size, the training epochs, e.t.c.
To initialize a momentum in Pycls style, one set
model_ema.alpha = 1e-5
instead. Momentum will be calculated through_calculate_pycls_momentum
.