|
-
- from typing import Dict, List, Optional
-
- import torch
- import torch.nn as nn
- from fairseq import utils
- from fairseq.modules import LayerNorm, MultiheadAttention
- from fairseq.modules.quant_noise import quant_noise
- from fairseq.modules.fairseq_dropout import FairseqDropout
- from torch import Tensor
- from fairseq.modules import (
- TransformerDecoderLayer,
- TransformerEncoderLayer
- )
- class AdapterEncoderLayer(TransformerEncoderLayer):
-
- def __init__(self, args):
- super().__init__()
-
- def forward(self, x, cxt, encoder_padding_mask, attn_mask: Optional[Tensor] = None):
- """
- Args:
- x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
- encoder_padding_mask (ByteTensor): binary ByteTensor of shape
- `(batch, seq_len)` where padding elements are indicated by ``1``.
- attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
- where `tgt_len` is the length of output and `src_len` is the
- length of input, though here both are equal to `seq_len`.
- `attn_mask[tgt_i, src_j] = 1` means that when calculating the
- embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
- useful for strided self-attention.
-
- Returns:
- encoded output of shape `(seq_len, batch, embed_dim)`
- """
- if attn_mask is not None:
- attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
-
- residual = x
- if self.normalize_before:
- x = self.self_attn_layer_norm(x)
- x, _ = self.self_attn(
- query=x,
- key=x,
- value=x,
- key_padding_mask=encoder_padding_mask,
- attn_mask=attn_mask,
- )
- x = self.dropout_module(x)
-
- # cxt adapter
- x, _ = self.self_attn(
- query=x,
- key=cxt,
- value=cxt,
- key_padding_mask=encoder_padding_mask,
- attn_mask=attn_mask,
- )
- x = self.dropout_module(x)
-
- x = residual + x
- if not self.normalize_before:
- x = self.self_attn_layer_norm(x)
-
- residual = x
- if self.normalize_before:
- x = self.final_layer_norm(x)
-
- x = self.activation_fn(self.fc1(x))
- x = self.activation_dropout_module(x)
- x = self.fc2(x)
- x = self.dropout_module(x)
- x = residual + x
- if not self.normalize_before:
- x = self.final_layer_norm(x)
- return x
|