basecls.data.dataloader 源代码

#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import megengine.data as data
import megengine.data.transform as T
from basecore.config import ConfigDict

from .dataset import build_dataset
from .transform import build_mixup, build_transform

__all__ = ["build_dataloader"]


[文档]def build_dataloader( cfg: ConfigDict, train: bool, augments: T.Transform = None, mode: str = "folder", infinite: bool = False, rank: int = None, ) -> data.DataLoader: """Build function for MegEngine dataloader. Args: cfg: config for building dataloader. train: train set or test set. augments: augments for building dataloder. Default: ``None`` infinite: make dataloader infinite or not. default: ``False`` rank: machine rank, only useful for infinite dataloader. Default: ``None`` Returns: A dataloader. """ dataset = build_dataset(cfg, train, mode) if train: if infinite: # for DPFlow producer assert rank is not None sampler = data.Infinite( data.RandomSampler( dataset, cfg.batch_size, drop_last=True, world_size=1, rank=0, seed=cfg.seed + rank, ) ) else: sampler = data.RandomSampler(dataset, cfg.batch_size, drop_last=True, seed=cfg.seed) else: sampler = data.SequentialSampler(dataset, 25) # can divide 50000 / 8 = 6250 transform = build_transform(cfg, train, augments) mixup = build_mixup(cfg, train) dataloader = data.DataLoader( dataset, sampler=sampler, transform=transform, num_workers=cfg.data.num_workers, collator=mixup, ) return dataloader