|
-
- from typing import Dict, List, Optional
-
- import torch
- import torch.nn as nn
- from fairseq import utils
- from fairseq.modules import LayerNorm, MultiheadAttention
- from fairseq.modules.quant_noise import quant_noise
- from fairseq.modules.fairseq_dropout import FairseqDropout
- from torch import Tensor
- from fairseq.modules import (
- TransformerDecoderLayer,
- TransformerEncoderLayer
- )
- class AdapterDecoderLayer(TransformerDecoderLayer):
-
- def __init__(self, args):
- super().__init__()
-
- def forward(
- self,
- x,
- encoder_out: Optional[torch.Tensor] = None,
- encoder_padding_mask: Optional[torch.Tensor] = None,
- cxt_out: Optional[torch.Tensor] = None,
- cxt_padding_mask: Optional[torch.Tensor] = None,
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
- prev_self_attn_state: Optional[List[torch.Tensor]] = None,
- prev_attn_state: Optional[List[torch.Tensor]] = None,
- self_attn_mask: Optional[torch.Tensor] = None,
- self_attn_padding_mask: Optional[torch.Tensor] = None,
- need_attn: bool = False,
- need_head_weights: bool = False,
- ):
- """
- Args:
- x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
- encoder_padding_mask (ByteTensor, optional): binary
- ByteTensor of shape `(batch, src_len)` where padding
- elements are indicated by ``1``.
- need_attn (bool, optional): return attention weights
- need_head_weights (bool, optional): return attention weights
- for each head (default: return average over heads).
-
- Returns:
- encoded output of shape `(seq_len, batch, embed_dim)`
- """
- if need_head_weights:
- need_attn = True
-
- residual = x
- if self.normalize_before:
- x = self.self_attn_layer_norm(x)
- if prev_self_attn_state is not None:
- prev_key, prev_value = prev_self_attn_state[:2]
- saved_state: Dict[str, Optional[Tensor]] = {
- "prev_key": prev_key,
- "prev_value": prev_value,
- }
- if len(prev_self_attn_state) >= 3:
- saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
- assert incremental_state is not None
- self.self_attn._set_input_buffer(incremental_state, saved_state)
- _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
- if self.cross_self_attention and not (
- incremental_state is not None
- and _self_attn_input_buffer is not None
- and "prev_key" in _self_attn_input_buffer
- ):
- if self_attn_mask is not None:
- assert encoder_out is not None
- self_attn_mask = torch.cat(
- (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
- )
- if self_attn_padding_mask is not None:
- if encoder_padding_mask is None:
- assert encoder_out is not None
- encoder_padding_mask = self_attn_padding_mask.new_zeros(
- encoder_out.size(1), encoder_out.size(0)
- )
- self_attn_padding_mask = torch.cat(
- (encoder_padding_mask, self_attn_padding_mask), dim=1
- )
- assert encoder_out is not None
- y = torch.cat((encoder_out, x), dim=0)
- else:
- y = x
-
- x, attn = self.self_attn(
- query=x,
- key=y,
- value=y,
- key_padding_mask=self_attn_padding_mask,
- incremental_state=incremental_state,
- need_weights=False,
- attn_mask=self_attn_mask,
- )
- x = self.dropout_module(x)
-
- # cxt adapter
- x, attn = self.self_attn(
- query=x,
- key=cxt_out,
- value=cxt_out,
- key_padding_mask=self_attn_padding_mask,
- incremental_state=incremental_state,
- need_weights=False,
- attn_mask=self_attn_mask,
- )
- x = self.dropout_module(x)
-
- x = residual + x
- if not self.normalize_before:
- x = self.self_attn_layer_norm(x)
-
- if self.encoder_attn is not None:
- residual = x
- if self.normalize_before:
- x = self.encoder_attn_layer_norm(x)
- if prev_attn_state is not None:
- prev_key, prev_value = prev_attn_state[:2]
- saved_state: Dict[str, Optional[Tensor]] = {
- "prev_key": prev_key,
- "prev_value": prev_value,
- }
- if len(prev_attn_state) >= 3:
- saved_state["prev_key_padding_mask"] = prev_attn_state[2]
- assert incremental_state is not None
- self.encoder_attn._set_input_buffer(incremental_state, saved_state)
-
- x, attn = self.encoder_attn(
- query=x,
- key=encoder_out,
- value=encoder_out,
- key_padding_mask=encoder_padding_mask,
- incremental_state=incremental_state,
- static_kv=True,
- need_weights=need_attn or (not self.training and self.need_attn),
- need_head_weights=need_head_weights,
- )
- x = self.dropout_module(x)
- x = residual + x
- if not self.normalize_before:
- x = self.encoder_attn_layer_norm(x)
-
- residual = x
- if self.normalize_before:
- x = self.final_layer_norm(x)
-
- x = self.activation_fn(self.fc1(x))
- x = self.activation_dropout_module(x)
- x = self.fc2(x)
- x = self.dropout_module(x)
- x = residual + x
- if not self.normalize_before:
- x = self.final_layer_norm(x)
- if self.onnx_trace and incremental_state is not None:
- saved_state = self.self_attn._get_input_buffer(incremental_state)
- assert saved_state is not None
- if self_attn_padding_mask is not None:
- self_attn_state = [
- saved_state["prev_key"],
- saved_state["prev_value"],
- saved_state["prev_key_padding_mask"],
- ]
- else:
- self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
- return x, attn, self_attn_state
- return x, attn, None
|