basecls.data.fake_data 源代码

#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
from numbers import Real
from typing import Tuple, Union

import numpy as np


[文档]class FakeDataLoader: """FakeDataLoader Args: batch_size: batch size img_size: height and width. Default: 224 channels: color channels. Default: 3 length: loader length. Default: 100 num_classes: number of classes. Default: 1000 """ def __init__( self, batch_size: int, img_size: Union[int, Tuple[int, int]] = 224, channels: int = 3, length: int = 100, num_classes: int = 1000, ): self.batch_size = batch_size self.channels = channels if isinstance(img_size, Real): img_size = (img_size, img_size) self.img_size = img_size self.length = length self.num_classes = num_classes def __len__(self): return self.length def __iter__(self): images = np.random.randint( 256, dtype=np.uint8, size=(self.batch_size, self.channels, *self.img_size) ) labels = np.random.randint(self.num_classes, dtype=np.int32, size=(self.batch_size,)) for _ in range(len(self)): yield images, labels