basecls.engine.build 源代码

#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import os
from typing import List

import megengine.distributed as dist
from basecore.config import ConfigDict
from basecore.engine import BaseHook
from basecore.utils import str_timestamp

from basecls.utils import registers

from .hooks import (
    CheckpointHook,
    EvalHook,
    LoggerHook,
    LRSchedulerHook,
    PreciseBNHook,
    ResumeHook,
    TensorboardHook,
)

__all__ = ["DefaultHooks"]


[文档]@registers.hooks.register() class DefaultHooks: """The default hooks factory. It combines :py:class:`~basecls.engine.LRSchedulerHook` -> :py:class:`~basecls.engine.PreciseBNHook` -> :py:class:`~basecls.engine.ResumeHook` -> :py:class:`~basecls.engine.TensorboardHook` -> :py:class:`~basecls.engine.LoggerHook` -> :py:class:`~basecls.engine.CheckpointHook` -> :py:class:`~basecls.engine.EvalHook`. """
[文档] @classmethod def build(cls, cfg: ConfigDict) -> List[BaseHook]: """Build function with a simple strategy. Args: cfg: config for setting hooks. Returns: A hook list. """ output_dir = cfg.output_dir hook_list = [ LRSchedulerHook(), PreciseBNHook(cfg.bn.precise_every_n_epoch), ResumeHook(output_dir, cfg.resume), ] if dist.get_rank() == 0: # Since LoggerHook will reset value, TensorboardHook should be added before LoggerHook hook_list.append( TensorboardHook( os.path.join(output_dir, "tensorboard", str_timestamp()), cfg.tb_every_n_iter ) ) hook_list.append(LoggerHook(cfg.log_every_n_iter)) hook_list.append(CheckpointHook(output_dir, cfg.save_every_n_epoch)) # Hooks better work after CheckpointHook hook_list.append(EvalHook(output_dir, cfg.eval_every_n_epoch)) return hook_list