basecls.engine#

class basecls.engine.DefaultHooks[源代码]#

基类:object

The default hooks factory.

It combines LRSchedulerHook -> PreciseBNHook -> ResumeHook -> TensorboardHook -> LoggerHook -> CheckpointHook -> EvalHook.

classmethod build(cfg)[源代码]#

Build function with a simple strategy.

参数

cfg (ConfigDict) – config for setting hooks.

返回类型

List[BaseHook]

返回

A hook list.

class basecls.engine.AccEvaluator[源代码]#

基类:BaseEvaluator

Classification evaluator with top-1 and top-5 accuracy.

ResultType#

Tuple[int, float, float] 的别名

preprocess(input_data)[源代码]#

Preprocess input data per batch.

参数

input_data (Sequence[ndarray]) – input data.

返回类型

Tensor

返回

Preprocessed input data.

postprocess(model_outputs, input_data)[源代码]#

Postprocess model outputs with input data per batch.

参数
返回类型

Tuple[int, float, float]

返回

A tuple that (batch size, top-1 accuracy per batch, top-5 accuracy per batch).

evaluate(results)[源代码]#

Evaluation function.

参数

results (Iterable[Tuple[int, float, float]]) – all results.

class basecls.engine.ClsTester(cfg, model, dataloader)[源代码]#

基类:BaseTester

test(warm_iters=5, log_seconds=5)[源代码]#
class basecls.engine.ClsTrainer(cfg, model, dataloader, solver, hooks=None)[源代码]#

基类:BaseTrainer

Classification trainer.

参数
cfg#

config for training.

model#

model for training.

ema#

model exponential moving average.

dataloader#

dataloader for training.

solver#

solver for training.

progress#

object for recording training process.

loss#

loss function for training.

meter#

object for recording metrics.

train()[源代码]#
参数
  • start_info (Iterable) – [epoch, iter] for training start.

  • max_info (Iterable) – [max_epoch, max_iter] for training.

before_train()[源代码]#
before_epoch()[源代码]#
after_epoch()[源代码]#
train_one_iter()[源代码]#

Basic logic of training one iteration.

model_step(samples, targets)[源代码]#
modify_grad()[源代码]#
model_ema_step()[源代码]#

Implement momentum based Exponential Moving Average (EMA) for model states rwightman/pytorch-image-models

Also inspired by Pycls facebookresearch/pycls#, which is more flexible and efficient

Heuristically, one can use a momentum of 0.9999 as used by Tensorflow and 0.9998 as used by timm, which updates model ema every iter. To be more efficient, one can set update_period to e.g. 8 or 32 to speed up your training, and decrease your momentum at scale: set momentum=0.9978 from 0.9999 (32 times) when you update_period=32.

Also, to make model EMA really work (improve generalization), one should carefully tune the momentum based on various factors, e.g. the learning rate scheduler, the total batch size, the training epochs, e.t.c.

To initialize a momentum in Pycls style, one set model_ema.alpha = 1e-5 instead. Momentum will be calculated through _calculate_pycls_momentum.