basecls.models.snet 源代码

#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
"""Shufflenet Series

ShufflenetV2: `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_

References:
    https://github.com/megvii-model/ShuffleNet-Series/blob/master/ShuffleNetV2/network.py
"""
from typing import Any, Callable, Mapping, Sequence

import megengine.functional as F
import megengine.hub as hub
import megengine.module as M

from basecls.layers import SE, DropPath, activation, conv2d, norm2d
from basecls.layers.heads import build_head
from basecls.models.resnet import SimpleStem
from basecls.utils import recursive_update, registers

__all__ = ["SNV2Block", "SNV2XceptionBlock", "SNetV2"]


def channel_shuffle(x):
    batchsize, num_channels, height, width = x.shape
    assert num_channels % 4 == 0
    x = x.reshape(batchsize * num_channels // 2, 2, height * width)
    x = x.transpose(1, 0, 2)
    x = x.reshape(2, -1, num_channels // 2, height, width)
    return x[0], x[1]


[文档]class SNV2Block(M.Module): """ShuffleNet V2 Block""" def __init__( self, w_in: int, w_out: int, w_mid: int, *, kernel: int, stride: int, norm_name: str, act_name: str, se_r: float, drop_path_prob: float, **kwargs, ): super().__init__() assert not kwargs self.stride = stride w_main_out = w_out - w_in w_se = int(w_main_out * se_r) self.branch_main = M.Sequential( # pw1 conv2d(w_in, w_mid, 1), norm2d(norm_name, w_mid), activation(act_name), # dw conv2d(w_mid, w_mid, kernel, stride=stride, groups=w_mid), norm2d(norm_name, w_mid), # pw2 conv2d(w_mid, w_main_out, 1), norm2d(norm_name, w_main_out), activation(act_name), ) # NOTE: se mid_act is RELU self.se = SE(w_main_out, w_se, act_name) if w_se else None self.drop_path = DropPath(drop_path_prob) if drop_path_prob > 0 else None self.branch_proj = ( M.Sequential( # dw conv2d(w_in, w_in, kernel, stride=stride, groups=w_in), norm2d(norm_name, w_in), # pw conv2d(w_in, w_in, 1), norm2d(norm_name, w_in), activation(act_name), ) if stride == 2 else None )
[文档] def forward(self, old_x): if self.stride == 1: x_proj, x = channel_shuffle(old_x) elif self.stride == 2: x_proj = x = old_x if self.branch_proj: x_proj = self.branch_proj(x_proj) x = self.branch_main(x) if self.se: x = self.se(x) if self.drop_path: x = self.drop_path(x) return F.concat((x_proj, x), 1)
[文档]class SNV2XceptionBlock(SNV2Block): """ShuffleNet V2 Xception stype block used in ShuffleNet V2+""" def __init__( self, w_in: int, w_out: int, w_mid: int, *, kernel: int, stride: int, norm_name: str, act_name: str, se_r: float, drop_path_prob: float, **kwargs, ): if isinstance(kernel, int): super().__init__( w_in, w_out, w_mid, kernel=kernel, stride=stride, norm_name=norm_name, act_name=act_name, se_r=se_r, drop_path_prob=drop_path_prob, **kwargs, ) elif kernel == "x": super().__init__( w_in, w_out, w_mid, kernel=3, stride=stride, norm_name=norm_name, act_name=act_name, se_r=se_r, drop_path_prob=drop_path_prob, **kwargs, ) w_main_out = w_out - w_in self.branch_main = M.Sequential( # dw conv2d(w_in, w_in, 3, stride=stride, groups=w_in), norm2d(norm_name, w_in), # pw conv2d(w_in, w_mid, 1), norm2d(norm_name, w_mid), activation(act_name), # dw conv2d(w_mid, w_mid, 3, stride=1, groups=w_mid), norm2d(norm_name, w_mid), # pw conv2d(w_mid, w_mid, 1), norm2d(norm_name, w_mid), activation(act_name), # dw conv2d(w_mid, w_mid, 3, stride=1, groups=w_mid), norm2d(norm_name, w_mid), # pw conv2d(w_mid, w_main_out, 1), norm2d(norm_name, w_main_out), activation(act_name), ) else: raise ValueError(f"unidentified kernel={kernel}, should be (3, 5, 7, 'x')")
[文档]@registers.models.register() class SNetV2(M.Module): """ShufflenetV2 model. Args: block: building block to use, ``SNV2XceptionBlock`` for v2+. stem_w: width for stem layer. depths: depth for each stage (number of blocks in the stage). widths: width for each stage (width of each block in the stage). strides: strides for each stage (applies to the first block of each stage). kernels: kernel sizes for each stage. use_maxpool: whether use maxpool stride 2 after stem. Default: ``True`` se_r: Squeeze-and-Excitation (SE) ratio. Default: ``0.0`` drop_path_prob: drop path probability. Default: ``0.0`` norm_name: normalization function. Default: ``"BN"`` act_name: activation function. Default: ``"relu6"`` head: head args. Default: ``None`` """ def __init__( self, block: Callable, stem_w: int, depths: Sequence[int], widths: Sequence[int], strides: Sequence[int], kernels: Sequence[int], use_maxpool: bool = True, se_r: float = 0.0, drop_path_prob: float = 0.0, norm_name: str = "BN", act_name: str = "relu", head: Mapping[str, Any] = None, ): super().__init__() self.stem = SimpleStem( 3, stem_w, norm_name, act_name, ) self.maxpool = M.MaxPool2d(3, 2, 1) if use_maxpool else None prev_w = stem_w # stochastic depth iblock = 0 total_blocks = sum(depths) self.num_stages = len(depths) for idxstage in range(self.num_stages): blocks = [] depth = depths[idxstage] width = widths[idxstage] stride = strides[idxstage] kernel = kernels[idxstage] if isinstance(kernel, int): kernel = [kernel] * depth for i in range(depth): iblock += 1 blocks.append( block( prev_w // 2 if i else prev_w, width, width // 2, kernel=kernel[i], stride=1 if i else stride, norm_name=norm_name, act_name=act_name if idxstage >= 1 else "relu", se_r=se_r if idxstage >= 2 else 0.0, drop_path_prob=drop_path_prob * iblock / total_blocks, ) ) prev_w = width setattr(self, f"stage{idxstage}", M.Sequential(*blocks)) self.head = build_head( prev_w, head_args=head, norm_name=norm_name, act_name=act_name, ) self._initialize_weights() def _initialize_weights(self): for name, m in self.named_modules(): if isinstance(m, M.Conv2d): if "stem" in name or "head.se" in name: M.init.normal_(m.weight, 0, 0.01) else: M.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) if m.bias is not None: M.init.fill_(m.bias, 0) elif isinstance(m, M.batchnorm._BatchNorm): M.init.fill_(m.weight, 1) if m.bias is not None: M.init.fill_(m.bias, 0.0001) M.init.fill_(m.running_mean, 0) elif isinstance(m, M.Linear): M.init.normal_(m.weight, 0, 0.01) if m.bias is not None: M.init.fill_(m.bias, 0)
[文档] def forward(self, x): x = self.stem(x) if self.maxpool: x = self.maxpool(x) for i in range(self.num_stages): stage = getattr(self, f"stage{i}") x = stage(x) x = self.head(x) return x
def _build_snetv2(**kwargs): model_args = dict( block=SNV2Block, stem_w=24, depths=[4, 8, 4], strides=[2, 2, 2], kernels=[3, 3, 3], head=dict(name="ClsHead", width=1024, bias=False), ) recursive_update(model_args, kwargs) return SNetV2(**model_args) def _build_snetv2_plus(**kwargs): model_args = dict( block=SNV2XceptionBlock, stem_w=16, use_maxpool=False, depths=[4, 4, 8, 4], strides=[2, 2, 2, 2], kernels=[ [3, 3, "x", 5], [5, 5, 3, 3], [7, 3, 7, 5, 5, 3, 7, 3], [7, 5, "x", 7], ], se_r=0.25, act_name="hswish", head=dict( name="MBV3Head", w_h=1280, width=1280, se_r=0.25, dropout_prob=0.2, bias=False, ), ) recursive_update(model_args, kwargs) return SNetV2(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/snet/snetv2_x050/snetv2_x050.pkl" ) def snetv2_x050(**kwargs): model_args = dict(widths=[48, 96, 192]) recursive_update(model_args, kwargs) return _build_snetv2(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/snet/snetv2_x100/snetv2_x100.pkl" ) def snetv2_x100(**kwargs): model_args = dict(widths=[116, 232, 464]) recursive_update(model_args, kwargs) return _build_snetv2(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/snet/snetv2_x150/snetv2_x150.pkl" ) def snetv2_x150(**kwargs): model_args = dict(widths=[176, 352, 704]) recursive_update(model_args, kwargs) return _build_snetv2(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/snet/snetv2_x200/snetv2_x200.pkl" ) def snetv2_x200(**kwargs): model_args = dict(widths=[244, 488, 976], head=dict(dropout_prob=0.2, width=2048)) recursive_update(model_args, kwargs) return _build_snetv2(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/snet/snetv2p_x075/snetv2p_x075.pkl" ) def snetv2p_x075(**kwargs): model_args = dict(widths=[36, 104, 208, 416]) recursive_update(model_args, kwargs) return _build_snetv2_plus(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/snet/snetv2p_x100/snetv2p_x100.pkl" ) def snetv2p_x100(**kwargs): model_args = dict(widths=[48, 128, 256, 512]) recursive_update(model_args, kwargs) return _build_snetv2_plus(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/snet/snetv2p_x125/snetv2p_x125.pkl" ) def snetv2p_x125(**kwargs): model_args = dict(widths=[68, 168, 336, 672]) recursive_update(model_args, kwargs) return _build_snetv2_plus(**model_args)