basecls.layers.wrapper#
- basecls.layers.wrapper.adjust_block_compatibility(ws, bs, gs)[源代码]#
Adjusts the compatibility of widths, bottlenecks and groups.
- basecls.layers.wrapper.calculate_fan_in_and_fan_out(tensor, pytorch_style=False)[源代码]#
Fixed
megengine.module.init.calculate_fan_in_and_fan_out()
for group conv2d.备注
The group conv2d kernel shape in MegEngine is
(G, O/G, I/G, K, K)
. This function calculatesfan_out = O/G * K * K
as default, but PyTorch usesfan_out = O * K * K
.
- basecls.layers.wrapper.compute_precise_bn_stats(cfg, model, dataloader)[源代码]#
Computes precise BN stats on training data.
References: facebookresearch/pycls
- 参数
cfg (
ConfigDict
) – config for precising BN.model (
Module
) – model for precising BN.dataloader (
Union
[DataLoader
,FakeDataLoader
]) – dataloader for precising BN.
- basecls.layers.wrapper.init_weights(m, pytorch_style=False, zero_init_final_gamma=False)[源代码]#
Performs ResNet-style weight initialization.
About zero-initialize: Zero-initialize the last BN in each residual branch, so that the residual branch starts with zeros, and each residual block behaves like an identity. This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677.
References: facebookresearch/pycls
- basecls.layers.wrapper.init_vit_weights(module)[源代码]#
Initialization for Vision Transformer (ViT).
References: rwightman/pytorch-image-models
- 参数
m – module to be initialized.