#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
"""ResNet Series
ResNet: `"Deep Residual Learning for Image Recognition" <https://arxiv.org/abs/1512.03385>`_
ResNet-D: `"Bag of Tricks for Image Classification with Convolutional Neural Networks"
<https://arxiv.org/abs/1812.01187>`_
ResNeXt: `"Aggregated Residual Transformations for Deep Neural Networks"
<https://arxiv.org/abs/1611.05431>`_
Se-ResNet: `"Squeeze-and-Excitation Networks" <https://arxiv.org/abs/1709.01507>`_
Wide ResNet: `"Wide Residual Networks" <https://arxiv.org/abs/1605.07146>`_
References:
https://github.com/facebookresearch/pycls/blob/main/pycls/models/anynet.py
https://github.com/facebookresearch/pycls/blob/main/pycls/models/resnet.py
"""
from functools import partial
from numbers import Real
from typing import Any, Callable, Mapping, Sequence, Union
import megengine.hub as hub
import megengine.module as M
from basecls.layers import (
SE,
DropPath,
activation,
build_head,
conv2d,
init_weights,
make_divisible,
norm2d,
pool2d,
)
from basecls.utils import recursive_update, registers
__all__ = [
"ResBasicBlock",
"ResBottleneckBlock",
"ResDeepStem",
"ResStem",
"SimpleStem",
"AnyStage",
"ResNet",
]
[文档]class ResBasicBlock(M.Module):
"""Residual basic block: x + f(x), f = [3x3 conv, BN, Act] x2."""
def __init__(
self,
w_in: int,
w_out: int,
stride: int,
bot_mul: float,
se_r: float,
avg_down: bool,
drop_path_prob: float,
norm_name: str,
act_name: str,
**kwargs,
):
super().__init__()
if w_in != w_out or stride > 1:
if avg_down and stride > 1:
self.pool = M.AvgPool2d(2, stride)
self.proj = conv2d(w_in, w_out, 1)
else:
self.proj = conv2d(w_in, w_out, 1, stride=stride)
self.bn = norm2d(norm_name, w_out)
w_b = round(w_out * bot_mul)
w_se = make_divisible(w_out * se_r) if se_r > 0.0 else 0
self.a = conv2d(w_in, w_b, 3, stride=stride)
self.a_bn = norm2d(norm_name, w_b)
self.a_act = activation(act_name)
self.b = conv2d(w_b, w_out, 3)
self.b_bn = norm2d(norm_name, w_out)
self.b_bn.final_bn = True
if w_se > 0:
self.se = SE(w_out, w_se, act_name)
self.drop_path = DropPath(drop_path_prob)
self.act = activation(act_name)
[文档] def forward(self, x):
x_p = x
if getattr(self, "pool", None) is not None:
x_p = self.pool(x_p)
if getattr(self, "proj", None) is not None:
x_p = self.proj(x_p)
x_p = self.bn(x_p)
x = self.a(x)
x = self.a_bn(x)
x = self.a_act(x)
x = self.b(x)
x = self.b_bn(x)
if getattr(self, "se", None) is not None:
x = self.se(x)
x = self.drop_path(x)
x += x_p
x = self.act(x)
return x
[文档]class ResBottleneckBlock(M.Module):
"""Residual bottleneck block: x + f(x), f = 1x1, 3x3, 1x1 [+SE]."""
def __init__(
self,
w_in: int,
w_out: int,
stride: int,
bot_mul: float,
group_w: int,
se_r: float,
avg_down: bool,
drop_path_prob: float,
norm_name: str,
act_name: str,
**kwargs,
):
super().__init__()
if w_in != w_out or stride > 1:
if avg_down and stride > 1:
self.pool = M.AvgPool2d(2, stride)
self.proj = conv2d(w_in, w_out, 1)
else:
self.proj = conv2d(w_in, w_out, 1, stride=stride)
self.bn = norm2d(norm_name, w_out)
w_b = round(w_out * bot_mul)
w_se = make_divisible(w_out * se_r) if se_r > 0.0 else 0
groups = w_b // group_w
self.a = conv2d(w_in, w_b, 1)
self.a_bn = norm2d(norm_name, w_b)
self.a_act = activation(act_name)
self.b = conv2d(w_b, w_b, 3, stride=stride, groups=groups)
self.b_bn = norm2d(norm_name, w_b)
self.b_act = activation(act_name)
self.c = conv2d(w_b, w_out, 1)
self.c_bn = norm2d(norm_name, w_out)
self.c_bn.final_bn = True
if w_se > 0:
self.se = SE(w_out, w_se, act_name)
self.drop_path = DropPath(drop_path_prob)
self.act = activation(act_name)
[文档] def forward(self, x):
x_p = x
if getattr(self, "pool", None) is not None:
x_p = self.pool(x_p)
if getattr(self, "proj", None) is not None:
x_p = self.proj(x_p)
x_p = self.bn(x_p)
x = self.a(x)
x = self.a_bn(x)
x = self.a_act(x)
x = self.b(x)
x = self.b_bn(x)
x = self.b_act(x)
x = self.c(x)
x = self.c_bn(x)
if getattr(self, "se", None) is not None:
x = self.se(x)
x = self.drop_path(x)
x += x_p
x = self.act(x)
return x
[文档]class ResDeepStem(M.Module):
"""ResNet-D stem: [3x3, BN, Act] x3, MaxPool."""
def __init__(self, w_in: int, w_out: int, norm_name: str, act_name: str, **kwargs):
super().__init__()
w_b = w_out // 2
self.a = conv2d(w_in, w_b, 3, stride=2)
self.a_bn = norm2d(norm_name, w_b)
self.a_act = activation(act_name)
self.b = conv2d(w_b, w_b, 3, stride=1)
self.b_bn = norm2d(norm_name, w_b)
self.b_act = activation(act_name)
self.c = conv2d(w_b, w_out, 3, stride=1)
self.c_bn = norm2d(norm_name, w_out)
self.c_act = activation(act_name)
self.pool = pool2d(3, stride=2)
[文档] def forward(self, x):
x = self.a(x)
x = self.a_bn(x)
x = self.a_act(x)
x = self.b(x)
x = self.b_bn(x)
x = self.b_act(x)
x = self.c(x)
x = self.c_bn(x)
x = self.c_act(x)
x = self.pool(x)
return x
[文档]class ResStem(M.Module):
"""ResNet stem: 7x7, BN, Act, MaxPool."""
def __init__(self, w_in: int, w_out: int, norm_name: str, act_name: str, **kwargs):
super().__init__()
self.conv = conv2d(w_in, w_out, 7, stride=2)
self.bn = norm2d(norm_name, w_out)
self.act = activation(act_name)
self.pool = pool2d(3, stride=2)
[文档] def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
x = self.pool(x)
return x
[文档]class SimpleStem(M.Module):
"""Simple stem: 3x3, BN, Act."""
def __init__(self, w_in: int, w_out: int, norm_name: str, act_name: str, **kwargs):
super().__init__()
self.conv = conv2d(w_in, w_out, 3, stride=2)
self.bn = norm2d(norm_name, w_out)
self.act = activation(act_name)
[文档] def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
return x
[文档]class AnyStage(M.Module):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
def __init__(
self,
w_in: int,
w_out: int,
stride: int,
depth: int,
block_func: Callable,
drop_path_prob: Sequence[float],
**kwargs,
):
super().__init__()
self.depth = depth
for i in range(depth):
block = block_func(w_in, w_out, stride, drop_path_prob=drop_path_prob[i], **kwargs)
setattr(self, f"b{i + 1}", block)
stride, w_in = 1, w_out
def __len__(self):
return self.depth
[文档] def forward(self, x):
for i in range(self.depth):
block = getattr(self, f"b{i + 1}")
x = block(x)
return x
[文档]@registers.models.register()
class ResNet(M.Module):
"""ResNet model.
Args:
stem_name: stem name.
stem_w: stem width.
block_name: block name.
depths: depth for each stage (number of blocks in the stage).
widths: width for each stage (width of each block in the stage).
strides: strides for each stage (applies to the first block of each stage).
bot_muls: bottleneck multipliers for each stage (applies to bottleneck block).
Default: ``1.0``
group_ws: group widths for each stage (applies to bottleneck block). Default: ``None``
se_r: Squeeze-and-Excitation (SE) ratio. Default: ``0.0``
drop_path_prob: drop path probability. Default: ``0.0``
zero_init_final_gamma: enable zero-initialize or not. Default: ``False``
norm_name: normalization function. Default: ``"BN"``
act_name: activation function. Default: ``"relu"``
head: head args. Default: ``None``
"""
def __init__(
self,
stem_name: Union[str, Callable],
stem_w: int,
block_name: Union[str, Callable],
depths: Sequence[int],
widths: Sequence[int],
strides: Sequence[int],
bot_muls: Union[float, Sequence[float]] = 1.0,
group_ws: Sequence[int] = None,
se_r: float = 0.0,
avg_down: bool = False,
drop_path_prob: float = 0.0,
zero_init_final_gamma: bool = False,
norm_name: str = "BN",
act_name: str = "relu",
head: Mapping[str, Any] = None,
):
super().__init__()
self.depths = depths
stem_func = self.get_stem_func(stem_name)
self.stem = stem_func(3, stem_w, norm_name, act_name)
block_func = self.get_block_func(block_name)
if isinstance(bot_muls, Real):
bot_muls = [bot_muls] * len(depths)
if group_ws is None:
group_ws = [None] * len(depths)
drop_path_prob_iter = (i / sum(depths) * drop_path_prob for i in range(sum(depths)))
drop_path_probs = [[next(drop_path_prob_iter) for _ in range(d)] for d in depths]
model_args = [depths, widths, strides, bot_muls, group_ws, drop_path_probs]
prev_w = stem_w
for i, (d, w, s, b, g, dp_p) in enumerate(zip(*model_args)):
stage = AnyStage(
prev_w,
w,
s,
d,
block_func,
bot_mul=b,
group_w=g,
se_r=se_r,
avg_down=avg_down,
drop_path_prob=dp_p,
norm_name=norm_name,
act_name=act_name,
)
setattr(self, f"s{i + 1}", stage)
prev_w = w
self.head = build_head(prev_w, head, norm_name, act_name)
self.apply(partial(init_weights, zero_init_final_gamma=zero_init_final_gamma))
[文档] def forward(self, x):
x = self.stem(x)
for i in range(len(self.depths)):
stage = getattr(self, f"s{i + 1}")
x = stage(x)
if getattr(self, "head", None) is not None:
x = self.head(x)
return x
[文档] @staticmethod
def get_stem_func(name: Union[str, Callable]):
"""Retrieves the stem function by name."""
if callable(name):
return name
if isinstance(name, str):
stem_funcs = {
"ResDeepStem": ResDeepStem,
"ResStem": ResStem,
"SimpleStem": SimpleStem,
}
if name in stem_funcs.keys():
return stem_funcs[name]
raise ValueError(f"Stem '{name}' not supported")
[文档] @staticmethod
def get_block_func(name: Union[str, Callable]):
"""Retrieves the block function by name."""
if callable(name):
return name
if isinstance(name, str):
block_funcs = {
"ResBasicBlock": ResBasicBlock,
"ResBottleneckBlock": ResBottleneckBlock,
}
if name in block_funcs.keys():
return block_funcs[name]
raise ValueError(f"Block '{name}' not supported")
def _build_resnet(**kwargs):
model_args = dict(stem_name=ResStem, stem_w=64, head=dict(name="ClsHead"))
recursive_update(model_args, kwargs)
return ResNet(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/resnet18/resnet18.pkl"
)
def resnet18(**kwargs):
model_args = dict(
block_name=ResBasicBlock,
depths=[2, 2, 2, 2],
widths=[64, 128, 256, 512],
strides=[1, 2, 2, 2],
)
recursive_update(model_args, kwargs)
return _build_resnet(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/resnet34/resnet34.pkl"
)
def resnet34(**kwargs):
model_args = dict(
block_name=ResBasicBlock,
depths=[3, 4, 6, 3],
widths=[64, 128, 256, 512],
strides=[1, 2, 2, 2],
)
recursive_update(model_args, kwargs)
return _build_resnet(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/resnet50/resnet50.pkl"
)
def resnet50(**kwargs):
model_args = dict(
block_name=ResBottleneckBlock,
depths=[3, 4, 6, 3],
widths=[256, 512, 1024, 2048],
strides=[1, 2, 2, 2],
bot_muls=[0.25, 0.25, 0.25, 0.25],
group_ws=[64, 128, 256, 512],
)
recursive_update(model_args, kwargs)
return _build_resnet(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/resnet101/resnet101.pkl"
)
def resnet101(**kwargs):
model_args = dict(
block_name=ResBottleneckBlock,
depths=[3, 4, 23, 3],
widths=[256, 512, 1024, 2048],
strides=[1, 2, 2, 2],
bot_muls=[0.25, 0.25, 0.25, 0.25],
group_ws=[64, 128, 256, 512],
)
recursive_update(model_args, kwargs)
return _build_resnet(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/resnet152/resnet152.pkl"
)
def resnet152(**kwargs):
model_args = dict(
block_name=ResBottleneckBlock,
depths=[3, 8, 36, 3],
widths=[256, 512, 1024, 2048],
strides=[1, 2, 2, 2],
bot_muls=[0.25, 0.25, 0.25, 0.25],
group_ws=[64, 128, 256, 512],
)
recursive_update(model_args, kwargs)
return _build_resnet(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/resnet18d/resnet18d.pkl"
)
def resnet18d(**kwargs):
model_args = dict(stem_name=ResDeepStem, avg_down=True)
recursive_update(model_args, kwargs)
return resnet18(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/resnet34d/resnet34d.pkl"
)
def resnet34d(**kwargs):
model_args = dict(stem_name=ResDeepStem, avg_down=True)
recursive_update(model_args, kwargs)
return resnet34(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/resnet50d/resnet50d.pkl"
)
def resnet50d(**kwargs):
model_args = dict(stem_name=ResDeepStem, avg_down=True)
recursive_update(model_args, kwargs)
return resnet50(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/resnet101d/resnet101d.pkl"
)
def resnet101d(**kwargs):
model_args = dict(stem_name=ResDeepStem, avg_down=True)
recursive_update(model_args, kwargs)
return resnet101(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/resnet152d/resnet152d.pkl"
)
def resnet152d(**kwargs):
model_args = dict(stem_name=ResDeepStem, avg_down=True)
recursive_update(model_args, kwargs)
return resnet152(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/resnext50_32x4d/resnext50_32x4d.pkl"
)
def resnext50_32x4d(**kwargs):
model_args = dict(bot_muls=[0.5, 0.5, 0.5, 0.5], group_ws=[4, 8, 16, 32])
recursive_update(model_args, kwargs)
return resnet50(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/resnext101_32x4d/resnext101_32x4d.pkl"
)
def resnext101_32x4d(**kwargs):
model_args = dict(bot_muls=[0.5, 0.5, 0.5, 0.5], group_ws=[4, 8, 16, 32])
recursive_update(model_args, kwargs)
return resnet101(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/resnext101_32x8d/resnext101_32x8d.pkl"
)
def resnext101_32x8d(**kwargs):
model_args = dict(bot_muls=[1.0, 1.0, 1.0, 1.0], group_ws=[8, 16, 32, 64])
recursive_update(model_args, kwargs)
return resnet101(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/resnext101_64x4d/resnext101_64x4d.pkl"
)
def resnext101_64x4d(**kwargs):
model_args = dict(bot_muls=[1.0, 1.0, 1.0, 1.0], group_ws=[4, 8, 16, 32])
recursive_update(model_args, kwargs)
return resnet101(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/resnext152_32x4d/resnext152_32x4d.pkl"
)
def resnext152_32x4d(**kwargs):
model_args = dict(bot_muls=[0.5, 0.5, 0.5, 0.5], group_ws=[4, 8, 16, 32])
recursive_update(model_args, kwargs)
return resnet152(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/resnext152_32x8d/resnext152_32x8d.pkl"
)
def resnext152_32x8d(**kwargs):
model_args = dict(bot_muls=[1.0, 1.0, 1.0, 1.0], group_ws=[8, 16, 32, 64])
recursive_update(model_args, kwargs)
return resnet152(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/resnext152_64x4d/resnext152_64x4d.pkl"
)
def resnext152_64x4d(**kwargs):
model_args = dict(bot_muls=[1.0, 1.0, 1.0, 1.0], group_ws=[4, 8, 16, 32])
recursive_update(model_args, kwargs)
return resnet152(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/se_resnet18/se_resnet18.pkl"
)
def se_resnet18(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnet18(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/se_resnet34/se_resnet34.pkl"
)
def se_resnet34(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnet34(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/se_resnet50/se_resnet50.pkl"
)
def se_resnet50(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnet50(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/se_resnet101/se_resnet101.pkl"
)
def se_resnet101(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnet101(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/resnet/se_resnet152/se_resnet152.pkl"
)
def se_resnet152(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnet152(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/se_resnext50_32x4d/se_resnext50_32x4d.pkl"
)
def se_resnext50_32x4d(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnext50_32x4d(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/se_resnext101_32x4d/se_resnext101_32x4d.pkl"
)
def se_resnext101_32x4d(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnext101_32x4d(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/se_resnext101_32x8d/se_resnext101_32x8d.pkl"
)
def se_resnext101_32x8d(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnext101_32x8d(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/se_resnext101_64x4d/se_resnext101_64x4d.pkl"
)
def se_resnext101_64x4d(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnext101_64x4d(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/se_resnext152_32x4d/se_resnext152_32x4d.pkl"
)
def se_resnext152_32x4d(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnext152_32x4d(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/se_resnext152_32x8d/se_resnext152_32x8d.pkl"
)
def se_resnext152_32x8d(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnext152_32x8d(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/se_resnext152_64x4d/se_resnext152_64x4d.pkl"
)
def se_resnext152_64x4d(**kwargs):
model_args = dict(se_r=0.0625)
recursive_update(model_args, kwargs)
return resnext152_64x4d(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/wide_resnet50_2/wide_resnet50_2.pkl"
)
def wide_resnet50_2(**kwargs):
model_args = dict(bot_muls=[0.5, 0.5, 0.5, 0.5])
recursive_update(model_args, kwargs)
return resnet50(**model_args)
@registers.models.register()
@hub.pretrained(
"https://data.megengine.org.cn/research/basecls/models/"
"resnet/wide_resnet101_2/wide_resnet101_2.pkl"
)
def wide_resnet101_2(**kwargs):
model_args = dict(bot_muls=[0.5, 0.5, 0.5, 0.5])
recursive_update(model_args, kwargs)
return resnet101(**model_args)