|
- from __future__ import annotations
- from torch.nn import LayerNorm
- import torch.nn as nn
- from collections.abc import Sequence
- import torch.nn.functional as F
- from monai.utils import ensure_tuple_rep, optional_import
- SUPPORTED_DROPOUT_MODE = {"vit", "swin"}
- from monai.networks.layers import Conv
- class PatchEmbed(nn.Module):
- """
- Patch embedding block based on: "Liu et al.,
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
- <https://arxiv.org/abs/2103.14030>"
- https://github.com/microsoft/Swin-Transformer
-
- Unlike ViT patch embedding block: (1) input is padded to satisfy window size requirements (2) normalized if
- specified (3) position embedding is not used.
-
- Example::
- "
- from monai.networks.blocks import PatchEmbed
- PatchEmbed(patch_size=2, in_chans=1, embed_dim=48, norm_layer=nn.LayerNorm, spatial_dims=3)
- "
- """
-
- def __init__(
- self,
- patch_size: Sequence[int] | int = 2,
- in_chans: int = 1,
- embed_dim: int = 48,
- norm_layer: type[LayerNorm] = nn.LayerNorm,
- spatial_dims: int = 3,
- ) -> None:
- """
- Args:
- patch_size: dimension of patch size.
- in_chans: dimension of input channels.
- embed_dim: number of linear projection output channels.
- norm_layer: normalization layer.
- spatial_dims: spatial dimension.
- """
-
- super().__init__()
-
- if spatial_dims not in (2, 3):
- raise ValueError("spatial dimension should be 2 or 3.")
-
- patch_size = ensure_tuple_rep(patch_size, spatial_dims)
- self.patch_size = patch_size
- self.embed_dim = embed_dim
- self.proj = Conv[Conv.CONV, spatial_dims](
- in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size
- )
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- x_shape = x.size()
- if len(x_shape) == 5:
- _, _, d, h, w = x_shape
- if w % self.patch_size[2] != 0:
- x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2]))
- if h % self.patch_size[1] != 0:
- x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1]))
- if d % self.patch_size[0] != 0:
- x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0]))
-
- elif len(x_shape) == 4:
- _, _, h, w = x_shape
- if w % self.patch_size[1] != 0:
- x = F.pad(x, (0, self.patch_size[1] - w % self.patch_size[1]))
- if h % self.patch_size[0] != 0:
- x = F.pad(x, (0, 0, 0, self.patch_size[0] - h % self.patch_size[0]))
-
- x = self.proj(x)
- if self.norm is not None:
- x_shape = x.size()
- x = x.flatten(2).transpose(1, 2)
- x = self.norm(x)
- if len(x_shape) == 5:
- d, wh, ww = x_shape[2], x_shape[3], x_shape[4]
- x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww)
- elif len(x_shape) == 4:
- wh, ww = x_shape[2], x_shape[3]
- x = x.transpose(1, 2).view(-1, self.embed_dim, wh, ww)
- return x
|