basecls.models.vgg 源代码

#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
"""VGG Series

VGG: `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
<https://arxiv.org/abs/1409.1556>`_
"""
from typing import Any, Mapping, Sequence

import megengine as mge
import megengine.hub as hub
import megengine.module as M

from basecls.layers import activation, build_head, conv2d, init_weights, norm2d
from basecls.utils import recursive_update, registers

__all__ = ["VGGStage", "VGG"]


[文档]class VGGStage(M.Module): """VGG stage (sequence of blocks w/ the same output shape).""" def __init__(self, w_in: int, w_out: int, depth: int, norm_name: str, act_name: str): super().__init__() self.depth = depth for i in range(depth): block = M.Sequential( conv2d(w_in, w_out, 3), norm2d(norm_name, w_out), activation(act_name) ) setattr(self, f"b{i + 1}", block) w_in = w_out self.max_pool = M.MaxPool2d(kernel_size=2, stride=2) def __len__(self): return self.depth
[文档] def forward(self, x: mge.Tensor) -> mge.Tensor: for i in range(self.depth): block = getattr(self, f"b{i + 1}") x = block(x) x = self.max_pool(x) return x
[文档]@registers.models.register() class VGG(M.Module): """VGG model. Args: depths: depth for each stage (number of blocks in the stage). widths: width for each stage (width of each block in the stage). norm_name: normalization function. Default: ``None`` act_name: activation function. Default: ``"relu"`` head: head args. Default: ``None`` """ def __init__( self, depths: Sequence[int], widths: Sequence[int], norm_name: str = None, act_name: str = "relu", head: Mapping[str, Any] = None, ): super().__init__() self.depths = depths model_args = [depths, widths] prev_w = 3 for i, (d, w) in enumerate(zip(*model_args)): stage = VGGStage(prev_w, w, d, norm_name, act_name) setattr(self, f"s{i + 1}", stage) prev_w = w self.head = build_head(prev_w, head, None, act_name) self.apply(init_weights)
[文档] def forward(self, x: mge.Tensor) -> mge.Tensor: for i in range(len(self.depths)): stage = getattr(self, f"s{i + 1}") x = stage(x) if getattr(self, "head", None) is not None: x = self.head(x) return x
def _build_vgg(**kwargs): model_args = dict(head=dict(name="VGGHead", dropout_prob=0.5)) recursive_update(model_args, kwargs) return VGG(**model_args) @registers.models.register() @hub.pretrained("https://data.megengine.org.cn/research/basecls/models/vgg/vgg11/vgg11.pkl") def vgg11(**kwargs): model_args = dict(depths=[1, 1, 2, 2, 2], widths=[64, 128, 256, 512, 512]) recursive_update(model_args, kwargs) return _build_vgg(**model_args) @registers.models.register() @hub.pretrained("https://data.megengine.org.cn/research/basecls/models/vgg/vgg11_bn/vgg11_bn.pkl") def vgg11_bn(**kwargs): model_args = dict(norm_name="BN") recursive_update(model_args, kwargs) return vgg11(**model_args) @registers.models.register() @hub.pretrained("https://data.megengine.org.cn/research/basecls/models/vgg/vgg13/vgg13.pkl") def vgg13(**kwargs): model_args = dict(depths=[2, 2, 2, 2, 2], widths=[64, 128, 256, 512, 512]) recursive_update(model_args, kwargs) return _build_vgg(**model_args) @registers.models.register() @hub.pretrained("https://data.megengine.org.cn/research/basecls/models/vgg/vgg13_bn/vgg13_bn.pkl") def vgg13_bn(**kwargs): model_args = dict(norm_name="BN") recursive_update(model_args, kwargs) return vgg13(**model_args) @registers.models.register() @hub.pretrained("https://data.megengine.org.cn/research/basecls/models/vgg/vgg16/vgg16.pkl") def vgg16(**kwargs): model_args = dict(depths=[2, 2, 3, 3, 3], widths=[64, 128, 256, 512, 512]) recursive_update(model_args, kwargs) return _build_vgg(**model_args) @registers.models.register() @hub.pretrained("https://data.megengine.org.cn/research/basecls/models/vgg/vgg16_bn/vgg16_bn.pkl") def vgg16_bn(**kwargs): model_args = dict(norm_name="BN") recursive_update(model_args, kwargs) return vgg16(**model_args) @registers.models.register() @hub.pretrained("https://data.megengine.org.cn/research/basecls/models/vgg/vgg19/vgg19.pkl") def vgg19(**kwargs): model_args = dict(depths=[2, 2, 4, 4, 4], widths=[64, 128, 256, 512, 512]) recursive_update(model_args, kwargs) return _build_vgg(**model_args) @registers.models.register() @hub.pretrained("https://data.megengine.org.cn/research/basecls/models/vgg/vgg19_bn/vgg19_bn.pkl") def vgg19_bn(**kwargs): model_args = dict(norm_name="BN") recursive_update(model_args, kwargs) return vgg19(**model_args)