自定义数据增强#

BaseCls 支持自定义数据增强。

实现范式#

  • 数据增强类必须实现 build 类方法,返回一个 VisionTransform 对象。

  • 自定义参数通过模型配置文件 augments 字段传入。

  • 以下字段为保留字段不可使用:

    • augments.name ,BaseCls 用此字段构造数据增强类。

具体步骤#

实现网络并注册#

 1import megengine.data.transform as T
 2import numpy as np
 3from basecls.utils import registers
 4from basecore.config import ConfigDict
 5
 6@registers.augments.register()
 7class YourAugmentBuilder:
 8
 9     @classmethod
10     def build(cls, cfg: ConfigDict) -> T.Transform:
11         return YourAugment(cfg)
12
13class YourAugment(T.VisionTransform):
14
15     def __init__(self, cfg: ConfigDict):
16         pass
17
18     def _apply_image(self, image: np.ndarray) -> np.ndarray:
19         pass

修改模型配置文件#

1_cfg = dict(
2    ...
3    argments=dict(
4        name="YourAugmentBuilder",
5        ...  # 你想传入的自定义参数
6    )
7    ...
8)