自定义网络#

BaseCls 支持接入用户自定义的网络。

实现范式#

  • 网络必须继承自 Module

  • 自定义参数通过模型配置文件 model 字段传入。

  • 以下字段为保留字段不可使用:

    • model.name , BaseCls 用此字段构造网络。

    • model.head.name , BaseCls 用此字段构造分类头。

具体步骤#

实现网络并注册#

 1from typing import Any, Mapping
 2
 3import megengine as mge
 4import megengine.functional as F
 5import megengine.module as M
 6from basecls.layers import build_head
 7from basecls.utils import registers
 8
 9class AlexNetHead(M.Module):
10
11    def __init__(self, w_in: int, w_out: int, width: int = 4096):
12        super().__init__()
13        self.avg_pool = M.AdaptiveAvgPool2d((6, 6))
14        self.classifier = M.Sequential(
15            M.Dropout(),
16            M.Linear(w_in * 6 * 6, width),
17            M.ReLU(),
18            M.Dropout(),
19            M.Linear(width, width),
20            M.ReLU(),
21            M.Linear(width, w_out),
22        )
23
24    def forward(self, x: mge.Tensor) -> mge.Tensor:
25        x = self.avg_pool(x)
26        x = F.flatten(x, 1)
27        x = self.classifier(x)
28        return x
29
30@registers.models.register()
31class AlexNet(M.Module):
32
33    def __init__(self, head: Mapping[str, Any] = None):
34        super().__init__()
35        self.features = M.Sequential(
36            M.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
37            M.ReLU(),
38            M.MaxPool2d(kernel_size=3, stride=2),
39            M.Conv2d(64, 192, kernel_size=5, padding=2),
40            M.ReLU(),
41            M.MaxPool2d(kernel_size=3, stride=2),
42            M.Conv2d(192, 384, kernel_size=3, padding=1),
43            M.ReLU(),
44            M.Conv2d(384, 256, kernel_size=3, padding=1),
45            M.ReLU(),
46            M.Conv2d(256, 256, kernel_size=3, padding=1),
47            M.ReLU(),
48            M.MaxPool2d(kernel_size=3, stride=2),
49        )
50        self.head = build_head(256, head)
51
52    def forward(self, x: mge.Tensor) -> mge.Tensor:
53        x = self.features(x)
54        if getattr(self, "head", None) is not None:
55            x = self.head(x)
56        return x

修改模型配置文件#

 1_cfg = dict(
 2    ...
 3    num_classes=1000,
 4    model=dict(
 5        name="AlexNet",
 6        ...  # 你想传入的自定义参数
 7        head=dict(
 8            name=AlexNetHead,  # 也可以直接传入一个类
 9            # w_out=1000,  # 若该字段未定义,会自动传入cfg.num_classes
10        ),
11    )
12    ...
13)