|
- from typing import Any, Dict, List, Optional, Tuple
- from fairseq.models import register_model,register_model_architecture
- from fairseq.models.transformer import (
- TransformerDecoder,
- TransformerEncoder, TransformerModel,
- base_architecture,
- Embedding
- )
- from fairseq.models.bart.model import (
- BARTModel,
- bart_large_architecture
- )
- from fairseq import utils
- from fairseq.models.fairseq_encoder import EncoderOut
- import torch
- import torch.nn as nn
- from typing import Optional
- from .adapter_encoder_layer import AdapterEncoderLayer
- from .adapter_decoder_layer import AdapterDecoderLayer
- from torch import Tensor
- from fairseq.modules import (
- AdaptiveSoftmax,
- FairseqDropout,
- LayerDropModuleList,
- LayerNorm,
- PositionalEmbedding,
- SinusoidalPositionalEmbedding,
- TransformerDecoderLayer,
- TransformerEncoderLayer,
- )
-
- DEFAULT_MAX_SOURCE_POSITIONS = 1024
- DEFAULT_MAX_TARGET_POSITIONS = 1024
-
- # Note: the register_model "decorator" should immediately precede the
- # definition of the Model class.
-
-
- @register_model('adapter_transformer')
- class AdapterTransformer(BARTModel):
-
- def forward(
- self,
- src_tokens,
- src_lengths,
- cxt_tokens,
- prev_output_tokens,
- classification_head_name: Optional[str] = None,
- token_embeddings: Optional[torch.Tensor] = None,
- return_all_hiddens: bool = True,
- features_only: bool = False,
- alignment_layer: Optional[int] = None,
- alignment_heads: Optional[int] = None,
- ):
- if classification_head_name is not None:
- features_only = True
-
- cxt_encoder_out = self.cxt_encoder(
- cxt_tokens, return_all_hiddens=return_all_hiddens
- )
- encoder_out = self.encoder(
- src_tokens,
- src_lengths=src_lengths,
- cxt_encoder_out=cxt_encoder_out,
- token_embeddings=token_embeddings,
- return_all_hiddens=return_all_hiddens
- )
- x, extra = self.decoder(
- prev_output_tokens,
- encoder_out=encoder_out,
- cxt_encoder_out =cxt_encoder_out,
- features_only=features_only,
- alignment_layer=alignment_layer,
- alignment_heads=alignment_heads,
- src_lengths=src_lengths,
- return_all_hiddens=return_all_hiddens,
- )
- eos: int = self.eos
- if classification_head_name is not None:
- sentence_representation = x[
- src_tokens.eq(eos), :
- ].view(x.size(0), -1, x.size(-1))[:, -1, :]
- for k, head in self.classification_heads.items():
- # for torch script only supports iteration
- if k == classification_head_name:
- x = head(sentence_representation)
- break
-
- return x, extra
-
- @classmethod
- def build_model(cls, args, task):
- """Build a new model instance."""
-
- # make sure all arguments are present in older models
- base_architecture(args)
-
- if args.encoder_layers_to_keep:
- args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
- if args.decoder_layers_to_keep:
- args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
-
- if getattr(args, "max_source_positions", None) is None:
- args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
- if getattr(args, "max_target_positions", None) is None:
- args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
-
- src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
-
- if args.share_all_embeddings:
- if src_dict != tgt_dict:
- raise ValueError(
- "--share-all-embeddings requires a joined dictionary")
- if args.encoder_embed_dim != args.decoder_embed_dim:
- raise ValueError(
- "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
- )
- if args.decoder_embed_path and (
- args.decoder_embed_path != args.encoder_embed_path
- ):
- raise ValueError(
- "--share-all-embeddings not compatible with --decoder-embed-path"
- )
- encoder_embed_tokens = cls.build_embedding(
- args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
- )
- cxt_encoder_embed_tokens = cls.build_embedding(
- args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
- )
- decoder_embed_tokens = encoder_embed_tokens
- args.share_decoder_input_output_embed = True
- else:
- encoder_embed_tokens = cls.build_embedding(
- args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
- )
- cxt_encoder_embed_tokens = cls.build_embedding(
- args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
- )
- decoder_embed_tokens = cls.build_embedding(
- args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
- )
-
- encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
- cxt_encoder = cls.build_cxt_encoder(args, src_dict, cxt_encoder_embed_tokens)
- decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
- return cls(args, encoder, decoder, cxt_encoder)
-
- @classmethod
- def build_embedding(cls, args, dictionary, embed_dim, path=None):
- num_embeddings = len(dictionary)
- padding_idx = dictionary.pad()
-
- emb = Embedding(num_embeddings, embed_dim, padding_idx)
- # if provided, load from preloaded dictionaries
- if path:
- embed_dict = utils.parse_embedding(path)
- utils.load_embedding(embed_dict, dictionary, emb)
- return emb
-
- @classmethod
- def build_encoder(cls, args, src_dict, embed_tokens):
- return AdapterTransformerEncoder(args, src_dict, embed_tokens)
-
- @classmethod
- def build_cxt_encoder(cls, args, src_dict, embed_tokens):
- return TransformerEncoder(args, src_dict, embed_tokens)
-
- @classmethod
- def build_decoder(cls, args, tgt_dict, embed_tokens):
- return AdapterTransformerDecoder(
- args,
- tgt_dict,
- embed_tokens,
- no_encoder_attn=getattr(args, "no_cross_attention", False),
- )
-
-
- class AdapterTransformerEncoder(TransformerEncoder):
- def __init__(self, args, dictionary, embed_tokens):
- super().__init__(args, dictionary, embed_tokens)
- self.seg_emb = Embedding(2, args.encoder_embed_dim, None)
-
- if self.encoder_layerdrop > 0.0:
- self.adapter_layers = LayerDropModuleList(p=self.encoder_layerdrop)
- else:
- self.adapter_layers = nn.ModuleList([])
- self.adapter_layers.extend(
- [self.build_adapter_encoder_layer(args) for i in range(args.encoder_layers)]
- )
-
-
- def build_adapter_encoder_layer(self, args):
- return AdapterEncoderLayer(args)
-
- def forward_embedding(self, src_tokens):
- # embed tokens and positions
- x = embed = self.embed_scale * self.embed_tokens(src_tokens)
- if self.embed_positions is not None:
- x = embed + self.embed_positions(src_tokens)
- if self.layernorm_embedding is not None:
- x = self.layernorm_embedding(x)
- x = self.dropout_module(x)
- if self.quant_noise is not None:
- x = self.quant_noise(x)
- return x, embed
-
- def forward(self, src_tokens, src_lengths, cxt_out, return_all_hiddens: bool = False):
-
- src, encoder_embedding = self.forward_embedding(src_tokens)
- encoder_padding_mask = src_tokens.eq(self.padding_idx)
- encoder_states = [] if return_all_hiddens else None
-
- for layer in self.adapter_layers:
- src = layer(src, cxt_out, encoder_padding_mask)
- if return_all_hiddens:
- assert encoder_states is not None
- encoder_states.append(src)
-
- if self.layer_norm is not None:
- src = self.layer_norm(src)
-
- return EncoderOut(
- encoder_out=src, # T x B x C
- encoder_padding_mask=encoder_padding_mask, # B x T
- encoder_embedding=encoder_embedding, # B x T x C
- encoder_states=encoder_states, # List[T x B x C]
- src_tokens=None,
- src_lengths=None,
- )
-
- class AdapterTransformerDecoder(TransformerDecoder):
-
- def build_decoder_layer(self, args, no_encoder_attn=False):
- return AdapterDecoderLayer(args, no_encoder_attn)
-
- def forward(
- self,
- prev_output_tokens,
- encoder_out: Optional[EncoderOut] = None,
- cxt_encoder_out: Optional[EncoderOut] = None,
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
- features_only: bool = False,
- alignment_layer: Optional[int] = None,
- alignment_heads: Optional[int] = None,
- src_lengths: Optional[Any] = None,
- return_all_hiddens: bool = False,
- ):
- """
- Args:
- prev_output_tokens (LongTensor): previous decoder outputs of shape
- `(batch, tgt_len)`, for teacher forcing
- encoder_out (optional): output from the encoder, used for
- encoder-side attention
- incremental_state (dict): dictionary used for storing state during
- :ref:`Incremental decoding`
- features_only (bool, optional): only return features without
- applying output layer (default: False).
-
- Returns:
- tuple:
- - the decoder's output of shape `(batch, tgt_len, vocab)`
- - a dictionary with any model-specific outputs
- """
- x, extra = self.extract_features(
- prev_output_tokens,
- encoder_out=encoder_out,
- cxt_encoder_out=cxt_encoder_out,
- incremental_state=incremental_state,
- alignment_layer=alignment_layer,
- alignment_heads=alignment_heads,
- )
- if not features_only:
- x = self.output_layer(x)
- return x, extra
-
- def extract_features(
- self,
- prev_output_tokens,
- encoder_out: Optional[EncoderOut] = None,
- cxt_encoder_out: Optional[EncoderOut] = None,
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
- full_context_alignment: bool = False,
- alignment_layer: Optional[int] = None,
- alignment_heads: Optional[int] = None,
- ):
- return self.extract_features_scriptable(
- prev_output_tokens,
- encoder_out,
- cxt_encoder_out,
- incremental_state,
- full_context_alignment,
- alignment_layer,
- alignment_heads,
- )
-
- def extract_features_scriptable(
- self,
- prev_output_tokens,
- encoder_out: Optional[EncoderOut] = None,
- cxt_encoder_out: Optional[EncoderOut] = None,
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
- full_context_alignment: bool = False,
- alignment_layer: Optional[int] = None,
- alignment_heads: Optional[int] = None,
- ):
-
- if alignment_layer is None:
- alignment_layer = self.num_layers - 1
-
- # embed positions
- positions = (
- self.embed_positions(
- prev_output_tokens, incremental_state=incremental_state
- )
- if self.embed_positions is not None
- else None
- )
-
- if incremental_state is not None:
- prev_output_tokens = prev_output_tokens[:, -1:]
- if positions is not None:
- positions = positions[:, -1:]
-
- # embed tokens and positions
- x = self.embed_scale * self.embed_tokens(prev_output_tokens)
-
- if self.quant_noise is not None:
- x = self.quant_noise(x)
-
- if self.project_in_dim is not None:
- x = self.project_in_dim(x)
-
- if positions is not None:
- x += positions
-
- if self.layernorm_embedding is not None:
- x = self.layernorm_embedding(x)
-
- x = self.dropout_module(x)
-
- # B x T x C -> T x B x C
- x = x.transpose(0, 1)
-
- self_attn_padding_mask: Optional[Tensor] = None
- if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
- self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
-
- # decoder layers
- attn: Optional[Tensor] = None
- inner_states: List[Optional[Tensor]] = [x]
- for idx, layer in enumerate(self.layers):
- if incremental_state is None and not full_context_alignment:
- self_attn_mask = self.buffered_future_mask(x)
- else:
- self_attn_mask = None
-
- x, layer_attn, _ = layer(
- x,
- encoder_out.encoder_out if encoder_out is not None else None,
- encoder_out.encoder_padding_mask if encoder_out is not None else None,
- cxt_encoder_out.encoder_out if cxt_encoder_out is not None else None,
- cxt_encoder_out.encoder_padding_mask if cxt_encoder_out is not None else None,
- incremental_state,
- self_attn_mask=self_attn_mask,
- self_attn_padding_mask=self_attn_padding_mask,
- need_attn=bool((idx == alignment_layer)),
- need_head_weights=bool((idx == alignment_layer)),
- )
- inner_states.append(x)
- if layer_attn is not None and idx == alignment_layer:
- attn = layer_attn.float().to(x)
-
- if attn is not None:
- if alignment_heads is not None:
- attn = attn[:alignment_heads]
-
- # average probabilities over heads
- attn = attn.mean(dim=0)
-
- if self.layer_norm is not None:
- x = self.layer_norm(x)
-
- # T x B x C -> B x T x C
- x = x.transpose(0, 1)
-
- if self.project_out_dim is not None:
- x = self.project_out_dim(x)
-
- return x, {"attn": [attn], "inner_states": inner_states}
-
- @register_model_architecture('adapter_transformer', 'adapter_bart_large')
- def bart_large_architecture(args):
- args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
- args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
- args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*1024)
- args.encoder_layers = getattr(args, 'encoder_layers', 12)
- args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
- args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
- args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
- args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
- args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
- args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
- args.decoder_layers = getattr(args, 'decoder_layers', 12)
- args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
- args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
- args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True)
- args.attention_dropout = getattr(args, 'attention_dropout', 0.)
- args.relu_dropout = getattr(args, 'relu_dropout', 0.)
- args.dropout = getattr(args, 'dropout', 0.1)
- args.max_target_positions = getattr(args, 'max_target_positions', 1024)
- args.max_source_positions = getattr(args, 'max_source_positions', 1024)
- args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
- args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
- args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', True)
- args.share_all_embeddings = getattr(args, 'share_all_embeddings', True)
-
- args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
- args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
-
- args.no_scale_embedding = getattr(args, 'no_scale_embedding', True)
- args.layernorm_embedding = getattr(args, 'layernorm_embedding', True)
-
- args.activation_fn = getattr(args, 'activation_fn', 'gelu')
- args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
- args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
-
- @register_model_architecture("adapter_transformer", "adapter_mbart_large")
- def mbart_large_architecture(args):
- args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
- bart_large_architecture(args)
-
- @register_model_architecture("adapter_transformer", "bart_base")
- def bart_base_architecture(args):
- args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
- args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768)
- args.encoder_layers = getattr(args, "encoder_layers", 6)
- args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
- args.decoder_layers = getattr(args, "decoder_layers", 6)
- args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
- bart_large_architecture(args)
-
- @register_model_architecture("adapter_transformer", "mbart_base")
- def mbart_base_architecture(args):
- args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
- bart_base_architecture(args)
|