自定义网络#
BaseCls 支持接入用户自定义的网络。
实现范式#
网络必须继承自
Module
。自定义参数通过模型配置文件
model
字段传入。以下字段为保留字段不可使用:
model.name
, BaseCls 用此字段构造网络。model.head.name
, BaseCls 用此字段构造分类头。
具体步骤#
实现网络并注册#
1from typing import Any, Mapping
2
3import megengine as mge
4import megengine.functional as F
5import megengine.module as M
6from basecls.layers import build_head
7from basecls.utils import registers
8
9class AlexNetHead(M.Module):
10
11 def __init__(self, w_in: int, w_out: int, width: int = 4096):
12 super().__init__()
13 self.avg_pool = M.AdaptiveAvgPool2d((6, 6))
14 self.classifier = M.Sequential(
15 M.Dropout(),
16 M.Linear(w_in * 6 * 6, width),
17 M.ReLU(),
18 M.Dropout(),
19 M.Linear(width, width),
20 M.ReLU(),
21 M.Linear(width, w_out),
22 )
23
24 def forward(self, x: mge.Tensor) -> mge.Tensor:
25 x = self.avg_pool(x)
26 x = F.flatten(x, 1)
27 x = self.classifier(x)
28 return x
29
30@registers.models.register()
31class AlexNet(M.Module):
32
33 def __init__(self, head: Mapping[str, Any] = None):
34 super().__init__()
35 self.features = M.Sequential(
36 M.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
37 M.ReLU(),
38 M.MaxPool2d(kernel_size=3, stride=2),
39 M.Conv2d(64, 192, kernel_size=5, padding=2),
40 M.ReLU(),
41 M.MaxPool2d(kernel_size=3, stride=2),
42 M.Conv2d(192, 384, kernel_size=3, padding=1),
43 M.ReLU(),
44 M.Conv2d(384, 256, kernel_size=3, padding=1),
45 M.ReLU(),
46 M.Conv2d(256, 256, kernel_size=3, padding=1),
47 M.ReLU(),
48 M.MaxPool2d(kernel_size=3, stride=2),
49 )
50 self.head = build_head(256, head)
51
52 def forward(self, x: mge.Tensor) -> mge.Tensor:
53 x = self.features(x)
54 if getattr(self, "head", None) is not None:
55 x = self.head(x)
56 return x
修改模型配置文件#
1_cfg = dict(
2 ...
3 num_classes=1000,
4 model=dict(
5 name="AlexNet",
6 ... # 你想传入的自定义参数
7 head=dict(
8 name=AlexNetHead, # 也可以直接传入一个类
9 # w_out=1000, # 若该字段未定义,会自动传入cfg.num_classes
10 ),
11 )
12 ...
13)