自定义数据源#

BaseCls 支持接入第三方数据源。

实现范式#

  • 数据源类必须实现 build 类方法,返回一个 Iterable 对象(实现了 __iter__ 方法)。

  • 自定义参数通过模型配置文件 data 字段传入。

  • 以下字段为保留字段不可使用:

    • data.name ,BaseCls 用此字段构造数据源。

具体步骤#

实现数据源并注册#

 1from basecls.utils import registers
 2from basecore.config import ConfigDict
 3
 4@registers.dataloaders.register()
 5class YourDataSourceBuilder:
 6
 7    @classmethod
 8    def build(cls, cfg: ConfigDict, augments):
 9        return YourDataSource(cfg, augments)
10
11class YourDataSource:
12
13    def __init__(self, cfg: ConfigDict, augments):
14        pass
15
16    def __iter__(self):
17        pass

修改模型配置文件#

1_cfg = dict(
2    ...
3    data=dict(
4        name="YourDataSourceBuilder",
5        ...  # 你想传入的自定义参数
6    )
7    ...
8)