basecls.models.repvgg 源代码

#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
"""RepVGG Series

RegVGG: `"RepVGG: Making VGG-style ConvNets Great Again" <https://arxiv.org/abs/2101.03697>`_

References:
    https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
"""
from numbers import Real
from typing import Any, List, Mapping, Sequence, Union

import megengine
import megengine.functional as F
import megengine.hub as hub
import megengine.module as M
import numpy as np

from basecls.layers import SE, activation, init_weights
from basecls.layers.heads import build_head
from basecls.utils import recursive_update, registers

__all__ = ["RepVGGBlock", "RepVGG"]


[文档]class RepVGGBlock(M.Module): """RepVGG Reparamed Block. Args: w_in: number of input channels. w_out: number of output channels. stride: stride of the 2D conv. Default: ``1`` groups: number of groups of the 2D conv. Default: ``1`` se_r: Squeeze-and-Excitation (SE) ratio. Default: ``0.0`` act_name: activation function. Default: ``"relu"`` deploy: fuse branches into a plain ``Conv2d`` layer. Default: ``False`` """ def __init__( self, w_in: int, w_out: int, stride: int = 1, groups: int = 1, se_r: float = 0.0, act_name: str = "relu", deploy: bool = False, ): super().__init__() self.w_in = w_in self.w_out = w_out self.stride = stride self.groups = groups self.act_name = act_name self.se_r = se_r self.deploy = deploy self.nonlinearity = activation(act_name) self.se = SE(w_out, int(w_out * se_r), act_name) if se_r > 0 else M.Identity() if deploy: self.rbr_reparam = M.Conv2d(w_in, w_out, 3, stride, 1, groups=groups, bias=True) else: self.rbr_identity = M.BatchNorm2d(w_in) if w_out == w_in and stride == 1 else None self.rbr_dense = M.ConvBn2d(w_in, w_out, 3, stride, 1, groups=groups, bias=False) self.rbr_1x1 = M.ConvBn2d(w_in, w_out, 1, stride, 0, groups=groups, bias=False)
[文档] def forward(self, inputs): if hasattr(self, "rbr_reparam"): return self.nonlinearity(self.se(self.rbr_reparam(inputs))) if self.rbr_identity is None: id_out = 0 else: id_out = self.rbr_identity(inputs) return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
def _get_equivalent_kernel_bias(self): """ This func derives the equivalent kernel and bias in a DIFFERENTIABLE way. You can get the equivalent kernel and bias at any time and do whatever you want, for example, apply some penalties or constraints during training, just like you do to the other models. May be useful for quantization or pruning. """ kernel_3x3, bias_3x3 = self._fuse_bn_tensor(self.rbr_dense) kernel_1x1, bias_1x1 = self._fuse_bn_tensor(self.rbr_1x1) kernel_id, bias_id = self._fuse_bn_tensor(self.rbr_identity) return ( kernel_3x3 + self._pad_1x1_to_3x3_tensor(kernel_1x1) + kernel_id, bias_3x3 + bias_1x1 + bias_id, ) def _pad_1x1_to_3x3_tensor(self, kernel_1x1): if kernel_1x1 is None: return 0 else: kernel = F.zeros((*kernel_1x1.shape[:-2], 3, 3)) kernel[..., 1:2, 1:2] = kernel_1x1 return kernel # torch.nn.functional.pad(kernel_1x1, [1,1,1,1]) def _fuse_bn_tensor(self, branch): if branch is None: return 0, 0 if isinstance(branch, M.ConvBn2d): kernel = branch.conv.weight running_mean = branch.bn.running_mean running_var = branch.bn.running_var gamma = branch.bn.weight beta = branch.bn.bias eps = branch.bn.eps else: assert isinstance(branch, M.BatchNorm2d) if not hasattr(self, "id_tensor"): input_dim = self.w_in // self.groups kernel_value = np.zeros((self.w_in, input_dim, 3, 3), dtype=np.float32) for i in range(self.w_in): kernel_value[i, i % input_dim, 1, 1] = 1 if self.groups > 1: kernel_value = kernel_value.reshape(self.groups, -1, *kernel_value.shape[-3:]) self.id_tensor = megengine.tensor(kernel_value) kernel = self.id_tensor running_mean = branch.running_mean running_var = branch.running_var gamma = branch.weight beta = branch.bias eps = branch.eps std = F.sqrt(running_var + eps) t = (gamma / std).reshape(*kernel.shape[:-3], 1, 1, 1) return kernel * t, beta - running_mean * gamma / std
[文档] @classmethod def convert_to_deploy(cls, module) -> M.Module: """Convert/fuse reparameterized :py:class:`RepvggBlock` for deploy""" module_output = module if isinstance(module, RepVGGBlock): if module.deploy: # already in deploy return module_output module_output = RepVGGBlock( module.w_in, module.w_out, module.stride, module.groups, module.se_r, module.act_name, deploy=True, ) kernel, bias = module._get_equivalent_kernel_bias() module_output.rbr_reparam.weight[:] = kernel module_output.rbr_reparam.bias[:] = bias if isinstance(module.se, SE): module_output.se.load_state_dict(module.se.state_dict()) for name, child in module.named_children(): setattr(module_output, name, cls.convert_to_deploy(child)) del module return module_output
[文档]@registers.models.register() class RepVGG(M.Module): """RepVGG Model. Use :py:meth:`RepVGG.convert_to_deploy` to convert a training :py:class:`RepVGG` to deploy: .. code:: model = RepVGG(..., deploy=False) model.load_state_dict(...) _ = RepVGG.convert_to_deploy(model) Args: num_blocks: RepVGG depths. width_multiplier: RepVGG widths, ``base_width`` is ``[64, 128, 256, 512]``. head: head args. Default: ``None`` groups: number of groups for blocks. Default: ``1`` se_r: Squeeze-and-Excitation (SE) ratio. Default: ``0.0`` act_name: activation function. Default: ``"relu"`` deploy: switch a reparamed RepVGG into deploy mode. Default: ``False`` """ def __init__( self, num_blocks: Sequence[int], width_multiplier: Sequence[int], head: Mapping[str, Any] = None, groups: Union[int, List[Union[int, List[int]]]] = 1, se_r: float = 0.0, act_name: str = "relu", deploy: bool = False, ): super().__init__() assert len(width_multiplier) == 4 self.deploy = deploy self.se_r = se_r if isinstance(groups, Real): groups = [groups] * len(num_blocks) self.w_in = min(64, int(64 * width_multiplier[0])) self.stage0 = RepVGGBlock(3, self.w_in, 2, 1, self.se_r, act_name, self.deploy) self.stage1 = self._make_stage( int(64 * width_multiplier[0]), num_blocks[0], 2, groups[0], act_name ) self.stage2 = self._make_stage( int(128 * width_multiplier[1]), num_blocks[1], 2, groups[1], act_name ) self.stage3 = self._make_stage( int(256 * width_multiplier[2]), num_blocks[2], 2, groups[2], act_name ) self.stage4 = self._make_stage( int(512 * width_multiplier[3]), num_blocks[3], 2, groups[3], act_name ) self.head = build_head(int(512 * width_multiplier[3]), head, act_name=act_name) self.apply(init_weights) def _make_stage(self, w_out, num_blocks, stride, groups, act_name): blocks = [] if isinstance(groups, Real): groups = [groups] * num_blocks for i in range(num_blocks): blocks.append( RepVGGBlock(self.w_in, w_out, stride, groups[i], self.se_r, act_name, self.deploy) ) stride = 1 self.w_in = w_out return M.Sequential(*blocks)
[文档] def forward(self, x): out = self.stage0(x) out = self.stage1(out) out = self.stage2(out) out = self.stage3(out) out = self.stage4(out) if self.head: out = self.head(out) return out
[文档] @classmethod def convert_to_deploy(cls, module) -> M.Module: return RepVGGBlock.convert_to_deploy(module)
@registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_a0/repvgg_a0.pkl" ) def repvgg_a0(**kwargs): model_args = dict( num_blocks=[2, 4, 14, 1], width_multiplier=[0.75, 0.75, 0.75, 2.5], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_a1/repvgg_a1.pkl" ) def repvgg_a1(**kwargs): model_args = dict( num_blocks=[2, 4, 14, 1], width_multiplier=[1, 1, 1, 2.5], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_a2/repvgg_a2.pkl" ) def repvgg_a2(**kwargs): model_args = dict( num_blocks=[2, 4, 14, 1], width_multiplier=[1.5, 1.5, 1.5, 2.75], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_b0/repvgg_b0.pkl" ) def repvgg_b0(**kwargs): model_args = dict( num_blocks=[4, 6, 16, 1], width_multiplier=[1, 1, 1, 2.5], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_b1/repvgg_b1.pkl" ) def repvgg_b1(**kwargs): model_args = dict( num_blocks=[4, 6, 16, 1], width_multiplier=[2, 2, 2, 4], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_b1g2/repvgg_b1g2.pkl" ) def repvgg_b1g2(**kwargs): model_args = dict( num_blocks=[4, 6, 16, 1], width_multiplier=[2, 2, 2, 4], groups=[ [1, 2, 1, 2], [1, 2, 1, 2, 1, 2], [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2], [1], ], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_b1g4/repvgg_b1g4.pkl" ) def repvgg_b1g4(**kwargs): model_args = dict( num_blocks=[4, 6, 16, 1], width_multiplier=[2, 2, 2, 4], groups=[ [1, 4, 1, 4], [1, 4, 1, 4, 1, 4], [1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4], [1], ], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_b2/repvgg_b2.pkl" ) def repvgg_b2(**kwargs): model_args = dict( num_blocks=[4, 6, 16, 1], width_multiplier=[2.5, 2.5, 2.5, 5], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_b2g2/repvgg_b2g2.pkl" ) def repvgg_b2g2(**kwargs): model_args = dict( num_blocks=[4, 6, 16, 1], width_multiplier=[2.5, 2.5, 2.5, 5], groups=[ [1, 2, 1, 2], [1, 2, 1, 2, 1, 2], [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2], [1], ], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_b2g4/repvgg_b2g4.pkl" ) def repvgg_b2g4(**kwargs): model_args = dict( num_blocks=[4, 6, 16, 1], width_multiplier=[2.5, 2.5, 2.5, 5], groups=[ [1, 4, 1, 4], [1, 4, 1, 4, 1, 4], [1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4], [1], ], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_b3/repvgg_b3.pkl" ) def repvgg_b3(**kwargs): model_args = dict( num_blocks=[4, 6, 16, 1], width_multiplier=[3, 3, 3, 5], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_b3g2/repvgg_b3g2.pkl" ) def repvgg_b3g2(**kwargs): model_args = dict( num_blocks=[4, 6, 16, 1], width_multiplier=[3, 3, 3, 5], groups=[ [1, 2, 1, 2], [1, 2, 1, 2, 1, 2], [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2], [1], ], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_b3g4/repvgg_b3g4.pkl" ) def repvgg_b3g4(**kwargs): model_args = dict( num_blocks=[4, 6, 16, 1], width_multiplier=[3, 3, 3, 5], groups=[ [1, 4, 1, 4], [1, 4, 1, 4, 1, 4], [1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4], [1], ], head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/repvgg/repvgg_d2/repvgg_d2.pkl" ) def repvgg_d2(**kwargs): model_args = dict( num_blocks=[8, 14, 24, 1], width_multiplier=[2.5, 2.5, 2.5, 5], se_r=0.0625, head=dict(name="ClsHead"), ) recursive_update(model_args, kwargs) return RepVGG(**model_args)