自定义数据增强#
BaseCls 支持自定义数据增强。
实现范式#
数据增强类必须实现
build
类方法,返回一个VisionTransform
对象。自定义参数通过模型配置文件
augments
字段传入。以下字段为保留字段不可使用:
augments.name
,BaseCls 用此字段构造数据增强类。
具体步骤#
实现网络并注册#
1import megengine.data.transform as T
2import numpy as np
3from basecls.utils import registers
4from basecore.config import ConfigDict
5
6@registers.augments.register()
7class YourAugmentBuilder:
8
9 @classmethod
10 def build(cls, cfg: ConfigDict) -> T.Transform:
11 return YourAugment(cfg)
12
13class YourAugment(T.VisionTransform):
14
15 def __init__(self, cfg: ConfigDict):
16 pass
17
18 def _apply_image(self, image: np.ndarray) -> np.ndarray:
19 pass
修改模型配置文件#
1_cfg = dict(
2 ...
3 argments=dict(
4 name="YourAugmentBuilder",
5 ... # 你想传入的自定义参数
6 )
7 ...
8)