|
- from typing import Optional
-
- import torch
- import torch.nn as nn
- from torch import Tensor
-
-
- class MultiHeadAttention(nn.Module):
- """
- This layer applies a multi-head self- or cross-attention as described in
- `Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
-
- Args:
- embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
- num_heads (int): Number of heads in multi-head attention
- attn_dropout (float): Attention dropout. Default: 0.0
- bias (bool): Use bias or not. Default: ``True``
-
- Shape:
- - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
- and :math:`C_{in}` is input embedding dim
- - Output: same shape as the input
-
- """
-
- def __init__(
- self,
- embed_dim: int,
- num_heads: int,
- attn_dropout: float = 0.0,
- bias: bool = True,
- *args,
- **kwargs
- ) -> None:
- super().__init__()
- if embed_dim % num_heads != 0:
- raise ValueError(
- "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
- self.__class__.__name__, embed_dim, num_heads
- )
- )
-
- self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
-
- self.attn_dropout = nn.Dropout(p=attn_dropout)
- self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)
-
- self.head_dim = embed_dim // num_heads
- self.scaling = self.head_dim ** -0.5
- self.softmax = nn.Softmax(dim=-1)
- self.num_heads = num_heads
- self.embed_dim = embed_dim
-
- def forward(self, x_q: Tensor) -> Tensor:
- # [N, P, C]
- b_sz, n_patches, in_channels = x_q.shape
-
- # self-attention
- # [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
- qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)
-
- # [N, P, 3, h, c] -> [N, h, 3, P, C]
- qkv = qkv.transpose(1, 3).contiguous()
-
- # [N, h, 3, P, C] -> [N, h, P, C] x 3
- query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
-
- query = query * self.scaling
-
- # [N h, P, c] -> [N, h, c, P]
- key = key.transpose(-1, -2)
-
- # QK^T
- # [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
- attn = torch.matmul(query, key)
- attn = self.softmax(attn)
- attn = self.attn_dropout(attn)
-
- # weighted sum
- # [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
- out = torch.matmul(attn, value)
-
- # [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
- out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
- out = self.out_proj(out)
-
- return out
-
-
- class TransformerEncoder(nn.Module):
- """
- This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
- Args:
- embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
- ffn_latent_dim (int): Inner dimension of the FFN
- num_heads (int) : Number of heads in multi-head attention. Default: 8
- attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
- dropout (float): Dropout rate. Default: 0.0
- ffn_dropout (float): Dropout between FFN layers. Default: 0.0
-
- Shape:
- - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
- and :math:`C_{in}` is input embedding dim
- - Output: same shape as the input
- """
-
- def __init__(
- self,
- embed_dim: int,
- ffn_latent_dim: int,
- num_heads: Optional[int] = 8,
- attn_dropout: Optional[float] = 0.0,
- dropout: Optional[float] = 0.0,
- ffn_dropout: Optional[float] = 0.0,
- *args,
- **kwargs
- ) -> None:
-
- super().__init__()
-
- attn_unit = MultiHeadAttention(
- embed_dim,
- num_heads,
- attn_dropout=attn_dropout,
- bias=True
- )
-
- self.pre_norm_mha = nn.Sequential(
- nn.LayerNorm(embed_dim),
- attn_unit,
- nn.Dropout(p=dropout)
- )
-
- self.pre_norm_ffn = nn.Sequential(
- nn.LayerNorm(embed_dim),
- nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
- nn.SiLU(),
- nn.Dropout(p=ffn_dropout),
- nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
- nn.Dropout(p=dropout)
- )
- self.embed_dim = embed_dim
- self.ffn_dim = ffn_latent_dim
- self.ffn_dropout = ffn_dropout
- self.std_dropout = dropout
-
- def forward(self, x: Tensor) -> Tensor:
- # multi-head attention
- res = x
- x = self.pre_norm_mha(x)
- x = x + res
-
- # feed forward network
- x = x + self.pre_norm_ffn(x)
- return x
|