basecls.data.build 源代码

#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
from typing import Union

import megengine.data as data
import megengine.data.transform as T
from basecore.config import ConfigDict

from basecls.utils import registers

from .dataloader import build_dataloader
from .fake_data import FakeDataLoader

__all__ = ["DataLoaderType", "FakeData", "FolderLoader"]

DataLoaderType = Union[data.DataLoader, FakeDataLoader]


[文档]@registers.dataloaders.register() class FakeData: """Fake data useful for benchmark."""
[文档] @classmethod def build( cls, cfg: ConfigDict, train: bool = True, augments: T.Transform = None ) -> data.DataLoader: return FakeDataLoader( batch_size=cfg.batch_size, img_size=cfg.preprocess.img_size, channels=1 if cfg.preprocess.img_color_space == "GRAY" else 3, length=200, num_classes=cfg.num_classes, )
[文档]@registers.dataloaders.register() class FolderLoader: """Local dataloader factory. The source is the local folder. """
[文档] @classmethod def build( cls, cfg: ConfigDict, train: bool = True, augments: T.Transform = None ) -> data.DataLoader: return build_dataloader(cfg, train, augments, mode="folder")