basecls.solver.weight_decay#

basecls.solver.weight_decay.get_param_groups(module, weight_decay_policy)[源代码]#

Directly get optimizer’s param_groups with different weight decays given policy.

cfg.solver.weight_decay can be a float or a sequence of weight decay policies.

For example:

cfg.solver.weight_decay = 1e-5

is equivalent to

cfg.solver.weight_decay = [
    1e-5
]

Weight decay policy works in sequential order, i.e., for each parameter, we try patterns from the beginning to the end. If unmatched, default weight decay (-1) will be applied. For example:

from basecls.layers import NORM_TYPES
cfg.solver.weight_decay = [
    (1e-5, "bias"),
    (0, NORM_TYPES),
    1e-4,
]

The parameter will first match (1e-5, "bias") then (0, NORM_TYPES), so any bias parameter, including the bias of normalization layers, will have weight decay 1e-5.

For mobile models, e.g. mobilenet and shufflenet, you may want to disable weight decay for normalization layers and any bias. This can be achieved by the following:

from basecls.layers import NORM_TYPES
cfg.solver.weight_decay = [
    (0, "bias"),
    (0, NORM_TYPES),
    4e-5,
]
参数
返回类型

Dict[float, Iterable[Tensor]]