basecls.layers.losses 源代码

#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import megengine as mge
import megengine.functional as F
import megengine.module as M
from basecore.config import ConfigDict

__all__ = ["build_loss", "BinaryCrossEntropy", "CrossEntropy"]


[文档]def build_loss(cfg: ConfigDict) -> M.Module: """The factory function to build loss. Args: cfg: config for building loss function. Returns: A loss function. """ loss_args = cfg.loss.to_dict() loss_name = loss_args.pop("name", None) if loss_name is None: raise ValueError("Loss name is missing") if callable(loss_name): return loss_name(**loss_args) if isinstance(loss_name, str): loss_funcs = { "BinaryCrossEntropy": BinaryCrossEntropy, "CrossEntropy": CrossEntropy, } if loss_name in loss_funcs: return loss_funcs[loss_name](**loss_args) raise ValueError(f"Loss '{loss_name}' not supported")
[文档]class BinaryCrossEntropy(M.Module): """The module for binary cross entropy. See :py:func:`~megengine.functional.loss.binary_cross_entropy` for more details. """ def __init__(self, **kwargs): super().__init__()
[文档] def forward(self, x: mge.Tensor, y: mge.Tensor) -> mge.Tensor: return F.loss.binary_cross_entropy(x, y)
[文档]class CrossEntropy(M.Module): """The module for cross entropy. It supports both categorical labels and one-hot labels. See :py:func:`~megengine.functional.loss.cross_entropy` for more details. Args: axis: reduced axis. Default: ``1`` label_smooth: label smooth factor. Default: ``0.0`` """ def __init__(self, axis: int = 1, label_smooth: float = 0.0): super().__init__() self.axis = axis self.label_smooth = label_smooth
[文档] def forward(self, x: mge.Tensor, y: mge.Tensor) -> mge.Tensor: if x.ndim == y.ndim + 1: return F.loss.cross_entropy(x, y, axis=self.axis, label_smooth=self.label_smooth) else: assert x.ndim == y.ndim if self.label_smooth != 0: y = y * (1 - self.label_smooth) + self.label_smooth / y.shape[self.axis] return (-y * F.logsoftmax(x, axis=self.axis)).sum(self.axis).mean()