|
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
-
- import logging
- from typing import DefaultDict
-
- import numpy as np
- import torch
- from collections import defaultdict
- from fairseq.data import data_utils, FairseqDataset, LanguagePairDataset
-
-
- logger = logging.getLogger(__name__)
-
-
- def collate(
- samples,
- pad_idx,
- eos_idx,
- left_pad_source=True,
- left_pad_target=False,
- input_feeding=True,
- pad_to_length=None,
- ):
- if len(samples) == 0:
- return {}
- # print('context:',samples[0]['context'],samples[0]['context'].size())
- # print('source:',samples[0]['source'],samples[0]['source'].size())
- def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
-
- return data_utils.collate_tokens(
- [s[key] for s in samples],
- pad_idx, eos_idx, left_pad, move_eos_to_beginning,
- pad_to_length=pad_to_length,
- )
-
- def check_alignment(alignment, src_len, tgt_len):
- if alignment is None or len(alignment) == 0:
- return False
- if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
- logger.warning("alignment size mismatch found, skipping alignment!")
- return False
- return True
-
- def compute_alignment_weights(alignments):
- """
- Given a tensor of shape [:, 2] containing the source-target indices
- corresponding to the alignments, a weight vector containing the
- inverse frequency of each target index is computed.
- For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then
- a tensor containing [1., 0.5, 0.5, 1] should be returned (since target
- index 3 is repeated twice)
- """
- align_tgt = alignments[:, 1]
- _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True)
- align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]]
- return 1. / align_weights.float()
-
- id = torch.LongTensor([s['id'] for s in samples])
-
- src_tokens = merge(
- 'source', left_pad=left_pad_source,
- pad_to_length=pad_to_length['source'] if pad_to_length is not None else None
- )
- print('result: ' , pad_to_length)
- con_tokens = merge(
- 'context', left_pad=left_pad_source,
- pad_to_length=pad_to_length['source']*3 if pad_to_length is not None else None
- )
- # sort by descending source length
- src_lengths = torch.LongTensor([
- s['source'].ne(pad_idx).long().sum() for s in samples
- ])
- src_lengths, sort_order = src_lengths.sort(descending=True)
- id = id.index_select(0, sort_order)
- src_tokens = src_tokens.index_select(0, sort_order)
- con_tokens = con_tokens.index_select(0, sort_order)
-
- prev_output_tokens = None
- target = None
- if samples[0].get('target', None) is not None:
- target = merge(
- 'target', left_pad=left_pad_target,
- pad_to_length=pad_to_length['target'] if pad_to_length is not None else None,
- )
- target = target.index_select(0, sort_order)
- tgt_lengths = torch.LongTensor([
- s['target'].ne(pad_idx).long().sum() for s in samples
- ]).index_select(0, sort_order)
- ntokens = tgt_lengths.sum().item()
-
- if samples[0].get('prev_output_tokens', None) is not None:
- prev_output_tokens = merge('prev_output_tokens', left_pad=left_pad_target)
- elif input_feeding:
- # we create a shifted version of targets for feeding the
- # previous output token(s) into the next decoder step
- prev_output_tokens = merge(
- 'target',
- left_pad=left_pad_target,
- move_eos_to_beginning=True,
- pad_to_length=pad_to_length['target'] if pad_to_length is not None else None,
- )
- else:
- ntokens = src_lengths.sum().item()
-
- batch = {
- 'id': id,
- 'nsentences': len(samples),
- 'ntokens': ntokens,
- 'net_input': {
- 'src_tokens': src_tokens,
- 'src_lengths': src_lengths,
- 'cxt_tokens': con_tokens
- },
- 'target': target,
- }
- if prev_output_tokens is not None:
- batch['net_input']['prev_output_tokens'] = prev_output_tokens.index_select(0, sort_order)
-
- if samples[0].get('alignment', None) is not None:
- bsz, tgt_sz = batch['target'].shape
- src_sz = batch['net_input']['src_tokens'].shape[1]
-
- offsets = torch.zeros((len(sort_order), 2), dtype=torch.long)
- offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz)
- if left_pad_source:
- offsets[:, 0] += (src_sz - src_lengths)
- if left_pad_target:
- offsets[:, 1] += (tgt_sz - tgt_lengths)
-
- alignments = [
- alignment + offset
- for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths)
- for alignment in [samples[align_idx]['alignment'].view(-1, 2)]
- if check_alignment(alignment, src_len, tgt_len)
- ]
-
- if len(alignments) > 0:
- alignments = torch.cat(alignments, dim=0)
- align_weights = compute_alignment_weights(alignments)
-
- batch['alignments'] = alignments
- batch['align_weights'] = align_weights
-
- if samples[0].get("constraints", None) is not None:
- # Collate the packed constraints across the samples, padding to
- # the length of the longest sample.
- lens = [sample.get("constraints").size(0) for sample in samples]
- max_len = max(lens)
- constraints = torch.zeros((len(samples), max(lens))).long()
- for i, sample in enumerate(samples):
- constraints[i, 0:lens[i]] = samples[i].get("constraints")
- batch["constraints"] = constraints
-
- return batch
-
-
- class DocumentDataset(LanguagePairDataset):
- """
- A pair of torch.utils.data.Datasets.
-
- Args:
- src (torch.utils.data.Dataset): source dataset to wrap
- src_sizes (List[int]): source sentence lengths
- src_dict (~fairseq.data.Dictionary): source vocabulary
- tgt (torch.utils.data.Dataset, optional): target dataset to wrap
- tgt_sizes (List[int], optional): target sentence lengths
- tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
- left_pad_source (bool, optional): pad source tensors on the left side
- (default: True).
- left_pad_target (bool, optional): pad target tensors on the left side
- (default: False).
- shuffle (bool, optional): shuffle dataset elements before batching
- (default: True).
- input_feeding (bool, optional): create a shifted version of the targets
- to be passed into the model for teacher forcing (default: True).
- remove_eos_from_source (bool, optional): if set, removes eos from end
- of source if it's present (default: False).
- append_eos_to_target (bool, optional): if set, appends eos to end of
- target if it's absent (default: False).
- align_dataset (torch.utils.data.Dataset, optional): dataset
- containing alignments.
- constraints (Tensor, optional): 2d tensor with a concatenated, zero-
- delimited list of constraints for each sentence.
- append_bos (bool, optional): if set, appends bos to the beginning of
- source/target sentence.
- num_buckets (int, optional): if set to a value greater than 0, then
- batches will be bucketed into the given number of batch shapes.
- src_lang_id (int, optional): source language ID, if set, the collated batch
- will contain a field 'src_lang_id' in 'net_input' which indicates the
- source language of the samples.
- tgt_lang_id (int, optional): target language ID, if set, the collated batch
- will contain a field 'tgt_lang_id' which indicates the target language
- of the samples.
- """
-
- def __init__(
- self, src, src_sizes, src_dict,
- tgt=None, tgt_sizes=None, tgt_dict=None,
- con=None, con_sizes=None,
- left_pad_source=True, left_pad_target=False,
- shuffle=True, input_feeding=True,
- remove_eos_from_source=False, append_eos_to_target=False,
- align_dataset=None,
- constraints=None,
- append_bos=False, eos=None,
- num_buckets=0,
- src_lang_id=None,
- tgt_lang_id=None,
- ):
- if tgt_dict is not None:
- assert src_dict.pad() == tgt_dict.pad()
- assert src_dict.eos() == tgt_dict.eos()
- assert src_dict.unk() == tgt_dict.unk()
- if tgt is not None:
- assert len(src) == len(tgt), "Source and target must contain the same number of examples"
- # assert len(src) == len(con)
-
- self.src = src
- self.tgt = tgt
- self.con = con
- self.src_sizes = np.array(src_sizes)
- self.con_sizes = np.array(con_sizes)
- self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
- self.src_dict = src_dict
- self.tgt_dict = tgt_dict
- self.left_pad_source = left_pad_source
- self.left_pad_target = left_pad_target
- self.shuffle = shuffle
- self.input_feeding = input_feeding
- self.remove_eos_from_source = remove_eos_from_source
- self.append_eos_to_target = append_eos_to_target
- self.align_dataset = align_dataset
- if self.align_dataset is not None:
- assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
- self.constraints = constraints
- self.append_bos = append_bos
- self.eos = (eos if eos is not None else src_dict.eos())
- self.src_lang_id = src_lang_id
- self.tgt_lang_id = tgt_lang_id
- if num_buckets > 0:
- from fairseq.data import BucketPadLengthDataset
- self.src = BucketPadLengthDataset(
- self.src,
- sizes=self.src_sizes,
- num_buckets=num_buckets,
- pad_idx=self.src_dict.pad(),
- left_pad=self.left_pad_source,
- )
- self.src_sizes = self.src.sizes
- logger.info('bucketing source lengths: {}'.format(list(self.src.buckets)))
- if self.tgt is not None:
- self.tgt = BucketPadLengthDataset(
- self.tgt,
- sizes=self.tgt_sizes,
- num_buckets=num_buckets,
- pad_idx=self.tgt_dict.pad(),
- left_pad=self.left_pad_target,
- )
- self.tgt_sizes = self.tgt.sizes
- logger.info('bucketing target lengths: {}'.format(list(self.tgt.buckets)))
-
- # determine bucket sizes using self.num_tokens, which will return
- # the padded lengths (thanks to BucketPadLengthDataset)
- num_tokens = np.vectorize(self.num_tokens, otypes=[np.long])
- self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
- self.buckets = [
- (None, num_tokens)
- for num_tokens in np.unique(self.bucketed_num_tokens)
- ]
- else:
- self.buckets = None
-
- def get_batch_shapes(self):
- return self.buckets
-
- def __getitem__(self, index):
- tgt_item = self.tgt[index] if self.tgt is not None else None
- src_item = self.src[index]
- con_item = self.con[index]
- # Append EOS to end of tgt sentence if it does not have an EOS and remove
- # EOS from end of src sentence if it exists. This is useful when we use
- # use existing datasets for opposite directions i.e., when we want to
- # use tgt_dataset as src_dataset and vice versa
- if self.append_eos_to_target:
- eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
- if self.tgt and self.tgt[index][-1] != eos:
- tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
-
- if self.append_bos:
- bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
- if self.tgt and self.tgt[index][0] != bos:
- tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
-
- bos = self.src_dict.bos()
- if self.src[index][0] != bos:
- src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
- con_item = torch.cat([torch.LongTensor([bos]), self.con[index]])
-
- if self.remove_eos_from_source:
- eos = self.src_dict.eos()
- if self.src[index][-1] == eos:
- src_item = self.src[index][:-1]
- if self.con[index][-1] == eos:
- con_item = self.con[index][:-1]
-
- example = {
- 'id': index,
- 'source': src_item,
- 'context': con_item,
- 'target': tgt_item,
- }
- if self.align_dataset is not None:
- example['alignment'] = self.align_dataset[index]
- if self.constraints is not None:
- example["constraints"] = self.constraints[index]
- return example
-
- def __len__(self):
- return len(self.src)
-
- def collater(self, samples, pad_to_length=None):
- """Merge a list of samples to form a mini-batch.
-
- Args:
- samples (List[dict]): samples to collate
- pad_to_length (dict, optional): a dictionary of
- {'source': source_pad_to_length, 'target': target_pad_to_length}
- to indicate the max length to pad to in source and target respectively.
-
- Returns:
- dict: a mini-batch with the following keys:
-
- - `id` (LongTensor): example IDs in the original input order
- - `ntokens` (int): total number of tokens in the batch
- - `net_input` (dict): the input to the Model, containing keys:
-
- - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
- the source sentence of shape `(bsz, src_len)`. Padding will
- appear on the left if *left_pad_source* is ``True``.
- - `src_lengths` (LongTensor): 1D Tensor of the unpadded
- lengths of each source sentence of shape `(bsz)`
- - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
- tokens in the target sentence, shifted right by one
- position for teacher forcing, of shape `(bsz, tgt_len)`.
- This key will not be present if *input_feeding* is
- ``False``. Padding will appear on the left if
- *left_pad_target* is ``True``.
- - `src_lang_id` (LongTensor): a long Tensor which contains source
- language IDs of each sample in the batch
-
- - `target` (LongTensor): a padded 2D Tensor of tokens in the
- target sentence of shape `(bsz, tgt_len)`. Padding will appear
- on the left if *left_pad_target* is ``True``.
- - `tgt_lang_id` (LongTensor): a long Tensor which contains target language
- IDs of each sample in the batch
- """
-
- res = collate(
- samples,
- pad_idx=self.src_dict.pad(),
- eos_idx=self.eos,
- left_pad_source=self.left_pad_source,
- left_pad_target=self.left_pad_target,
- input_feeding=self.input_feeding,
- pad_to_length=pad_to_length,
- )
- if self.src_lang_id is not None or self.tgt_lang_id is not None:
- src_tokens = res['net_input']['src_tokens']
- bsz = src_tokens.size(0)
- if self.src_lang_id is not None:
- res['net_input']['src_lang_id'] = torch.LongTensor(
- [[self.src_lang_id]]
- ).expand(bsz, 1).to(src_tokens)
- if self.tgt_lang_id is not None:
- res['tgt_lang_id'] = torch.LongTensor(
- [[self.tgt_lang_id]]
- ).expand(bsz, 1).to(src_tokens)
- return res
-
- def num_tokens(self, index):
- """Return the number of tokens in a sample. This value is used to
- enforce ``--max-tokens`` during batching."""
- return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
-
- def size(self, index):
- """Return an example's size as a float or tuple. This value is used when
- filtering a dataset with ``--max-positions``."""
- return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
-
- def ordered_indices(self):
- """Return an ordered list of indices. Batches will be constructed based
- on this order."""
- if self.shuffle:
- indices = np.random.permutation(len(self)).astype(np.int64)
- else:
- indices = np.arange(len(self), dtype=np.int64)
- if self.buckets is None:
- # sort by target length, then source length
- if self.tgt_sizes is not None:
- indices = indices[
- np.argsort(self.tgt_sizes[indices], kind='mergesort')
- ]
- return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
- else:
- # sort by bucketed_num_tokens, which is:
- # max(padded_src_len, padded_tgt_len)
- return indices[
- np.argsort(self.bucketed_num_tokens[indices], kind='mergesort')
- ]
-
-
-
- def filter_indices_by_size(self, indices, max_sizes):
- """ Filter a list of sample indices. Remove those that are longer
- than specified in max_sizes.
-
- Args:
- indices (np.array): original array of sample indices
- max_sizes (int or list[int] or tuple[int]): max sample size,
- can be defined separately for src and tgt (then list or tuple)
-
- Returns:
- np.array: filtered sample array
- list: list of removed indices
- """
- if max_sizes is None:
- return indices, []
- if type(max_sizes) in (int, float):
- max_src_size, max_tgt_size = max_sizes, max_sizes
- else:
- max_src_size, max_tgt_size = max_sizes
- if self.tgt_sizes is None:
- ignored = indices[self.src_sizes[indices] > max_src_size]
- else:
- ignored = indices[(self.src_sizes[indices] > max_src_size) |
- (self.tgt_sizes[indices] > max_tgt_size)]
- if len(ignored) > 0:
- if self.tgt_sizes is None:
- indices = indices[self.src_sizes[indices] <= max_src_size]
- else:
- indices = indices[(self.src_sizes[indices] <= max_src_size) &
- (self.tgt_sizes[indices] <= max_tgt_size)]
- return indices, ignored.tolist()
|