#!/usr/bin/env python3
# Copyright (c) 2021 Facebook, Inc. and its affiliates.
# Copyright (c) 2020 Ross Wightman
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
"""AutoAugment and RandAugment
AutoAugment: `"AutoAugment: Learning Augmentation Policies from Data"
<https://arxiv.org/abs/1805.09501>`_
RandAugment: `"RandAugment: Practical automated data augmentation with a reduced search space"
<https://arxiv.org/abs/1909.13719>`_
References:
https://github.com/facebookresearch/pycls/blob/main/pycls/datasets/augment.py
https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
"""
import random
from numbers import Real
from typing import List, Tuple, Union
import numpy as np
from PIL import Image, ImageEnhance, ImageOps
__all__ = ["MAX_LEVEL", "POSTERIZE_MIN", "TorchAutoAugment", "TorchRandAugment", "WARP_PARAMS"]
# This signifies the max integer that the controller RNN could predict for the augmentation scheme.
MAX_LEVEL = 10
# Minimum value for posterize (0 in EfficientNet implementation).
POSTERIZE_MIN = 1
# Parameters for affine warping and rotation.
WARP_PARAMS = {"fillcolor": (128, 128, 128), "resample": Image.BILINEAR}
def affine_warp(img: Image, mat) -> Image:
"""Applies affine transform to image."""
return img.transform(img.size, Image.AFFINE, mat, **WARP_PARAMS)
OP_FUNCTIONS = {
# Each op takes an image x and a level v and returns an augmented image.
"auto_contrast": lambda x, _: ImageOps.autocontrast(x),
"equalize": lambda x, _: ImageOps.equalize(x),
"invert": lambda x, _: ImageOps.invert(x),
"rotate": lambda x, v: x.rotate(v, **WARP_PARAMS),
"posterize": lambda x, v: ImageOps.posterize(x, max(POSTERIZE_MIN, int(v))),
"posterize_inc": lambda x, v: ImageOps.posterize(x, max(POSTERIZE_MIN, 4 - int(v))),
"solarize": lambda x, v: ImageOps.solarize(x, int(v)),
# x.point(lambda i: i if i < int(v) else 255 - i),
"solarize_inc": lambda x, v: ImageOps.solarize(x, 256 - int(v)),
# x.point(lambda i: i if i < 256 - int(v) else 255 - i),
"solarize_add": lambda x, v: x.point(lambda i: min(255, i + int(v)) if i < 128 else i),
"color": lambda x, v: ImageEnhance.Color(x).enhance(v),
"contrast": lambda x, v: ImageEnhance.Contrast(x).enhance(v),
"brightness": lambda x, v: ImageEnhance.Brightness(x).enhance(v),
"sharpness": lambda x, v: ImageEnhance.Sharpness(x).enhance(v),
"color_inc": lambda x, v: ImageEnhance.Color(x).enhance(1 + v),
"contrast_inc": lambda x, v: ImageEnhance.Contrast(x).enhance(1 + v),
"brightness_inc": lambda x, v: ImageEnhance.Brightness(x).enhance(1 + v),
"sharpness_inc": lambda x, v: ImageEnhance.Sharpness(x).enhance(1 + v),
"shear_x": lambda x, v: affine_warp(x, (1, v, 0, 0, 1, 0)),
"shear_y": lambda x, v: affine_warp(x, (1, 0, 0, v, 1, 0)),
"trans_x": lambda x, v: affine_warp(x, (1, 0, v * x.size[0], 0, 1, 0)),
"trans_y": lambda x, v: affine_warp(x, (1, 0, 0, 0, 1, v * x.size[1])),
}
OP_RANGES = {
# Ranges for each op in the form of a (min, max, negate).
"auto_contrast": (0, 1, False),
"equalize": (0, 1, False),
"invert": (0, 1, False),
"rotate": (0.0, 30.0, True),
"posterize": (0, 4, False),
"posterize_inc": (0, 4, False),
"solarize": (0, 256, False),
"solarize_inc": (0, 256, False),
"solarize_add": (0, 110, False),
"color": (0.1, 1.9, False),
"color_inc": (0, 0.9, True),
"contrast": (0.1, 1.9, False),
"contrast_inc": (0, 0.9, True),
"brightness": (0.1, 1.9, False),
"brightness_inc": (0, 0.9, True),
"sharpness": (0.1, 1.9, False),
"sharpness_inc": (0, 0.9, True),
"shear_x": (0.0, 0.3, True),
"shear_y": (0.0, 0.3, True),
"trans_x": (0.0, 0.45, True),
"trans_y": (0.0, 0.45, True),
}
AUTOAUG_POLICY = [
# AutoAugment "policy_v0" in form of (op, prob, magnitude).
[("equalize", 0.8, 1), ("shear_y", 0.8, 4)],
[("color", 0.4, 9), ("equalize", 0.6, 3)],
[("color", 0.4, 1), ("rotate", 0.6, 8)],
[("solarize", 0.8, 3), ("equalize", 0.4, 7)],
[("solarize", 0.4, 2), ("solarize", 0.6, 2)],
[("color", 0.2, 0), ("equalize", 0.8, 8)],
[("equalize", 0.4, 8), ("solarize_add", 0.8, 3)],
[("shear_x", 0.2, 9), ("rotate", 0.6, 8)],
[("color", 0.6, 1), ("equalize", 1.0, 2)],
[("invert", 0.4, 9), ("rotate", 0.6, 0)],
[("equalize", 1.0, 9), ("shear_y", 0.6, 3)],
[("color", 0.4, 7), ("equalize", 0.6, 0)],
[("posterize", 0.4, 6), ("auto_contrast", 0.4, 7)],
[("solarize", 0.6, 8), ("color", 0.6, 9)],
[("solarize", 0.2, 4), ("rotate", 0.8, 9)],
[("rotate", 1.0, 7), ("trans_y", 0.8, 9)],
[("shear_x", 0.0, 0), ("solarize", 0.8, 4)],
[("shear_y", 0.8, 0), ("color", 0.6, 4)],
[("color", 1.0, 0), ("rotate", 0.6, 2)],
[("equalize", 0.8, 4), ("equalize", 0.0, 8)],
[("equalize", 1.0, 4), ("auto_contrast", 0.6, 2)],
[("shear_y", 0.4, 7), ("solarize_add", 0.6, 7)],
[("posterize", 0.8, 2), ("solarize", 0.6, 10)],
[("solarize", 0.6, 8), ("equalize", 0.6, 1)],
[("color", 0.8, 6), ("rotate", 0.4, 5)],
]
RANDAUG_OPS = [
# RandAugment list of operations using "increasing" transforms.
"auto_contrast",
"equalize",
"invert",
"rotate",
"posterize_inc",
"solarize_inc",
"solarize_add",
"color_inc",
"contrast_inc",
"brightness_inc",
"sharpness_inc",
"shear_x",
"shear_y",
"trans_x",
"trans_y",
]
def apply_op(
img: Image,
op: str,
prob: Union[float, Tuple[float, float]],
magnitude: Real,
magnitude_std: float = 0.0,
) -> Image:
"""Apply the selected op to image with given probability and magnitude."""
if op not in OP_RANGES and op not in OP_FUNCTIONS:
raise ValueError(f"Operation '{op}' not supported")
if isinstance(prob, tuple):
assert len(prob) == 2
prob = random.uniform(**prob)
if random.random() > prob:
return img
if magnitude_std == float("inf"):
magnitude = random.uniform(0, magnitude)
elif magnitude_std > 0.0:
magnitude = max(0, random.gauss(magnitude, magnitude_std))
min_v, max_v, negate = OP_RANGES[op]
# The magnitude is converted to an absolute value v for an op (some ops use -v or v)
v = magnitude / MAX_LEVEL * (max_v - min_v) + min_v
v = -v if negate and random.random() > 0.5 else v
return OP_FUNCTIONS[op](img, v)
def auto_augment(img: Image, policy: List[Tuple] = None) -> Image:
"""Apply auto augmentation to an image."""
policy = policy if policy else AUTOAUG_POLICY
for op, prob, magnitude in random.choice(policy):
img = apply_op(img, op, prob, magnitude)
return img
def rand_augment(
img: Image,
magnitude: Real,
magnitude_std: float = 0.0,
prob: Union[float, Tuple[float, float]] = 0.5,
n_ops: int = 2,
ops: List[str] = None,
) -> Image:
"""Apply random augmentation to an image."""
ops = ops if ops else RANDAUG_OPS
for op in np.random.choice(ops, n_ops):
img = apply_op(img, op, prob, magnitude, magnitude_std)
return img
[文档]class TorchAutoAugment:
def __call__(self, img: Image) -> Image:
return auto_augment(img)
[文档]class TorchRandAugment:
def __init__(
self,
magnitude: Real,
magnitude_std: float = 0.0,
prob: Union[float, Tuple[float, float]] = 0.5,
n_ops: int = 2,
):
self.magnitude = magnitude
self.magnitude_std = magnitude_std
self.prob = prob
self.n_ops = n_ops
def __call__(self, img: Image) -> Image:
return rand_augment(img, self.magnitude, self.magnitude_std, self.prob, self.n_ops)