basecls.data.mixup#

Mixup and CutMix

Mixup: “Mixup: Beyond Empirical Risk Minimization”

CutMix: “CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features”

引用

rwightman/pytorch-image-models

class basecls.data.mixup.MixupCutmixTransform(mixup_alpha=1.0, cutmix_alpha=0.0, cutmix_minmax=None, prob=1.0, switch_prob=0.5, mode='batch', data_format='HWC', num_classes=1000, calibrate_cutmix_lambda=True, calibrate_mixup_lambda=False, permute=False, *, order=None)[源代码]#

基类:VisionTransform

Implement Mixup and CutMix as VisionTransform.

备注

When composed in Compose , batch_compose must be set to True.

参数
  • mixup_alpha (float) – mixup alpha value, mixup is active if > 0. Default: 1.0

  • cutmix_alpha (float) – cutmix alpha value, cutmix is active if > 0. Default: 0.0

  • cutmix_minmax (Optional[List[float]]) – cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. Default: None

  • prob (float) – probability of applying mixup or cutmix per batch or element. Default: 1.0

  • switch_prob (float) – probability of switching to cutmix instead of mixup when both are active. Default: 0.5

  • mode (str) – how to apply mixup/cutmix params, supports "batch", "pair" (pair of elements) and "elem" (element). Default: "batch"

  • data_format (str) – "CHW" or "HWC", use "HWC" if use this transform before T.ToMode(). Default: "HWC"

  • num_classes (int) – number of classes for target. Default: 1000

  • calibrate_cutmix_lambda (bool) – apply lambda correction when cutmix bbox clipped by image borders. Correction is based on clipped area for cutmix. Default: True

  • calibrate_mixup_lambda (bool) – enforce mixup lambda to be greater than 0.5, only make difference in "elem" mode. Default: False

  • permute (bool) – whether mixup with permuted samples instead of flipped samples. Default: False

apply_batch(inputs)[源代码]#

Apply transform on batch input data.

class basecls.data.mixup.MixupCutmixCollator(*args, **kwargs)[源代码]#

基类:Collator

A faster version implemented as a collator.

apply(inputs)[源代码]#