|
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
-
- import os
- from collections import Counter
-
- import torch
- from fairseq.file_io import PathManager
- from fairseq.tokenizer import tokenize_line
- from typing import List, Dict
-
-
- def safe_readline(f):
- pos = f.tell()
- while True:
- try:
- return f.readline()
- except UnicodeDecodeError:
- pos -= 1
- f.seek(pos) # search where this character begins
-
-
- class Binarizer:
- @staticmethod
- def binarize(
- filename,
- dict,
- consumer,
- tokenize=tokenize_line,
- append_eos=True,
- reverse_order=False,
- offset=0,
- end=-1,
- already_numberized=False,
- ) -> Dict[str, int]:
- nseq, ntok = 0, 0
- replaced = Counter()
-
- def replaced_consumer(word, idx):
- if idx == dict.unk_index and word != dict.unk_word:
- replaced.update([word])
-
- with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
- f.seek(offset)
- # next(f) breaks f.tell(), hence readline() must be used
- line = safe_readline(f)
- while line:
- # f.tell() does not always give the byte position in the file
- # sometimes it skips to a very large number
- # it is unlikely that through a normal read we go from
- # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely
- # that the procedure breaks by the undeterministic behavior of
- # f.tell()
- if end > 0 and f.tell() > end and f.tell() < end + 2 ** 32:
- break
- if already_numberized:
- id_strings = line.strip().split()
- id_list = [int(id_string) for id_string in id_strings]
- if reverse_order:
- id_list.reverse()
- if append_eos:
- id_list.append(dict.eos())
- ids = torch.IntTensor(id_list)
- else:
- ids = dict.encode_line(
- line=line,
- line_tokenizer=tokenize,
- add_if_not_exist=False,
- consumer=replaced_consumer,
- append_eos=append_eos,
- reverse_order=reverse_order,
- )
- nseq += 1
- ntok += len(ids)
- consumer(ids)
- line = f.readline()
- return {
- "nseq": nseq,
- "nunk": sum(replaced.values()),
- "ntok": ntok,
- "replaced": replaced,
- }
-
- @staticmethod
- def binarize_alignments(
- filename, alignment_parser, consumer, offset=0, end=-1
- ) -> Dict[str, int]:
- nseq = 0
-
- with open(PathManager.get_local_path(filename), "r") as f:
- f.seek(offset)
- line = safe_readline(f)
- while line:
- if end > 0 and f.tell() > end:
- break
- ids = alignment_parser(line)
- nseq += 1
- consumer(ids)
- line = f.readline()
- return {"nseq": nseq}
-
- @staticmethod
- def find_offsets(filename, num_chunks) -> List[int]:
- with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
- size = os.fstat(f.fileno()).st_size
- chunk_size = size // num_chunks
- offsets = [0 for _ in range(num_chunks + 1)]
- for i in range(1, num_chunks):
- f.seek(chunk_size * i)
- safe_readline(f)
- offsets[i] = f.tell()
- return offsets
|