basecls.layers.wrapper#

class basecls.layers.wrapper.Preprocess(mean, std)[源代码]#

基类:Module

forward(inputs)[源代码]#
返回类型

Tuple[Tensor, Tensor]

basecls.layers.wrapper.adjust_block_compatibility(ws, bs, gs)[源代码]#

Adjusts the compatibility of widths, bottlenecks and groups.

参数
返回类型

Tuple[List[int], ...]

返回

The adjusted widths, bottlenecks and groups.

basecls.layers.wrapper.calculate_fan_in_and_fan_out(tensor, pytorch_style=False)[源代码]#

Fixed megengine.module.init.calculate_fan_in_and_fan_out() for group conv2d.

备注

The group conv2d kernel shape in MegEngine is (G, O/G, I/G, K, K). This function calculates fan_out = O/G * K * K as default, but PyTorch uses fan_out = O * K * K.

参数
  • tensor (Tensor) – tensor to be initialized.

  • pytorch_style (bool) – utilize pytorch style init for group conv. Default: False

basecls.layers.wrapper.compute_precise_bn_stats(cfg, model, dataloader)[源代码]#

Computes precise BN stats on training data.

References: facebookresearch/pycls

参数
basecls.layers.wrapper.init_weights(m, pytorch_style=False, zero_init_final_gamma=False)[源代码]#

Performs ResNet-style weight initialization.

About zero-initialize: Zero-initialize the last BN in each residual branch, so that the residual branch starts with zeros, and each residual block behaves like an identity. This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677.

References: facebookresearch/pycls

参数
  • m (Module) – module to be initialized.

  • pytorch_style (bool) – utilize pytorch style init for group conv. Default: False

  • zero_init_final_gamma (bool) – enable zero-initialize or not. Default: False

basecls.layers.wrapper.init_vit_weights(module)[源代码]#

Initialization for Vision Transformer (ViT).

References: rwightman/pytorch-image-models

参数

m – module to be initialized.

basecls.layers.wrapper.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0)[源代码]#
basecls.layers.wrapper.lecun_normal_(tensor)[源代码]#
basecls.layers.wrapper.make_divisible(value, divisor=8, min_value=None, round_limit=0.0)[源代码]#