#!/usr/bin/env python3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
import copy
from typing import Optional, Sequence
import cv2
import megengine.data as data
import megengine.data.transform as T
import numpy as np
from basecore.config import ConfigDict
from loguru import logger
from basecls.utils import registers
from .augment import WARP_PARAMS, TorchAutoAugment, TorchRandAugment
from .const import CV2_INTERP, PIL_INTERP
from .mixup import MixupCutmixCollator
from .rand_erase import RandomErasing
__all__ = [
"build_transform",
"AutoAugment",
"SimpleAugment",
"ColorAugment",
"RandAugment",
"build_mixup",
]
class ToColorSpace(T.VisionTransform):
"""Transform to transfer color space.
Args:
color_space: color space, supports ``"BGR"``, ``"RGB"`` and ``"GRAY"``.
"""
def __init__(self, color_space: str, *, order: Sequence = None):
super().__init__(order)
if color_space not in ("BGR", "RGB", "GRAY"):
raise ValueError(f"Color space '{color_space}' not supported")
self.color_space = color_space
def _apply_image(self, image: np.ndarray) -> np.ndarray:
if self.color_space == "BGR":
return image
elif self.color_space == "RGB":
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif self.color_space == "GRAY":
return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)[..., np.newaxis]
else:
raise ValueError(f"Color space '{self.color_space}' not supported")
[文档]@registers.augments.register()
class SimpleAugment:
"""Simple augmentation."""
[文档] @classmethod
def build(cls, cfg: ConfigDict) -> T.Transform:
return T.PseudoTransform()
[文档]@registers.augments.register()
class ColorAugment:
"""Color augmentation."""
[文档] @classmethod
def build(cls, cfg: ConfigDict) -> T.Transform:
aug_args = cfg.augments.color_aug.to_dict()
lighting_scale = aug_args.pop("lighting")
return T.Compose([T.ColorJitter(**aug_args), T.Lighting(lighting_scale)])
[文档]@registers.augments.register()
class AutoAugment:
"""AutoAugment."""
[文档] @classmethod
def build(cls, cfg: ConfigDict) -> T.Transform:
return T.TorchTransformCompose([TorchAutoAugment()])
[文档]@registers.augments.register()
class RandAugment:
"""Random augmentation."""
[文档] @classmethod
def build(cls, cfg: ConfigDict) -> T.Transform:
return T.TorchTransformCompose([TorchRandAugment(**cfg.augments.rand_aug.to_dict())])
[文档]def build_mixup(cfg: ConfigDict, train: bool = True) -> Optional[data.Collator]:
"""Build (optionally) Mixup/CutMix augment.
Args:
cfg: config for building Mixup/CutMix collator.
train: train set or test set. Default: ``True``
Returns:
:py:class:`~basecls.data.mixup.MixupCutmixCollator` or ``None``
"""
mixup_cfg = cfg.augments.mixup
if train and (
mixup_cfg.mixup_alpha > 0.0
or mixup_cfg.cutmix_alpha > 0.0
or mixup_cfg.cutmix_minmax is not None
):
mixup_collator = MixupCutmixCollator(**mixup_cfg.to_dict(), num_classes=cfg.num_classes)
logger.info(f"Using mixup with configuration:\n{mixup_cfg}")
else:
mixup_collator = None
return mixup_collator