basecls.solver.weight_decay 源代码

#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import numbers
import re
from collections import defaultdict
from typing import Dict, Iterable, Optional, Sequence, Tuple, Type, Union

import megengine.module as M
from loguru import logger
from megengine.tensor import Tensor

__all__ = ["get_param_groups"]

PatternType = Union[str, Type, Sequence[Type]]


[文档]def get_param_groups( module: M.Module, weight_decay_policy: Union[float, Sequence[Tuple[float, PatternType]]] ) -> Dict[float, Iterable[Tensor]]: """Directly get optimizer's param_groups with different weight decays given policy. ``cfg.solver.weight_decay`` can be a float or a sequence of weight decay policies. For example: .. code:: cfg.solver.weight_decay = 1e-5 is equivalent to .. code:: cfg.solver.weight_decay = [ 1e-5 ] Weight decay policy works in sequential order, i.e., for each parameter, we try patterns from the beginning to the end. If unmatched, default weight decay (-1) will be applied. For example: .. code:: from basecls.layers import NORM_TYPES cfg.solver.weight_decay = [ (1e-5, "bias"), (0, NORM_TYPES), 1e-4, ] The parameter will first match ``(1e-5, "bias")`` then ``(0, NORM_TYPES)``, so any bias parameter, including the bias of normalization layers, will have weight decay ``1e-5``. For mobile models, e.g. mobilenet and shufflenet, you may want to disable weight decay for normalization layers and any bias. This can be achieved by the following: .. code:: from basecls.layers import NORM_TYPES cfg.solver.weight_decay = [ (0, "bias"), (0, NORM_TYPES), 4e-5, ] Args: module: training model weight_decay_policy: weight decay policy defined in ``cfg.solver.weight_decay`` """ weight_decay_policy = _parse_weight_decay_policy(weight_decay_policy) # build mapping from weight_decay to parameters wd2params = defaultdict(list) wd2param_names = defaultdict(list) for n, p, m in module.named_parameters(with_parent=True): wd = _get_weight_decay(n, m, weight_decay_policy) if wd < 0: wd = weight_decay_policy[-1] assert wd >= 0, "weight decay should be non-negative" wd2params[wd].append(p) wd2param_names[wd].append(n) if len(wd2param_names) > 1: # more than one param groups for wd, param_names in wd2param_names.items(): logger.info(f"Following params has weight_decay = {wd}\n\t{param_names}") optim_params = [] for group_wd, params in wd2params.items(): optim_params.append({"params": params, "weight_decay": group_wd}) return optim_params
def _parse_weight_decay_policy(policy): """Parse and verify weight decay policy""" if isinstance(policy, numbers.Number): policy = [policy] for p in policy[:-1]: assert ( isinstance(p, Sequence) and len(p) == 2 ), "weight decay policy should be a tuple of (value, pattern), get {p}" assert isinstance(p[0], numbers.Number), "value should be a number" f", get {p[0]}" assert isinstance( p[1], (str, Type, Sequence) ), "pattern should be string, type, or list of types, get {p[1]}" p = policy[-1] assert isinstance( p, numbers.Number ), f"weight decay should fallback into default value e.g. 1e-4, get {p}" return policy def _get_weight_decay( name: str, module: M.Module, policy: Sequence[Tuple[Optional[float], Union[PatternType, Sequence[PatternType]]]], ) -> Optional[float]: """Get weight decay based on given weight decay policy. Args: name: name of parameter module: the module owning the parameter policy: weight decay policy defined in ``cfg.solver.weight_decay`` """ for p in policy: if isinstance(p, numbers.Number): # touch fallback return p else: wd, pattern = p if isinstance(pattern, str): if re.findall(pattern, name): return wd elif isinstance(pattern, (Type, Sequence)): if isinstance(module, pattern): return wd else: raise TypeError(f"Unexpected pattern {pattern}") raise RuntimeError(f"Why parameter {name} excapes weight decay assignment?")