|
- # coding=utf-8
- # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """ PyTorch Transformer XL model.
- Adapted from https://github.com/kimiyoung/transformer-xl.
- In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
- """
-
- from __future__ import absolute_import, division, print_function, unicode_literals
-
- import os
- import json
- import math
- import logging
- import collections
- import sys
- from io import open
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn import CrossEntropyLoss
- from torch.nn.parameter import Parameter
-
- from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
- from .configuration_transfo_xl import TransfoXLConfig
- from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
- from .file_utils import add_start_docstrings
-
- logger = logging.getLogger(__name__)
-
- TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
- 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
- }
-
- def build_tf_to_pytorch_map(model, config):
- """ A map of modules from TF to PyTorch.
- This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
- """
- tf_to_pt_map = {}
-
- if hasattr(model, 'transformer'):
- # We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax
- tf_to_pt_map.update({
- "transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
- "transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
- for i, (out_l, proj_l, tie_proj) in enumerate(zip(
- model.crit.out_layers,
- model.crit.out_projs,
- config.tie_projs)):
- layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
- if config.tie_weight:
- tf_to_pt_map.update({
- layer_str + 'b': out_l.bias})
- else:
- raise NotImplementedError
- # I don't think this is implemented in the TF code
- tf_to_pt_map.update({
- layer_str + 'lookup_table': out_l.weight,
- layer_str + 'b': out_l.bias})
- if not tie_proj:
- tf_to_pt_map.update({
- layer_str + 'proj': proj_l
- })
- # Now load the rest of the transformer
- model = model.transformer
-
- # Embeddings
- for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
- layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
- tf_to_pt_map.update({
- layer_str + 'lookup_table': embed_l.weight,
- layer_str + 'proj_W': proj_l
- })
-
- # Transformer blocks
- for i, b in enumerate(model.layers):
- layer_str = "transformer/layer_%d/" % i
- tf_to_pt_map.update({
- layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
- layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
- layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
- layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
- layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
- layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
- layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
- layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
- layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
- layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
- layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
- })
-
- # Relative positioning biases
- if config.untie_r:
- r_r_list = []
- r_w_list = []
- for b in model.layers:
- r_r_list.append(b.dec_attn.r_r_bias)
- r_w_list.append(b.dec_attn.r_w_bias)
- else:
- r_r_list = [model.r_r_bias]
- r_w_list = [model.r_w_bias]
- tf_to_pt_map.update({
- 'transformer/r_r_bias': r_r_list,
- 'transformer/r_w_bias': r_w_list})
- return tf_to_pt_map
-
- def load_tf_weights_in_transfo_xl(model, config, tf_path):
- """ Load tf checkpoints in a pytorch model
- """
- try:
- import numpy as np
- import tensorflow as tf
- except ImportError:
- logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
- "https://www.tensorflow.org/install/ for installation instructions.")
- raise
- # Build TF to PyTorch weights loading map
- tf_to_pt_map = build_tf_to_pytorch_map(model, config)
-
- # Load weights from TF model
- init_vars = tf.train.list_variables(tf_path)
- tf_weights = {}
- for name, shape in init_vars:
- logger.info("Loading TF weight {} with shape {}".format(name, shape))
- array = tf.train.load_variable(tf_path, name)
- tf_weights[name] = array
-
- for name, pointer in tf_to_pt_map.items():
- assert name in tf_weights
- array = tf_weights[name]
- # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
- # which are not required for using pretrained model
- if 'kernel' in name or 'proj' in name:
- array = np.transpose(array)
- if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
- # Here we will split the TF weigths
- assert len(pointer) == array.shape[0]
- for i, p_i in enumerate(pointer):
- arr_i = array[i, ...]
- try:
- assert p_i.shape == arr_i.shape
- except AssertionError as e:
- e.args += (p_i.shape, arr_i.shape)
- raise
- logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
- p_i.data = torch.from_numpy(arr_i)
- else:
- try:
- assert pointer.shape == array.shape
- except AssertionError as e:
- e.args += (pointer.shape, array.shape)
- raise
- logger.info("Initialize PyTorch weight {}".format(name))
- pointer.data = torch.from_numpy(array)
- tf_weights.pop(name, None)
- tf_weights.pop(name + '/Adam', None)
- tf_weights.pop(name + '/Adam_1', None)
-
- logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
- return model
-
-
- class PositionalEmbedding(nn.Module):
- def __init__(self, demb):
- super(PositionalEmbedding, self).__init__()
-
- self.demb = demb
-
- inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
- self.register_buffer('inv_freq', inv_freq)
-
- def forward(self, pos_seq, bsz=None):
- sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
- pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
-
- if bsz is not None:
- return pos_emb[:,None,:].expand(-1, bsz, -1)
- else:
- return pos_emb[:,None,:]
-
-
-
- class PositionwiseFF(nn.Module):
- def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
- super(PositionwiseFF, self).__init__()
-
- self.d_model = d_model
- self.d_inner = d_inner
- self.dropout = dropout
-
- self.CoreNet = nn.Sequential(
- nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
- nn.Dropout(dropout),
- nn.Linear(d_inner, d_model),
- nn.Dropout(dropout),
- )
-
- self.layer_norm = nn.LayerNorm(d_model)
-
- self.pre_lnorm = pre_lnorm
-
- def forward(self, inp):
- if self.pre_lnorm:
- ##### layer normalization + positionwise feed-forward
- core_out = self.CoreNet(self.layer_norm(inp))
-
- ##### residual connection
- output = core_out + inp
- else:
- ##### positionwise feed-forward
- core_out = self.CoreNet(inp)
-
- ##### residual connection + layer normalization
- output = self.layer_norm(inp + core_out)
-
- return output
-
-
-
- class MultiHeadAttn(nn.Module):
- def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
- pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False):
- super(MultiHeadAttn, self).__init__()
-
- self.output_attentions = output_attentions
- self.n_head = n_head
- self.d_model = d_model
- self.d_head = d_head
- self.dropout = dropout
-
- self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
- self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
-
- self.drop = nn.Dropout(dropout)
- self.dropatt = nn.Dropout(dropatt)
- self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
-
- self.layer_norm = nn.LayerNorm(d_model)
-
- self.scale = 1 / (d_head ** 0.5)
-
- self.pre_lnorm = pre_lnorm
-
- if r_r_bias is None or r_w_bias is None: # Biases are not shared
- self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
- self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
- else:
- self.r_r_bias = r_r_bias
- self.r_w_bias = r_w_bias
-
- def forward(self, h, attn_mask=None, mems=None, head_mask=None):
- ##### multihead attention
- # [hlen x bsz x n_head x d_head]
-
- if mems is not None:
- c = torch.cat([mems, h], 0)
- else:
- c = h
-
- if self.pre_lnorm:
- ##### layer normalization
- c = self.layer_norm(c)
-
- head_q = self.q_net(h)
- head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
-
- head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
- head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
- head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
-
- # [qlen x klen x bsz x n_head]
- attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
- attn_score.mul_(self.scale)
- if attn_mask is not None and torch.sum(attn_mask).item():
- attn_mask = (attn_mask == 1) # Switch to bool
- if attn_mask.dim() == 2:
- attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
- elif attn_mask.dim() == 3:
- attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
-
- # [qlen x klen x bsz x n_head]
- attn_prob = F.softmax(attn_score, dim=1)
- attn_prob = self.dropatt(attn_prob)
-
- # Mask heads if we want to
- if head_mask is not None:
- attn_prob = attn_prob * head_mask
-
- # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
- attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
- attn_vec = attn_vec.contiguous().view(
- attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
-
- ##### linear projection
- attn_out = self.o_net(attn_vec)
- attn_out = self.drop(attn_out)
-
- if self.pre_lnorm:
- ##### residual connection
- outputs = [h + attn_out]
- else:
- ##### residual connection + layer normalization
- outputs = [self.layer_norm(h + attn_out)]
-
- if self.output_attentions:
- outputs.append(attn_prob)
-
- return outputs
-
- class RelMultiHeadAttn(nn.Module):
- def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
- tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
- r_r_bias=None, r_w_bias=None, output_attentions=False):
- super(RelMultiHeadAttn, self).__init__()
-
- self.output_attentions = output_attentions
- self.n_head = n_head
- self.d_model = d_model
- self.d_head = d_head
- self.dropout = dropout
-
- self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
-
- self.drop = nn.Dropout(dropout)
- self.dropatt = nn.Dropout(dropatt)
- self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
-
- self.layer_norm = nn.LayerNorm(d_model)
-
- self.scale = 1 / (d_head ** 0.5)
-
- self.pre_lnorm = pre_lnorm
-
- if r_r_bias is None or r_w_bias is None: # Biases are not shared
- self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
- self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
- else:
- self.r_r_bias = r_r_bias
- self.r_w_bias = r_w_bias
-
- def _parallelogram_mask(self, h, w, left=False):
- mask = torch.ones((h, w)).byte()
- m = min(h, w)
- mask[:m,:m] = torch.triu(mask[:m,:m])
- mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
-
- if left:
- return mask
- else:
- return mask.flip(0)
-
- def _shift(self, x, qlen, klen, mask, left=False):
- if qlen > 1:
- zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
- device=x.device, dtype=x.dtype)
- else:
- zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
-
- if left:
- mask = mask.flip(1)
- x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
- else:
- x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
-
- x = x_padded.masked_select(mask[:,:,None,None]) \
- .view(qlen, klen, x.size(2), x.size(3))
-
- return x
-
- def _rel_shift(self, x, zero_triu=False):
- zero_pad_shape = (x.size(0), 1) + x.size()[2:]
- zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
- x_padded = torch.cat([zero_pad, x], dim=1)
-
- x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
- x_padded = x_padded.view(*x_padded_shape)
-
- x = x_padded[1:].view_as(x)
-
- if zero_triu:
- ones = torch.ones((x.size(0), x.size(1)))
- x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
-
- return x
-
- def forward(self, w, r, attn_mask=None, mems=None):
- raise NotImplementedError
-
- class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
- def __init__(self, *args, **kwargs):
- super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
-
- self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
-
- def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
- qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
-
- if mems is not None:
- cat = torch.cat([mems, w], 0)
- if self.pre_lnorm:
- w_heads = self.qkv_net(self.layer_norm(cat))
- else:
- w_heads = self.qkv_net(cat)
- r_head_k = self.r_net(r)
-
- w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
- w_head_q = w_head_q[-qlen:]
- else:
- if self.pre_lnorm:
- w_heads = self.qkv_net(self.layer_norm(w))
- else:
- w_heads = self.qkv_net(w)
- r_head_k = self.r_net(r)
-
- w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
-
- klen = w_head_k.size(0)
-
- w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
- w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
- w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
-
- r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
-
- #### compute attention score
- rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
- AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
-
- rr_head_q = w_head_q + self.r_r_bias
- BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
- BD = self._rel_shift(BD)
-
- # [qlen x klen x bsz x n_head]
- attn_score = AC + BD
- attn_score.mul_(self.scale)
-
- #### compute attention probability
- if attn_mask is not None and torch.sum(attn_mask).item():
- attn_mask = (attn_mask == 1) # Switch to bool
- if attn_mask.dim() == 2:
- attn_score = attn_score.float().masked_fill(
- attn_mask[None,:,:,None], -1e30).type_as(attn_score)
- elif attn_mask.dim() == 3:
- attn_score = attn_score.float().masked_fill(
- attn_mask[:,:,:,None], -1e30).type_as(attn_score)
-
- # [qlen x klen x bsz x n_head]
- attn_prob = F.softmax(attn_score, dim=1)
- attn_prob = self.dropatt(attn_prob)
-
- # Mask heads if we want to
- if head_mask is not None:
- attn_prob = attn_prob * head_mask
-
- #### compute attention vector
- attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
-
- # [qlen x bsz x n_head x d_head]
- attn_vec = attn_vec.contiguous().view(
- attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
-
- ##### linear projection
- attn_out = self.o_net(attn_vec)
- attn_out = self.drop(attn_out)
-
- if self.pre_lnorm:
- ##### residual connection
- outputs = [w + attn_out]
- else:
- ##### residual connection + layer normalization
- outputs = [self.layer_norm(w + attn_out)]
-
- if self.output_attentions:
- outputs.append(attn_prob)
-
- return outputs
-
- class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
- def __init__(self, *args, **kwargs):
- super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
-
- def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None, head_mask=None):
- # r_emb: [klen, n_head, d_head], used for term B
- # r_w_bias: [n_head, d_head], used for term C
- # r_bias: [klen, n_head], used for term D
-
- qlen, bsz = w.size(0), w.size(1)
-
- if mems is not None:
- cat = torch.cat([mems, w], 0)
- if self.pre_lnorm:
- w_heads = self.qkv_net(self.layer_norm(cat))
- else:
- w_heads = self.qkv_net(cat)
- w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
-
- w_head_q = w_head_q[-qlen:]
- else:
- if self.pre_lnorm:
- w_heads = self.qkv_net(self.layer_norm(w))
- else:
- w_heads = self.qkv_net(w)
- w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
-
- klen = w_head_k.size(0)
-
- w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
- w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
- w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
-
- if klen > r_emb.size(0):
- r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
- r_emb = torch.cat([r_emb_pad, r_emb], 0)
- r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
- r_bias = torch.cat([r_bias_pad, r_bias], 0)
- else:
- r_emb = r_emb[-klen:]
- r_bias = r_bias[-klen:]
-
- #### compute attention score
- rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
-
- AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
- B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
- D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
- BD = self._rel_shift(B_ + D_)
-
- # [qlen x klen x bsz x n_head]
- attn_score = AC + BD
- attn_score.mul_(self.scale)
-
- #### compute attention probability
- if attn_mask is not None and torch.sum(attn_mask).item():
- attn_mask = (attn_mask == 1) # Switch to bool
- if attn_mask.dim() == 2:
- attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
- elif attn_mask.dim() == 3:
- attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
-
- # [qlen x klen x bsz x n_head]
- attn_prob = F.softmax(attn_score, dim=1)
- attn_prob = self.dropatt(attn_prob)
-
- if head_mask is not None:
- attn_prob = attn_prob * head_mask
-
- #### compute attention vector
- attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
-
- # [qlen x bsz x n_head x d_head]
- attn_vec = attn_vec.contiguous().view(
- attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
-
- ##### linear projection
- attn_out = self.o_net(attn_vec)
- attn_out = self.drop(attn_out)
-
- if self.pre_lnorm:
- ##### residual connection
- outputs = [w + attn_out]
- else:
- ##### residual connection + layer normalization
- outputs = [self.layer_norm(w + attn_out)]
-
- if self.output_attentions:
- outputs.append(attn_prob)
-
- return outputs
-
-
-
- class DecoderLayer(nn.Module):
- def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
- super(DecoderLayer, self).__init__()
-
- self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
- self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
- pre_lnorm=kwargs.get('pre_lnorm'))
-
- def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None):
-
- attn_outputs = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
- mems=mems, head_mask=head_mask)
- ff_output = self.pos_ff(attn_outputs[0])
-
- outputs = [ff_output] + attn_outputs[1:]
-
- return outputs
-
- class RelLearnableDecoderLayer(nn.Module):
- def __init__(self, n_head, d_model, d_head, d_inner, dropout,
- **kwargs):
- super(RelLearnableDecoderLayer, self).__init__()
-
- self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
- **kwargs)
- self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
- pre_lnorm=kwargs.get('pre_lnorm'))
-
- def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None):
-
- attn_outputs = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
- attn_mask=dec_attn_mask,
- mems=mems, head_mask=head_mask)
- ff_output = self.pos_ff(attn_outputs[0])
-
- outputs = [ff_output] + attn_outputs[1:]
-
- return outputs
-
- class RelPartialLearnableDecoderLayer(nn.Module):
- def __init__(self, n_head, d_model, d_head, d_inner, dropout,
- **kwargs):
- super(RelPartialLearnableDecoderLayer, self).__init__()
-
- self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
- d_head, dropout, **kwargs)
- self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
- pre_lnorm=kwargs.get('pre_lnorm'))
-
- def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
-
- attn_outputs = self.dec_attn(dec_inp, r,
- attn_mask=dec_attn_mask,
- mems=mems, head_mask=head_mask)
- ff_output = self.pos_ff(attn_outputs[0])
-
- outputs = [ff_output] + attn_outputs[1:]
-
- return outputs
-
-
-
- class AdaptiveEmbedding(nn.Module):
- def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
- sample_softmax=False):
- super(AdaptiveEmbedding, self).__init__()
-
- self.n_token = n_token
- self.d_embed = d_embed
-
- self.cutoffs = cutoffs + [n_token]
- self.div_val = div_val
- self.d_proj = d_proj
-
- self.emb_scale = d_proj ** 0.5
-
- self.cutoff_ends = [0] + self.cutoffs
-
- self.emb_layers = nn.ModuleList()
- self.emb_projs = nn.ParameterList()
- if div_val == 1:
- self.emb_layers.append(
- nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
- )
- if d_proj != d_embed:
- self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
- else:
- for i in range(len(self.cutoffs)):
- l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
- d_emb_i = d_embed // (div_val ** i)
- self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
- self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
-
- def forward(self, inp):
- if self.div_val == 1:
- embed = self.emb_layers[0](inp)
- if self.d_proj != self.d_embed:
- embed = F.linear(embed, self.emb_projs[0])
- else:
- param = next(self.parameters())
- inp_flat = inp.view(-1)
- emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
- dtype=param.dtype, device=param.device)
- for i in range(len(self.cutoffs)):
- l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
-
- mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
- indices_i = mask_i.nonzero().squeeze()
-
- if indices_i.numel() == 0:
- continue
-
- inp_i = inp_flat.index_select(0, indices_i) - l_idx
- emb_i = self.emb_layers[i](inp_i)
- emb_i = F.linear(emb_i, self.emb_projs[i])
-
- emb_flat.index_copy_(0, indices_i, emb_i)
-
- embed_shape = inp.size() + (self.d_proj,)
- embed = emb_flat.view(embed_shape)
-
- embed.mul_(self.emb_scale)
-
- return embed
-
-
- class TransfoXLPreTrainedModel(PreTrainedModel):
- """ An abstract class to handle weights initialization and
- a simple interface for dowloading and loading pretrained models.
- """
- config_class = TransfoXLConfig
- pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
- load_tf_weights = load_tf_weights_in_transfo_xl
- base_model_prefix = "transformer"
-
- def _init_weight(self, weight):
- if self.config.init == 'uniform':
- nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
- elif self.config.init == 'normal':
- nn.init.normal_(weight, 0.0, self.config.init_std)
-
- def _init_bias(self, bias):
- nn.init.constant_(bias, 0.0)
-
- def _init_weights(self, m):
- """ Initialize the weights.
- """
- classname = m.__class__.__name__
- if classname.find('Linear') != -1:
- if hasattr(m, 'weight') and m.weight is not None:
- self._init_weight(m.weight)
- if hasattr(m, 'bias') and m.bias is not None:
- self._init_bias(m.bias)
- elif classname.find('AdaptiveEmbedding') != -1:
- if hasattr(m, 'emb_projs'):
- for i in range(len(m.emb_projs)):
- if m.emb_projs[i] is not None:
- nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
- elif classname.find('Embedding') != -1:
- if hasattr(m, 'weight'):
- self._init_weight(m.weight)
- elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
- if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
- self._init_weight(m.cluster_weight)
- if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
- self._init_bias(m.cluster_bias)
- if hasattr(m, 'out_projs'):
- for i in range(len(m.out_projs)):
- if m.out_projs[i] is not None:
- nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
- elif classname.find('LayerNorm') != -1:
- if hasattr(m, 'weight'):
- nn.init.normal_(m.weight, 1.0, self.config.init_std)
- if hasattr(m, 'bias') and m.bias is not None:
- self._init_bias(m.bias)
- else:
- if hasattr(m, 'r_emb'):
- self._init_weight(m.r_emb)
- if hasattr(m, 'r_w_bias'):
- self._init_weight(m.r_w_bias)
- if hasattr(m, 'r_r_bias'):
- self._init_weight(m.r_r_bias)
- if hasattr(m, 'r_bias'):
- self._init_bias(m.r_bias)
-
- def set_num_special_tokens(self, num_special_tokens):
- pass
-
-
- TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in
- `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
- by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
- It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
- previously computed hidden-states to attend to longer context (memory).
- This model also uses adaptive softmax inputs and outputs (tied).
-
- This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
- refer to the PyTorch documentation for all matter related to general usage and behavior.
-
- .. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
- https://arxiv.org/abs/1901.02860
-
- .. _`torch.nn.Module`:
- https://pytorch.org/docs/stable/nn.html#module
-
- Parameters:
- config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the configuration.
- Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
- """
-
- TRANSFO_XL_INPUTS_DOCSTRING = r"""
- Inputs:
- **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
- Indices of input sequence tokens in the vocabulary.
- Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
- the right or on the left.
- Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
- See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
- :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
- **mems**: (`optional`)
- list of ``torch.FloatTensor`` (one for each layer):
- that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
- (see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
- **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
- Mask to nullify selected heads of the self-attention modules.
- Mask values selected in ``[0, 1]``:
- ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
- """
-
- @add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
- TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
- class TransfoXLModel(TransfoXLPreTrainedModel):
- r"""
- Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
- **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
- Sequence of hidden-states at the last layer of the model.
- **mems**:
- list of ``torch.FloatTensor`` (one for each layer):
- that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
- (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
- **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
- list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
- of shape ``(batch_size, sequence_length, hidden_size)``:
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- **attentions**: (`optional`, returned when ``config.output_attentions=True``)
- list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
-
- Examples::
-
- tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
- model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
- input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
- outputs = model(input_ids)
- last_hidden_states, mems = outputs[:2]
-
- """
- def __init__(self, config):
- super(TransfoXLModel, self).__init__(config)
- self.output_attentions = config.output_attentions
- self.output_hidden_states = config.output_hidden_states
-
- self.n_token = config.n_token
-
- self.d_embed = config.d_embed
- self.d_model = config.d_model
- self.n_head = config.n_head
- self.d_head = config.d_head
-
- self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
- div_val=config.div_val)
-
- self.drop = nn.Dropout(config.dropout)
-
- self.n_layer = config.n_layer
-
- self.tgt_len = config.tgt_len
- self.mem_len = config.mem_len
- self.ext_len = config.ext_len
- self.max_klen = config.tgt_len + config.ext_len + config.mem_len
-
- self.attn_type = config.attn_type
-
- if not config.untie_r:
- self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
- self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
-
- self.layers = nn.ModuleList()
- if config.attn_type == 0: # the default attention
- for i in range(config.n_layer):
- self.layers.append(
- RelPartialLearnableDecoderLayer(
- config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
- tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
- dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
- r_w_bias=None if config.untie_r else self.r_w_bias,
- r_r_bias=None if config.untie_r else self.r_r_bias,
- output_attentions=self.output_attentions)
- )
- elif config.attn_type == 1: # learnable embeddings
- for i in range(config.n_layer):
- self.layers.append(
- RelLearnableDecoderLayer(
- config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
- tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
- dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
- r_w_bias=None if config.untie_r else self.r_w_bias,
- r_r_bias=None if config.untie_r else self.r_r_bias,
- output_attentions=self.output_attentions)
- )
- elif config.attn_type in [2, 3]: # absolute embeddings
- for i in range(config.n_layer):
- self.layers.append(
- DecoderLayer(
- config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
- dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
- r_w_bias=None if config.untie_r else self.r_w_bias,
- r_r_bias=None if config.untie_r else self.r_r_bias,
- output_attentions=self.output_attentions)
- )
-
- self.same_length = config.same_length
- self.clamp_len = config.clamp_len
-
- if self.attn_type == 0: # default attention
- self.pos_emb = PositionalEmbedding(self.d_model)
- elif self.attn_type == 1: # learnable
- self.r_emb = nn.Parameter(torch.FloatTensor(
- self.n_layer, self.max_klen, self.n_head, self.d_head))
- self.r_bias = nn.Parameter(torch.FloatTensor(
- self.n_layer, self.max_klen, self.n_head))
- elif self.attn_type == 2: # absolute standard
- self.pos_emb = PositionalEmbedding(self.d_model)
- elif self.attn_type == 3: # absolute deeper SA
- self.r_emb = nn.Parameter(torch.FloatTensor(
- self.n_layer, self.max_klen, self.n_head, self.d_head))
-
- self.init_weights()
-
- def _resize_token_embeddings(self, new_num_tokens):
- return self.word_emb
-
- def backward_compatible(self):
- self.sample_softmax = -1
-
- def reset_length(self, tgt_len, ext_len, mem_len):
- self.tgt_len = tgt_len
- self.mem_len = mem_len
- self.ext_len = ext_len
-
- def _prune_heads(self, heads):
- logger.info("Head pruning is not implemented for Transformer-XL model")
- pass
-
- def init_mems(self, data):
- if self.mem_len > 0:
- mems = []
- param = next(self.parameters())
- for i in range(self.n_layer):
- empty = torch.zeros(self.mem_len, data.size(1), self.config.d_model,
- dtype=param.dtype, device=param.device)
- mems.append(empty)
-
- return mems
- else:
- return None
-
- def _update_mems(self, hids, mems, qlen, mlen):
- # does not deal with None
- if mems is None: return None
-
- # mems is not None
- assert len(hids) == len(mems), 'len(hids) != len(mems)'
-
- # There are `mlen + qlen` steps that can be cached into mems
- # For the next step, the last `ext_len` of the `qlen` tokens
- # will be used as the extended context. Hence, we only cache
- # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
- # to `mlen + qlen - self.ext_len`.
- with torch.no_grad():
- new_mems = []
- end_idx = mlen + max(0, qlen - 0 - self.ext_len)
- beg_idx = max(0, end_idx - self.mem_len)
- for i in range(len(hids)):
-
- cat = torch.cat([mems[i], hids[i]], dim=0)
- new_mems.append(cat[beg_idx:end_idx].detach())
-
- return new_mems
-
- def _forward(self, dec_inp, mems=None, head_mask=None):
- qlen, bsz = dec_inp.size()
-
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x n_heads x N x N
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
- # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
- if head_mask is not None:
- if head_mask.dim() == 1:
- head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
- head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
- elif head_mask.dim() == 2:
- head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
- head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
- else:
- head_mask = [None] * self.n_layer
-
- word_emb = self.word_emb(dec_inp)
-
- mlen = mems[0].size(0) if mems is not None else 0
- klen = mlen + qlen
- if self.same_length:
- all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
- mask_len = klen - self.mem_len
- if mask_len > 0:
- mask_shift_len = qlen - mask_len
- else:
- mask_shift_len = qlen
- dec_attn_mask = (torch.triu(all_ones, 1+mlen)
- + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
- else:
- dec_attn_mask = torch.triu(
- word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
-
- hids = []
- attentions = []
- if self.attn_type == 0: # default
- pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
- dtype=word_emb.dtype)
- if self.clamp_len > 0:
- pos_seq.clamp_(max=self.clamp_len)
- pos_emb = self.pos_emb(pos_seq)
-
- core_out = self.drop(word_emb)
- pos_emb = self.drop(pos_emb)
-
- for i, layer in enumerate(self.layers):
- hids.append(core_out)
- mems_i = None if mems is None else mems[i]
- layer_outputs = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask,
- mems=mems_i, head_mask=head_mask[i])
- core_out = layer_outputs[0]
- if self.output_attentions:
- attentions.append(layer_outputs[1])
- elif self.attn_type == 1: # learnable
- core_out = self.drop(word_emb)
- for i, layer in enumerate(self.layers):
- hids.append(core_out)
- if self.clamp_len > 0:
- r_emb = self.r_emb[i][-self.clamp_len :]
- r_bias = self.r_bias[i][-self.clamp_len :]
- else:
- r_emb, r_bias = self.r_emb[i], self.r_bias[i]
-
- mems_i = None if mems is None else mems[i]
- layer_outputs = layer(core_out, r_emb, self.r_w_bias[i],
- r_bias, dec_attn_mask=dec_attn_mask,
- mems=mems_i, head_mask=head_mask[i])
- core_out = layer_outputs[0]
- if self.output_attentions:
- attentions.append(layer_outputs[1])
- elif self.attn_type == 2: # absolute
- pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
- dtype=word_emb.dtype)
- if self.clamp_len > 0:
- pos_seq.clamp_(max=self.clamp_len)
- pos_emb = self.pos_emb(pos_seq)
-
- core_out = self.drop(word_emb + pos_emb[-qlen:])
-
- for i, layer in enumerate(self.layers):
- hids.append(core_out)
- mems_i = None if mems is None else mems[i]
- if mems_i is not None and i == 0:
- mems_i += pos_emb[:mlen]
- layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
- mems=mems_i, head_mask=head_mask[i])
- core_out = layer_outputs[0]
- if self.output_attentions:
- attentions.append(layer_outputs[1])
- elif self.attn_type == 3:
- core_out = self.drop(word_emb)
-
- for i, layer in enumerate(self.layers):
- hids.append(core_out)
- mems_i = None if mems is None else mems[i]
- if mems_i is not None and mlen > 0:
- cur_emb = self.r_emb[i][:-qlen]
- cur_size = cur_emb.size(0)
- if cur_size < mlen:
- cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
- cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
- else:
- cur_emb = cur_emb[-mlen:]
- mems_i += cur_emb.view(mlen, 1, -1)
- core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
-
- layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
- mems=mems_i, head_mask=head_mask[i])
- core_out = layer_outputs[0]
- if self.output_attentions:
- attentions.append(layer_outputs[1])
-
- core_out = self.drop(core_out)
-
- new_mems = self._update_mems(hids, mems, mlen, qlen)
-
- # We transpose back here to shape [bsz, len, hidden_dim]
- outputs = [core_out.transpose(0, 1).contiguous(), new_mems]
- if self.output_hidden_states:
- # Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
- hids.append(core_out)
- hids = list(t.transpose(0, 1).contiguous() for t in hids)
- outputs.append(hids)
- if self.output_attentions:
- # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
- attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
- outputs.append(attentions)
- return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
-
- def forward(self, input_ids, mems=None, head_mask=None):
- # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
- # so we transpose here from shape [bsz, len] to shape [len, bsz]
- input_ids = input_ids.transpose(0, 1).contiguous()
-
- if mems is None:
- mems = self.init_mems(input_ids)
- outputs = self._forward(input_ids, mems=mems, head_mask=head_mask)
-
- return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
-
-
- @add_start_docstrings("""The Transformer-XL Model with a language modeling head on top
- (adaptive softmax with weights tied to the adaptive input embeddings)""",
- TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
- class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
- r"""
- **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
- Labels for language modeling.
- Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
- Indices are selected in ``[-1, 0, ..., config.vocab_size]``
- All labels set to ``-1`` are ignored (masked), the loss is only
- computed for labels in ``[0, ..., config.vocab_size]``
-
- Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
- **loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
- Language modeling loss.
- **prediction_scores**: ``None`` if ``lm_labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- We don't output them when the loss is computed to speedup adaptive softmax decoding.
- **mems**:
- list of ``torch.FloatTensor`` (one for each layer):
- that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
- (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
- **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
- list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
- of shape ``(batch_size, sequence_length, hidden_size)``:
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- **attentions**: (`optional`, returned when ``config.output_attentions=True``)
- list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
-
- Examples::
-
- tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
- model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
- input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
- outputs = model(input_ids)
- prediction_scores, mems = outputs[:2]
-
- """
- def __init__(self, config):
- super(TransfoXLLMHeadModel, self).__init__(config)
- self.transformer = TransfoXLModel(config)
- self.sample_softmax = config.sample_softmax
- # use sampled softmax
- if config.sample_softmax > 0:
- self.out_layer = nn.Linear(config.d_model, config.n_token)
- self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
- # use adaptive softmax (including standard softmax)
- else:
- self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
- config.cutoffs, div_val=config.div_val)
- self.init_weights()
- self.tie_weights()
-
- def tie_weights(self):
- """
- Run this to be sure output and input (adaptive) softmax weights are tied
- """
- # sampled softmax
- if self.sample_softmax > 0:
- if self.config.tie_weight:
- self.out_layer.weight = self.transformer.word_emb.weight
- # adaptive softmax (including standard softmax)
- else:
- if self.config.tie_weight:
- for i in range(len(self.crit.out_layers)):
- self._tie_or_clone_weights(self.crit.out_layers[i],
- self.transformer.word_emb.emb_layers[i])
- if self.config.tie_projs:
- for i, tie_proj in enumerate(self.config.tie_projs):
- if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
- if self.config.torchscript:
- self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
- else:
- self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
- elif tie_proj and self.config.div_val != 1:
- if self.config.torchscript:
- self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
- else:
- self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
-
- def reset_length(self, tgt_len, ext_len, mem_len):
- self.transformer.reset_length(tgt_len, ext_len, mem_len)
-
- def init_mems(self, data):
- return self.transformer.init_mems(data)
-
- def forward(self, input_ids, labels=None, mems=None, head_mask=None):
- bsz = input_ids.size(0)
- tgt_len = input_ids.size(1)
-
- transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask)
-
- last_hidden = transformer_outputs[0]
- pred_hid = last_hidden[:, -tgt_len:]
- outputs = transformer_outputs[1:]
- if self.sample_softmax > 0 and self.training:
- assert self.config.tie_weight
- logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
- softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
- outputs = [softmax_output] + outputs
- if labels is not None:
- # TODO: This is not implemented
- raise NotImplementedError
- else:
- softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
- if labels is None:
- softmax_output = softmax_output.view(bsz, tgt_len, -1)
- outputs = [softmax_output] + outputs
- else:
- softmax_output = softmax_output.view(bsz, tgt_len)
- outputs = [softmax_output, None] + outputs
-
- return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
|