basecls.models.vit 源代码

#!/usr/bin/env python3
# Hacked together by / Copyright 2021 Ross Wightman
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
"""Vision Transformer (ViT)

ViT: `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"
<https://arxiv.org/abs/2010.11929>`_

DeiT: `"Training data-efficient image transformers & distillation through attention"
<https://arxiv.org/abs/2012.12877>`_

References:
    https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import math
from collections import OrderedDict
from typing import Callable, Optional, Union

import cv2
import megengine as mge
import megengine.functional as F
import megengine.hub as hub
import megengine.module as M
import numpy as np
from loguru import logger
from megengine.utils.tuple_function import _pair as to_2tuple

from basecls.layers import DropPath, activation, init_vit_weights, norm2d, trunc_normal_
from basecls.utils import recursive_update, registers

__all__ = ["PatchEmbed", "Attention", "FFN", "EncoderBlock", "ViT"]


[文档]class PatchEmbed(M.Module): """Image to Patch Embedding Args: img_size: Image size. Default: ``224`` patch_size: Patch token size. Default: ``16`` in_chans: Number of input image channels. Default: ``3`` embed_dim: Number of linear projection output channels. Default: ``768`` flatten: Flatten embedding. Default: ``True`` norm_name: Normalization layer. Default: ``None`` """ def __init__( self, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, flatten: bool = True, norm_name: str = None, **kwargs, ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = M.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm2d(norm_name, embed_dim) if norm_name else None
[文档] def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], ( f"Input image size ({H}*{W}) doesn't match model " f"({self.img_size[0]}*{self.img_size[1]})." ) x = self.proj(x) if self.flatten: x = F.flatten(x, 2).transpose(0, 2, 1) if self.norm: x = self.norm(x) return x
[文档]class Attention(M.Module): """Self-Attention block. Args: dim: input Number of input channels. num_heads: Number of attention heads. Default: ``8`` qkv_bias: If True, add a learnable bias to query, key, value. Default: ``False`` qk_scale: Override default qk scale of ``head_dim ** -0.5`` if set. attn_drop: Dropout ratio of attention weight. Default: ``0.0`` proj_drop: Dropout ratio of output. Default: ``0.0`` """ def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_scale: float = None, attn_drop: float = 0.0, proj_drop: float = 0.0, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = M.Linear(dim, dim * 3, bias=qkv_bias) self.softmax = M.Softmax(axis=-1) self.attn_drop = M.Dropout(attn_drop) self.proj = M.Linear(dim, dim) self.proj_drop = M.Dropout(proj_drop)
[文档] def forward(self, x): B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .transpose(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] attn = F.matmul(q, k.transpose(0, 1, 3, 2)) * self.scale attn = self.softmax(attn) attn = self.attn_drop(attn) x = F.matmul(attn, v).transpose(0, 2, 1, 3).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x
[文档]class FFN(M.Module): """FFN for ViT Args: in_features: Number of input features. hidden_features: Number of input features. Default: ``None`` out_features: Number of output features. Default: ``None`` drop: Dropout ratio. Default: ``0.0`` act_name: activation function. Default: ``"gelu"`` """ def __init__( self, in_features: int, hidden_features: int = None, out_features: int = None, drop: float = 0.0, act_name: str = "gelu", ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = M.Linear(in_features, hidden_features) self.act = activation(act_name) self.fc2 = M.Linear(hidden_features, out_features) self.drop = M.Dropout(drop)
[文档] def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
[文档]class EncoderBlock(M.Module): """Transformer Encoder block. Args: dim: Number of input channels. num_heads: Number of attention heads. ffn_ratio: Ratio of ffn hidden dim to embedding dim. Default: ``4.0`` qkv_bias: If True, add a learnable bias to query, key, value. Default: ``False`` qk_scale: Override default qk scale of ``head_dim ** -0.5`` if set. drop: Dropout ratio of non-attention weight. Default: ``0.0`` attn_drop: Dropout ratio of attention weight. Default: ``0.0`` drop_path: Stochastic depth rate. Default: ``0.0`` norm_name: Normalization layer. Default: ``"LN"`` act_name: Activation layer. Default: ``"gelu"`` """ def __init__( self, dim: int, num_heads: int, ffn_ratio: float = 4.0, qkv_bias: bool = False, qk_scale: float = None, attn_drop: float = 0.0, drop: float = 0.0, drop_path: float = 0.0, norm_name: str = "LN", act_name: str = "gelu", **kwargs, ): super().__init__() self.norm1 = norm2d(norm_name, dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0.0 else None self.norm2 = norm2d(norm_name, dim) ffn_hidden_dim = int(dim * ffn_ratio) self.ffn = FFN( in_features=dim, hidden_features=ffn_hidden_dim, drop=drop, act_name=act_name )
[文档] def forward(self, x): if self.drop_path: x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.ffn(self.norm2(x))) else: x = x + self.attn(self.norm1(x)) x = x + self.ffn(self.norm2(x)) return x
[文档]@registers.models.register() class ViT(M.Module): """ViT model. Args: img_size: Input image size. Default: ``224`` patch_size: Patch token size. Default: ``16`` in_chans: Number of input image channels. Default: ``3`` embed_dim: Number of linear projection output channels. Default: ``768`` depth: Depth of Transformer Encoder layer. Default: ``12`` num_heads: Number of attention heads. Default: ``12`` ffn_ratio: Ratio of ffn hidden dim to embedding dim. Default: ``4.0`` qkv_bias: If True, add a learnable bias to query, key, value. Default: ``True`` qk_scale: Override default qk scale of head_dim ** -0.5 if set. Default: ``None`` representation_size: Size of representation layer (pre-logits). Default: ``None`` distilled: Includes a distillation token and head. Default: ``False`` drop_rate: Dropout rate. Default: ``0.0`` attn_drop_rate: Attention dropout rate. Default: ``0.0`` drop_path_rate: Stochastic depth rate. Default: ``0.0`` embed_layer: Patch embedding layer. Default: :py:class:`PatchEmbed` norm_name: Normalization layer. Default: ``"LN"`` act_name: Activation function. Default: ``"gelu"`` num_classes: Number of classes. Default: ``1000`` """ def __init__( self, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, ffn_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float = None, representation_size: int = None, distilled: bool = False, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, embed_layer: M.Module = PatchEmbed, norm_name: str = "LN", act_name: str = "gelu", num_classes: int = 1000, **kwargs, ): super().__init__() # Patch Embedding self.embed_dim = embed_dim self.patch_embed = embed_layer(img_size, patch_size, in_chans, embed_dim) num_patches = self.patch_embed.num_patches # CLS & DST Tokens self.cls_token = mge.Parameter(F.zeros([1, 1, embed_dim])) self.dist_token = mge.Parameter(F.zeros([1, 1, embed_dim])) if distilled else None self.num_tokens = 2 if distilled else 1 # Pos Embedding self.pos_embed = mge.Parameter(F.zeros([1, num_patches + self.num_tokens, embed_dim])) self.pos_drop = M.Dropout(drop_rate) # Blocks dpr = [ x.item() for x in F.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule self.blocks = M.Sequential( *[ EncoderBlock( dim=embed_dim, num_heads=num_heads, ffn_ratio=ffn_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_name=norm_name, act_name=act_name, ) for i in range(depth) ] ) self.norm = norm2d(norm_name, embed_dim) # Representation layer if representation_size and not distilled: self.num_features = representation_size self.pre_logits = M.Sequential( OrderedDict( [("fc", M.Linear(embed_dim, representation_size)), ("act", activation("tanh"))] ) ) else: self.pre_logits = None # Classifier head(s) self.head = M.Linear(self.embed_dim, num_classes) if num_classes > 0 else None self.head_dist = None if distilled: self.head_dist = M.Linear(self.embed_dim, num_classes) if num_classes > 0 else None # Init self.init_weights()
[文档] def init_weights(self): trunc_normal_(self.pos_embed, std=0.02) if self.dist_token is not None: trunc_normal_(self.dist_token, std=0.02) trunc_normal_(self.cls_token, std=0.02) self.apply(init_vit_weights)
[文档] def forward(self, x): x = self.patch_embed(x) cls_token = F.broadcast_to(self.cls_token, (x.shape[0], 1, self.cls_token.shape[-1])) if self.dist_token is None: x = F.concat((cls_token, x), axis=1) else: dist_token = F.broadcast_to(self.dist_token, (x.shape[0], 1, self.dist_token.shape[-1])) x = F.concat((cls_token, dist_token, x), axis=1) x = self.pos_drop(x + self.pos_embed) x = self.blocks(x) x = self.norm(x) if self.dist_token is None: x = x[:, 0] if self.pre_logits: x = self.pre_logits(x) else: x = x[:, 0], x[:, 1] if self.head_dist is not None: x_cls, x_dist = x if self.head: x_cls = self.head(x_cls) if self.head_dist: x_dist = self.head_dist(x_dist) if self.training: # during inference, return the average of both classifier predictions return x_cls, x_dist else: return (x_cls + x_dist) / 2 elif self.head: x = self.head(x) return x
[文档] def load_state_dict( self, state_dict: Union[dict, Callable[[str, mge.Tensor], Optional[np.ndarray]]], strict=True, ): if "pos_embed" in state_dict: old_pos_embed = state_dict["pos_embed"] old_n_patches = old_pos_embed.shape[1] - self.num_tokens old_gs = int(math.sqrt(old_n_patches + 0.5)) new_n_patches = self.pos_embed.shape[1] - self.num_tokens new_gs = int(math.sqrt(new_n_patches + 0.5)) logger.info("Position embedding grid-size from {} to {}", [old_gs] * 2, [new_gs] * 2) logger.info( "Resized position embedding: {} to {}", old_pos_embed.shape, self.pos_embed.shape ) if isinstance(old_pos_embed, mge.Tensor): old_pos_embed = old_pos_embed.numpy() pos_emb_tok, old_pos_emb_grid = np.split(old_pos_embed, [self.num_tokens], axis=1) old_pos_emb_grid = old_pos_emb_grid.reshape(old_gs, old_gs, -1).transpose(2, 0, 1) new_pos_embed_grid = ( np.stack( [ cv2.resize(c, (new_gs, new_gs), interpolation=cv2.INTER_CUBIC) for c in old_pos_emb_grid ] ) .transpose(1, 2, 0) .reshape(1, new_gs ** 2, -1) ) new_pos_embed = np.concatenate([pos_emb_tok, new_pos_embed_grid], axis=1) if isinstance(old_pos_embed, mge.Tensor): new_pos_embed = mge.Parameter(new_pos_embed) state_dict["pos_embed"] = new_pos_embed super().load_state_dict(state_dict, strict)
def _build_vit(**kwargs): model_args = dict(depth=12, drop_path_rate=0.1) recursive_update(model_args, kwargs) return ViT(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "vit/vit_tiny_patch16_224/vit_tiny_patch16_224.pkl" ) def vit_tiny_patch16_224(**kwargs): model_args = dict(patch_size=16, embed_dim=192, num_heads=3) recursive_update(model_args, kwargs) return _build_vit(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "vit/vit_tiny_patch16_384/vit_tiny_patch16_384.pkl" ) def vit_tiny_patch16_384(**kwargs): model_args = dict(img_size=384) recursive_update(model_args, kwargs) return vit_tiny_patch16_224(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "vit/vit_small_patch16_224/vit_small_patch16_224.pkl" ) def vit_small_patch16_224(**kwargs): model_args = dict(patch_size=16, embed_dim=384, num_heads=6) recursive_update(model_args, kwargs) return _build_vit(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "vit/vit_small_patch16_384/vit_small_patch16_384.pkl" ) def vit_small_patch16_384(**kwargs): model_args = dict(img_size=384) recursive_update(model_args, kwargs) return vit_small_patch16_224(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "vit/vit_small_patch32_224/vit_small_patch32_224.pkl" ) def vit_small_patch32_224(**kwargs): model_args = dict(patch_size=32, embed_dim=384, num_heads=6) recursive_update(model_args, kwargs) return _build_vit(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "vit/vit_small_patch32_384/vit_small_patch32_384.pkl" ) def vit_small_patch32_384(**kwargs): model_args = dict(img_size=384) recursive_update(model_args, kwargs) return vit_small_patch32_224(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "vit/vit_base_patch16_224/vit_base_patch16_224.pkl" ) def vit_base_patch16_224(**kwargs): model_args = dict(patch_size=16, embed_dim=768, num_heads=12) recursive_update(model_args, kwargs) return _build_vit(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "vit/vit_base_patch16_384/vit_base_patch16_384.pkl" ) def vit_base_patch16_384(**kwargs): model_args = dict(img_size=384) recursive_update(model_args, kwargs) return vit_base_patch16_224(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "vit/vit_base_patch32_224/vit_base_patch32_224.pkl" ) def vit_base_patch32_224(**kwargs): model_args = dict(patch_size=32, embed_dim=768, num_heads=12) recursive_update(model_args, kwargs) return _build_vit(**model_args) @registers.models.register() @hub.pretrained( "https://data.megengine.org.cn/research/basecls/models/" "vit/vit_base_patch32_384/vit_base_patch32_384.pkl" ) def vit_base_patch32_384(**kwargs): model_args = dict(img_size=384) recursive_update(model_args, kwargs) return vit_base_patch32_224(**model_args)