|
- from fairseq import utils
- from fairseq.tasks import register_task
- import itertools
- import os
- import logging
- from .document_dataset import DocumentDataset
- from fairseq.tasks.translation_from_pretrained_bart import TranslationFromPretrainedBARTTask
-
- from fairseq.data import (
- AppendTokenDataset,
- ConcatDataset,
- data_utils,
- indexed_dataset,
- PrependTokenDataset,
- StripTokenDataset,
- TruncateDataset,
- )
-
- logger = logging.getLogger(__name__)
-
- def load_document_dataset(
- data_path, split,
- src, src_dict,
- tgt, tgt_dict, con,
- combine, dataset_impl, upsample_primary,
- left_pad_source, left_pad_target, max_source_positions,
- max_target_positions, prepend_bos=False, load_alignments=False,
- truncate_source=False, append_source_id=False,
- num_buckets=0,
- shuffle=True,
- ):
-
- def split_exists(split, src, tgt, lang, data_path):
- filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
- return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
-
- src_datasets = []
- tgt_datasets = []
- con_datasets = []
-
- for k in itertools.count():
- split_k = split + (str(k) if k > 0 else '')
-
- # infer langcode
- if split_exists(split_k, src, tgt, src, data_path):
- prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
- elif split_exists(split_k, tgt, src, src, data_path):
- prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
- else:
- if k > 0:
- break
- else:
- raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
-
- src_dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)
- if truncate_source:
- src_dataset = AppendTokenDataset(
- TruncateDataset(
- StripTokenDataset(src_dataset, src_dict.eos()),
- max_source_positions - 1,
- ),
- src_dict.eos(),
- )
- src_datasets.append(src_dataset)
-
- con_dataset = data_utils.load_indexed_dataset(prefix + con, src_dict, dataset_impl)
-
- if con_dataset is not None:
- con_datasets.append(con_dataset)
-
- tgt_dataset = data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)
- if tgt_dataset is not None:
- tgt_datasets.append(tgt_dataset)
-
- logger.info('{} {} {}-{} {} examples'.format(
- data_path, split_k, src, tgt, len(src_datasets[-1])
- ))
-
- if not combine:
- break
-
- assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
- # assert len(src_datasets) == len(con_datasets)
-
- if len(src_datasets) == 1:
- src_dataset = src_datasets[0]
- if len(con_datasets) == 1:
- con_dataset = con_datasets[0]
- tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
- else:
- sample_ratios = [1] * len(src_datasets)
- sample_ratios[0] = upsample_primary
- src_dataset = ConcatDataset(src_datasets, sample_ratios)
- if len(con_datasets) > 1:
- con_dataset = ConcatDataset(con_datasets, sample_ratios)
- if len(tgt_datasets) > 0:
- tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
- else:
- tgt_dataset = None
-
- if prepend_bos:
- assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
- src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
- con_dataset = PrependTokenDataset(con_dataset, src_dict.bos())
- if tgt_dataset is not None:
- tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())
-
- eos = None
- if append_source_id:
- src_dataset = AppendTokenDataset(src_dataset, src_dict.index('[{}]'.format(src)))
- con_dataset = AppendTokenDataset(con_dataset, src_dict.index('[{}]'.format(con)))
- if tgt_dataset is not None:
- tgt_dataset = AppendTokenDataset(tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
- eos = tgt_dict.index('[{}]'.format(tgt))
-
- align_dataset = None
- if load_alignments:
- align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
- if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
- align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)
-
- tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
- con_dataset_sizes = con_dataset.sizes if con_dataset is not None else None
- return DocumentDataset(
- src_dataset, src_dataset.sizes, src_dict,
- tgt_dataset, tgt_dataset_sizes, tgt_dict,
- con_dataset, con_dataset_sizes,
- left_pad_source=left_pad_source,
- left_pad_target=left_pad_target,
- align_dataset=align_dataset, eos=eos,
- num_buckets=num_buckets,
- shuffle=shuffle,
- )
-
-
- @register_task('adapter_mbart_task')
- class AdapterMbartTask(TranslationFromPretrainedBARTTask):
-
- @staticmethod
- def add_args(parser):
- TranslationFromPretrainedBARTTask.add_args(parser)
- parser.add_argument('-c', '--context', default=None, metavar='CONTEXT', help='context')
-
- def load_dataset(self, split, epoch=1, combine=False, **kwargs):
- paths = utils.split_paths(self.args.data)
- assert len(paths) > 0
- data_path = paths[(epoch - 1) % len(paths)]
-
- # infer langcode
- src, tgt, cxt = self.args.source_lang, self.args.target_lang, self.args.context
-
- self.datasets[split] = load_document_dataset(
- data_path, split, src, self.src_dict, tgt, self.tgt_dict, cxt,
- combine=combine, dataset_impl=self.args.dataset_impl,
- upsample_primary=self.args.upsample_primary,
- left_pad_source=self.args.left_pad_source,
- left_pad_target=self.args.left_pad_target,
- max_source_positions=getattr(self.args, 'max_source_positions', 1024),
- max_target_positions=getattr(self.args, 'max_target_positions', 1024),
- truncate_source=self.args.truncate_source,
- num_buckets=self.args.num_batch_buckets,
- shuffle=(split != 'test'),
- load_alignments=self.args.load_alignments,
- prepend_bos=getattr(self.args, 'prepend_bos', False),
- append_source_id=True
- )
|