basecls.models.hrnet 源代码

#!/usr/bin/env python3
# Copyright (c) Microsoft
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
"""HRNet Series

HRNet: `"Deep High-Resolution Representation Learning for Visual Recognition"
<https://arxiv.org/abs/1908.07919>`_

References:
    https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py
"""
from collections import OrderedDict
from functools import partial
from typing import Any, List, Mapping, Optional

import megengine.functional as F
import megengine.hub as hub
import megengine.module as M

from basecls.layers import build_head, conv2d, init_weights, norm2d
from basecls.layers.activations import activation
from basecls.utils import recursive_update, registers

from .resnet import ResBasicBlock, ResBottleneckBlock

__all__ = [
    "UpsampleNearest",
    "HRFusion",
    "HRModule",
    "HRTrans",
    "HRStage",
    "HRMerge",
    "HRNet",
]


[文档]class UpsampleNearest(M.Module): """Nearest upsample block Args: scale_factor: Upsample scale factor. """ def __init__(self, scale_factor: int): super().__init__() self.scale_factor = scale_factor
[文档] def forward(self, x): return F.repeat(F.repeat(x, self.scale_factor, axis=2), self.scale_factor, axis=3)
block_dict = { "basic": ( partial(ResBasicBlock, stride=1, bot_mul=1, se_r=0, avg_down=False, drop_path_prob=0.0), 1, ), "bottleneck": ( lambda w_out, **kwargs: ResBottleneckBlock( w_out=w_out * 4, stride=1, bot_mul=0.25, group_w=w_out, se_r=0, avg_down=False, drop_path_prob=0.0, **kwargs, ), 4, ), }
[文档]class HRFusion(M.Module): """HRNet fusion block. Args: channels: Fusion channels. multi_scale_output: Whether output multi-scale features. norm_name: Normalization layer. act_name: Activation function. """ def __init__( self, channels: List[int], multi_scale_output: bool, norm_name: str, act_name: str, ): super().__init__() self.multi_scale_output = multi_scale_output num_branches = len(channels) for f_o in range(num_branches if multi_scale_output else 1): for f_i in range(num_branches): fuse_layer = None if f_i < f_o: fuse_layer = [] for i in range(f_o - f_i): w_out = channels[f_i if i < f_o - f_i - 1 else f_o] fuse_layer.extend( [ ( f"conv_{i + 1}", conv2d(w_in=channels[f_i], w_out=w_out, k=3, stride=2), ), (f"norm_{i + 1}", norm2d(norm_name, w_out)), ] ) if i < f_o - f_i - 1: fuse_layer.append((f"act_{i + 1}", activation(act_name))) fuse_layer = M.Sequential(OrderedDict(fuse_layer)) elif f_i > f_o: fuse_layer = M.Sequential( OrderedDict( [ ("conv_1", conv2d(w_in=channels[f_i], w_out=channels[f_o], k=1)), ("norm_1", norm2d(norm_name, channels[f_o])), ("upsample_1", UpsampleNearest(scale_factor=2 ** (f_i - f_o))), ] ) ) setattr(self, f"fuse_{f_i + 1}_{f_o + 1}", fuse_layer) self.act = activation(act_name)
[文档] def forward(self, x_list): x_fuse = [] for f_o in range(len(x_list) if self.multi_scale_output else 1): x_sum = None for f_i, x in enumerate(x_list): fuse_layer = getattr(self, f"fuse_{f_i + 1}_{f_o + 1}", None) if fuse_layer: x = fuse_layer(x) x_sum = x if x_sum is None else x_sum + x x_sum = self.act(x_sum) x_fuse.append(x_sum) return x_fuse
[文档]class HRModule(M.Module): """HRNet module. Args: block_name: Branch block type. num_blocks: Number of blocks. in_channels: Input channels. channels: Output channels. multi_scale_output: Whether output multi-scale features. norm_name: Normalization layer. act_name: Activation function. """ def __init__( self, block_name: str, num_blocks: List[int], in_channels: List[int], channels: List[int], multi_scale_output: bool, norm_name: str, act_name: str, ): super().__init__() for i, (w_in, w_out, num_block) in enumerate(zip(in_channels, channels, num_blocks)): branch = self._make_branch( w_in=w_in, w_out=w_out, block_name=block_name, num_block=num_block, norm_name=norm_name, act_name=act_name, ) setattr(self, f"branch{i + 1}", branch) self.fusion = None if len(channels) > 1: _, out_mul = block_dict[block_name] fusion_channels = [out_mul * c for c in channels] self.fusion = HRFusion( channels=fusion_channels, multi_scale_output=multi_scale_output, norm_name=norm_name, act_name=act_name, ) def _make_branch( self, w_in: int, w_out: int, block_name: str, num_block: int, norm_name: str, act_name: str, ): block_fn, out_mul = block_dict[block_name] return M.Sequential( OrderedDict( [ ( f"block{i + 1}", block_fn( w_in=w_out * out_mul if i else w_in, w_out=w_out, norm_name=norm_name, act_name=act_name, ), ) for i in range(num_block) ] ) )
[文档] def forward(self, x_list): x_list = [getattr(self, f"branch{i + 1}")(x) for i, x in enumerate(x_list)] if self.fusion: x_list = self.fusion(x_list) return x_list
[文档]class HRTrans(M.Module): """HRNet transition block. Args: in_chs: Input channels. out_chs: Output channels. norm_name: Normalization layer. act_name: Activation function. """ def __init__( self, in_chs: List[int], out_chs: List[int], norm_name: str, act_name: str, ): super().__init__() n_in, n_out = len(in_chs), len(out_chs) self.num_trans = n_out for t_o in range(n_out): if t_o < n_in: trans_layer = ( M.Sequential( OrderedDict( [ ("conv_1", conv2d(w_in=in_chs[t_o], w_out=out_chs[t_o], k=3)), ("norm_1", norm2d(norm_name, out_chs[t_o])), ("act_1", activation(act_name)), ] ) ) if in_chs[t_o] != out_chs[t_o] else None ) else: trans_layer = [] for i in range(t_o - n_in + 1): w_out = out_chs[t_o] if i == t_o - n_in else in_chs[-1] trans_layer.extend( [ (f"conv_{i + 1}", conv2d(w_in=in_chs[-1], w_out=w_out, k=3, stride=2)), (f"norm_{i + 1}", norm2d(norm_name, w_out)), (f"act_{i + 1}", activation(act_name)), ] ) trans_layer = M.Sequential(OrderedDict(trans_layer)) setattr(self, f"trans_{t_o + 1}", trans_layer)
[文档] def forward(self, x_list): x_trans = [] x_list = x_list + [x_list[-1]] * (self.num_trans - len(x_list)) for t_o, x in enumerate(x_list): trans_layer = getattr(self, f"trans_{t_o + 1}", None) if trans_layer: x = trans_layer(x) x_trans.append(x) return x_trans
[文档]class HRStage(M.Module): """HRNet stage. Args: num_modules: Number of modules. num_blocks: Number of blocks for each module. block_name: Branch block type. pre_channels: Channels of previous stage (an empty list for the first stage). cur_channels: Channels of current stage. multi_scale_output: Whether output multi-scale features. w_fst: Width of stem for the first stage (``None`` for other stages). norm_name: Normalization layer. act_name: Activation function. """ def __init__( self, num_modules: int, num_blocks: List[int], block_name: str, pre_channels: List[int], cur_channels: List[int], multi_scale_output: bool, w_fst: Optional[int], norm_name: str, act_name: str, ): super().__init__() self.transition = None _, out_mul = block_dict[block_name] mid_channels = [out_mul * c for c in cur_channels] if w_fst: fst_channels = [w_fst] else: self.transition = HRTrans( in_chs=pre_channels, out_chs=mid_channels, norm_name=norm_name, act_name=act_name, ) self.num_modules = num_modules for i in range(num_modules): module = HRModule( block_name=block_name, num_blocks=num_blocks, in_channels=fst_channels if w_fst and i == 0 else mid_channels, channels=cur_channels, multi_scale_output=multi_scale_output or i < num_modules - 1, norm_name=norm_name, act_name=act_name, ) setattr(self, f"module{i + 1}", module)
[文档] def forward(self, x_list): if self.transition: x_list = self.transition(x_list) for i in range(self.num_modules): module = getattr(self, f"module{i + 1}") x_list = module(x_list) return x_list
[文档]class HRMerge(M.Module): """HRNet merge block. Args: block_name: Head block type. pre_channels: Channels of the last stage. channels: Channels of each scale to merge. norm_name: Normalization layer. act_name: Activation function. """ def __init__( self, block_name: str, pre_channels: List[int], channels: List[int], norm_name: str, act_name: str, ): super().__init__() block_fn, out_mul = block_dict[block_name] for i, (w_in, w_out) in enumerate(zip(pre_channels, channels)): branch = block_fn( w_in=w_in, w_out=w_out, norm_name=norm_name, act_name=act_name, ) setattr(self, f"branch{i + 1}", branch) _, out_mul = block_dict[block_name] for i, (w_in, w_out) in enumerate(zip(channels[:-1], channels[1:])): dnsample = M.Sequential( OrderedDict( [ ( "conv", conv2d( w_in=w_in * out_mul, w_out=w_out * out_mul, k=3, stride=2, bias=True, ), ), ("norm", norm2d(norm_name, w_out * out_mul)), ("act", activation(act_name)), ] ) ) setattr(self, f"dnsample{i + 1}", dnsample)
[文档] def forward(self, x_list): x = getattr(self, "branch1")(x_list[0]) for i in range(len(x_list) - 1): dnsample = getattr(self, f"dnsample{i + 1}") branch = getattr(self, f"branch{i + 2}") x = dnsample(x) + branch(x_list[i + 1]) return x
[文档]@registers.models.register() class HRNet(M.Module): """HRNet model. Args: stage_modules: Number of modules for each stage. stage_blocks: Number of blocks for each module in stages. stage_block_names: Branch block types for each stage. stage_channels: Number of channels for each stage. w_stem: Stem width. Default: ``64`` multi_scale_output: Whether output multi-scale features. Default: ``True`` merge_block_name: Merge block type. Default: ``"bottleneck"`` merge_channels: Channels of each scale in merge block. Default: ``[32, 64, 128, 256]`` norm_name: Normalization layer. Default: ``"BN"`` act_name: Activation function. Default: ``"relu"`` head: head args. Default: ``None`` """ def __init__( self, stage_modules: List[int], stage_blocks: List[List[int]], stage_block_names: List[str], stage_channels: List[List[int]], w_stem: int = 64, multi_scale_output: bool = True, merge_block_name: str = "bottleneck", merge_channels: List[int] = [32, 64, 128, 256], norm_name: str = "BN", act_name: str = "relu", head: Mapping[str, Any] = None, **kwargs, ): super().__init__() self.stem = M.Sequential( OrderedDict( [ ("conv_1", conv2d(w_in=3, w_out=w_stem, k=3, stride=2)), ("norm_1", norm2d(norm_name, w_in=w_stem)), ("act_1", activation(act_name)), ("conv_2", conv2d(w_in=64, w_out=w_stem, k=3, stride=2)), ("norm_2", norm2d(norm_name, w_in=w_stem)), ("act_2", activation(act_name)), ] ) ) self.num_stages = len(stage_modules) pre_channels = [] for i in range(self.num_stages): stage = HRStage( num_modules=stage_modules[i], num_blocks=stage_blocks[i], block_name=stage_block_names[i], pre_channels=pre_channels, cur_channels=stage_channels[i], multi_scale_output=multi_scale_output, w_fst=None if i else w_stem, norm_name=norm_name, act_name=act_name, ) setattr(self, f"stage{i + 1}", stage) _, out_mul = block_dict[stage_block_names[i]] pre_channels = [out_mul * c for c in stage_channels[i]] self.merge = HRMerge( block_name=merge_block_name, pre_channels=pre_channels, channels=merge_channels, norm_name=norm_name, act_name=act_name, ) w_merge = merge_channels[-1] * block_dict[merge_block_name][1] self.head = build_head(w_merge, head, norm_name, act_name) self.apply(init_weights)
[文档] def forward(self, x): x = self.stem(x) x_list = [x] for i in range(self.num_stages): stage = getattr(self, f"stage{i + 1}") x_list = stage(x_list) x = self.merge(x_list) if getattr(self, "head", None) is not None: x = self.head(x) return x
def _build_hrnet(**kwargs): model_args = dict( stage_modules=[1, 1, 4, 3], stage_blocks=[[4], [4, 4], [4, 4, 4], [4, 4, 4, 4]], stage_block_names=["bottleneck", "basic", "basic", "basic"], head=dict(name="ClsHead", width=2048), ) recursive_update(model_args, kwargs) return HRNet(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "hrnet/hrnet_w18_small_v1/hrnet_w18_small_v1.pkl" ) def hrnet_w18_small_v1(**kwargs): model_args = dict( stage_modules=[1, 1, 1, 1], stage_blocks=[[1], [2, 2], [2, 2, 2], [2, 2, 2, 2]], stage_channels=[[32], [16, 32], [16, 32, 64], [16, 32, 64, 128]], ) recursive_update(model_args, kwargs) return _build_hrnet(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "hrnet/hrnet_w18_small_v2/hrnet_w18_small_v2.pkl" ) def hrnet_w18_small_v2(**kwargs): model_args = dict( stage_modules=[1, 1, 3, 2], stage_blocks=[[2], [2, 2], [2, 2, 2], [2, 2, 2, 2]], stage_channels=[[64], [18, 36], [18, 36, 72], [18, 36, 72, 144]], ) recursive_update(model_args, kwargs) return _build_hrnet(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/hrnet/hrnet_w18/hrnet_w18.pkl" ) def hrnet_w18(**kwargs): model_args = dict( stage_channels=[[64], [18, 36], [18, 36, 72], [18, 36, 72, 144]], ) recursive_update(model_args, kwargs) return _build_hrnet(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/hrnet/hrnet_w30/hrnet_w30.pkl" ) def hrnet_w30(**kwargs): model_args = dict( stage_channels=[[64], [30, 60], [30, 60, 120], [30, 60, 120, 240]], ) recursive_update(model_args, kwargs) return _build_hrnet(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/hrnet/hrnet_w32/hrnet_w32.pkl" ) def hrnet_w32(**kwargs): model_args = dict( stage_channels=[[64], [32, 64], [32, 64, 128], [32, 64, 128, 256]], ) recursive_update(model_args, kwargs) return _build_hrnet(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/hrnet/hrnet_w40/hrnet_w40.pkl" ) def hrnet_w40(**kwargs): model_args = dict( stage_channels=[[64], [40, 80], [40, 80, 160], [40, 80, 160, 320]], ) recursive_update(model_args, kwargs) return _build_hrnet(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/hrnet/hrnet_w44/hrnet_w44.pkl" ) def hrnet_w44(**kwargs): model_args = dict( stage_channels=[[64], [44, 88], [44, 88, 176], [44, 88, 176, 352]], ) recursive_update(model_args, kwargs) return _build_hrnet(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/hrnet/hrnet_w48/hrnet_w48.pkl" ) def hrnet_w48(**kwargs): model_args = dict( stage_channels=[[64], [48, 96], [48, 96, 192], [48, 96, 192, 384]], ) recursive_update(model_args, kwargs) return _build_hrnet(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/hrnet/hrnet_w64/hrnet_w64.pkl" ) def hrnet_w64(**kwargs): model_args = dict( stage_channels=[[64], [64, 128], [64, 128, 256], [64, 128, 256, 512]], ) recursive_update(model_args, kwargs) return _build_hrnet(**model_args)