|
- # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """several datasets with preset arguments"""
- from .datasets import json_dataset, csv_dataset
- import os
- import json
- import random
- import tqdm
- from multiprocessing import Queue, Process
- from queue import Empty
- from collections import defaultdict
- from torch.utils import data
- from .lazy_loader import LazyLoader
- from utils import print_rank_0
-
- NUM_PROCESSES = 100
-
-
- def punctuation_standardization(string: str):
- punctuation_dict = {"\u201c": "\"", "\u201d": "\"", "\u2019": "'", "\u2018": "'", "\u2013": "-"}
- for key, value in punctuation_dict.items():
- string = string.replace(key, value)
- return string
-
-
- class KeyDataset(data.Dataset):
- def __init__(self, text_loader, mask_loader, **kwargs):
- self.texts = text_loader
- self.masks = mask_loader
- self.is_lazy = False
- if isinstance(self.texts, LazyLoader) and isinstance(self.masks, LazyLoader):
- self.text_lens = self.texts.lens
- self.is_lazy = True
-
- def get_text_len(self, idx):
- return self.text_lens[idx]
-
- def __getitem__(self, index):
- text = self.texts[index]
- mask_length = self.masks[index]
- mask = []
- for i, length in enumerate(mask_length):
- if i % 2 == 0:
- mask += [0] * length
- else:
- mask += [1] * length
- assert len(text) == len(mask)
- return {"tokens": text, "loss_masks": mask}
-
- def __len__(self):
- return len(self.texts)
-
-
- class PromptDataset(data.Dataset):
- def __init__(self, prompt_loader, text_loader, tokenizer=None, to_tokenize=False, **kwargs):
- self.prompts = prompt_loader
- self.texts = text_loader
- self.tokenizer = tokenizer
- self.to_tokenize = to_tokenize
- if isinstance(self.prompts, LazyLoader) and isinstance(self.texts, LazyLoader):
- self.prompt_lens = self.prompts.lens
- self.text_lens = self.texts.lens
- self.is_lazy = True
-
- def get_text_len(self, idx):
- return self.prompt_lens[idx] + self.text_lens[idx]
-
- def __getitem__(self, index):
- prompt = self.prompts[index]
- text = self.texts[index]
- if self.to_tokenize:
- prompt = self.tokenizer.EncodeAsIds(prompt).tokenization
- text = self.tokenizer.EncodeAsIds(text).tokenization
- return {"tokens": prompt + text, "loss_masks": [0] * len(prompt) + [1] * len(text)}
-
- def __len__(self):
- return len(self.prompts)
-
-
- class DataReader:
- PATH = None
- assert_str = None
- reserve_punct = False
-
- def tokenize_worker(self, input, output, info, tokenizer, tokenize):
- raise NotImplementedError
-
- def print_info(self, info):
- pass
-
- def __init__(self, writers, tokenizer=None, tokenize=False, **kwargs):
- assert os.path.exists(self.PATH), self.assert_str
- print_rank_0(f"Creating dataset from {self.PATH}")
- self.tokenizer = tokenizer
- self.tokenize = tokenize
- self.writers = writers
-
- def process(self):
- if os.path.isdir(self.PATH):
- paths = [entry.path for entry in os.scandir(self.PATH) if
- not entry.is_dir() and not entry.name.endswith("bz2")]
- else:
- paths = [self.PATH]
- task_queue, done_queue, info_queue = Queue(maxsize=10000000), Queue(maxsize=10000000), Queue()
- processes = []
- for i in range(NUM_PROCESSES):
- process = Process(target=self.tokenize_worker,
- args=(task_queue, done_queue, info_queue, self.tokenizer, self.tokenize))
- process.start()
- processes.append(process)
-
- def read_input_to_queue():
- for path in paths:
- print_rank_0(f"Start reading {path}")
- with open(path) as file:
- for row in file:
- task_queue.put(row)
- print_rank_0("Read input complete")
- for i in range(len(processes)):
- task_queue.put('STOP')
-
- process = Process(target=read_input_to_queue)
- process.start()
- count = len(processes)
- progress_bar = tqdm.tqdm()
- while True:
- data = done_queue.get()
- if data == 'COMPLETE':
- count -= 1
- if count == 0:
- break
- else:
- self.write_result(data, self.writers)
- progress_bar.update()
- progress_bar.close()
- self.print_info(info_queue)
-
- @staticmethod
- def write_result(data, writers):
- raise NotImplementedError
-
- @staticmethod
- def get_token_count(contents):
- return sum(map(len, contents))
-
- @classmethod
- def process_sample(cls, text, tokenizer, tokenize):
- if isinstance(text, str) and tokenize:
- if not cls.reserve_punct:
- text = punctuation_standardization(text)
- text = tokenizer.EncodeAsIds(text).tokenization if text else []
- return text
-
- @staticmethod
- def trim_field(content, max_length):
- if len(content) > max_length:
- content = content[:max_length]
- content += "......"
- return content
-
- def process_line(self, data, tokenizer, tokenize):
- raise NotImplementedError
-
-
- class PromptReader(DataReader):
- is_json = True
-
- def tokenize_worker(self, input, output, info, tokenizer, tokenize):
- for row in iter(input.get, 'STOP'):
- row = row.rstrip()
- if row:
- if self.is_json:
- row = json.loads(row)
- prompts, texts = self.process_line(row, tokenizer, tokenize)
- for prompt, text in zip(prompts, texts):
- output.put((prompt, text))
- output.put("COMPLETE")
-
- @staticmethod
- def write_result(data, writers):
- prompt, text = data
- writers['prompt'].write(prompt)
- writers['text'].write(text)
-
-
- class KeyReader(DataReader):
- PATH = '/root/data/wikipedia/wiki-key.txt'
- assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py"
-
- def process_line(self, data, tokenizer, tokenize):
- keys, contents = data['key'], data["content"]
- assert len(keys) == len(contents)
- for i in range(1, len(keys)):
- keys[i] = " " + keys[i]
- contents = [" " + content for content in contents]
- keys = [tokenizer.EncodeAsIds(key).tokenization for key in keys]
- contents = [tokenizer.EncodeAsIds(content).tokenization for content in contents]
- summary = sum(keys, [])
- summary_prefix = self.process_sample("Summary: ", tokenizer, tokenize)
- summary_mask = [len(summary_prefix), len(summary)]
- summary = summary_prefix + summary
- text, text_mask = [], []
- for key, content in zip(keys, contents):
- content = content + [tokenizer.get_command('eop').Id]
- text += key
- text += content
- text_mask.append(len(key))
- text_mask.append(len(content))
- return (summary, summary_mask), (text, text_mask)
-
- def tokenize_worker(self, input, output, info, tokenizer, tokenize):
- for row in iter(input.get, 'STOP'):
- data = json.loads(row)
- summary, content = self.process_line(data, tokenizer, tokenize)
- output.put((summary, content))
- output.put("COMPLETE")
-
- @staticmethod
- def write_result(data, writers):
- summary, content = data
- writers['text'].write(summary[0])
- writers['mask'].write(summary[1])
- writers['text'].write(content[0])
- writers['mask'].write(content[1])
-
-
- class zhihu(PromptReader):
- PATH = "/root/data/zhihu/zhihu"
- reserve_punct = True
- assert_str = "make sure to set PATH for zhihu data_utils/corpora.py"
- qtitle_prefix = "问题:"
- qcontent_prefix = "问题描述:"
- user_prefix = "回答用户:"
- answer_prefix = " 回答:"
-
- # qtitle_prefix = []
- # qcontent_prefix = []
- # user_prefix = []
- # answer_prefix = []
-
- def process_line(self, data, tokenizer, tokenize):
- prompts, texts = [], []
- ans_length = len(data.get("ans-content", ""))
- ans_up = data.get("ans-up-num", "")
- ans_up = int(ans_up) if ans_up else 0
- if ans_length > 100 or ans_up > 1000:
- qtitle = data["q_title"]
- qcontent = data["q-content"]
- if qcontent is None:
- qcontent = ""
- qcontent = self.trim_field(qcontent, max_length=100)
- user = data.get("user-signature", "")
- prompt = self.qtitle_prefix + qtitle + self.qcontent_prefix + qcontent + self.user_prefix + user + self.answer_prefix
- text = data["ans-content"]
- prompt, text = self.process_sample(prompt, tokenizer, tokenize), self.process_sample(text, tokenizer,
- tokenize)
- prompts.append(prompt)
- texts.append(text)
- # prompt = data["q_title"] + data["q-content"] + data["user-signature"]
- # text = data["ans-content"]
- # prompts.append(prompt)
- # texts.append(text)
- return prompts, texts
-
-
- class zhidao(PromptReader):
- PATH = "/root/data/zhidao/zhidao"
- reserve_punct = True
- assert_str = "make sure to set PATH for zhidao data_utils/corpora.py"
- qtitle_prefix = "问题:"
- qcontent_prefix = "问题描述:"
- answer_prefix = "回答:"
-
- def process_line(self, data, tokenizer, tokenize):
- if "title" not in data:
- return [], []
- prompts, texts = [], []
- qtitle = data["title"]
- qcontent = data.get("content", "")
- qcontent = self.trim_field(qcontent, max_length=100)
- prompt = self.qtitle_prefix + qtitle + self.qcontent_prefix + qcontent + self.answer_prefix
- prompt = self.process_sample(prompt, tokenizer, tokenize)
- if "best_answer" in data:
- text = data["best_answer"]["content"]
- if len(text) > 10:
- text = self.process_sample(text, tokenizer, tokenize)
- prompts.append(prompt)
- texts.append(text)
- for answer in data.get("other_answers", []):
- text = answer["content"]
- if len(text) > 100:
- text = self.process_sample(text, tokenizer, tokenize)
- prompts.append(prompt)
- texts.append(text)
- return prompts, texts
-
-
- class baike(PromptReader):
- PATH = "/root/data/baike/baike"
- reserve_punct = True
- assert_str = "make sure to set PATH for baike data_utils/corpora.py"
-
- def process_line(self, data, tokenizer, tokenize):
- prompts, texts = [], []
- text = data.get("title", "") + data.get("abstract", "") + data.get("content", "")
- if text:
- p, t = self.process_sample("", tokenizer, tokenize), self.process_sample(text, tokenizer, tokenize)
- prompts.append(p)
- texts.append(t)
- return prompts, texts
-
-
- class wikipedia(PromptReader):
- """
- dataset for wikipedia with arguments configured for convenience
-
- command line usage: `--train-data wikipedia`
- """
- # PATH = '/dataset/data/wiki.txt'
- PATH = '/root/data/bert_data/wiki.txt'
- assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py"
-
- def process_line(self, data, tokenizer, tokenize):
- text = data['text']
- prompt, text = self.process_sample("", tokenizer, tokenize), self.process_sample(text, tokenizer, tokenize)
- return [prompt], [text]
-
-
- class TestDataset(PromptReader):
- PATH = '/root/data/test.json'
- assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py"
-
- def process_line(self, data, tokenizer, tokenize):
- prompt, text = data['prompt'], data['text']
- prompt, text = self.process_sample(prompt, tokenizer, tokenize), self.process_sample(text, tokenizer, tokenize)
- return [prompt], [text]
-
-
- class OpenWebText(PromptReader):
- PATH = '/dataset/fd5061f6/english_data/openwebtext2'
- assert_str = "make sure to set PATH for openwebtext data_utils/corpora.py"
-
- def __init__(self, *args, **kwargs):
- import fasttext
- super().__init__(*args, **kwargs)
- self.model = fasttext.load_model('/dataset/fd5061f6/english_data/lid.176.bin')
- print_rank_0("Load language detection model")
-
- def process_line(self, data, tokenizer, tokenize):
- text = data['text']
- if len(text) > 100:
- lang = self.model.predict(text.replace('\n', ''))[0][0]
- if lang == '__label__en':
- prompt, text = self.process_sample("", tokenizer, tokenize), self.process_sample(text, tokenizer,
- tokenize)
- return [prompt], [text]
- return [], []
-
-
- class CCNews(PromptReader):
- PATH = "/mnt/cc_news.json"
- assert_str = "make sure to set PATH for cc-news data_utils/corpora.py"
-
- def process_line(self, data, tokenizer, tokenize):
- text = ""
- title = data.get("title", None)
- description = data.get("description", None)
- maintext = data.get("maintext", None)
- if title:
- text += title.strip() + " "
- if description and (not maintext or not maintext.startswith(description)):
- text += description.strip() + " "
- if maintext:
- text += maintext
- if len(text) > 100:
- prompt, text = self.process_sample("", tokenizer, tokenize), self.process_sample(text, tokenizer, tokenize)
- return [prompt], [text]
- else:
- return [], []
-
-
- class BertData(PromptReader):
- is_json = False
- PATH = '/dataset/fd5061f6/english_data/wikibook'
-
- def process_line(self, data, tokenizer, tokenize):
- if data:
- prompt, text = "", data
- prompt, text = self.process_sample(prompt, tokenizer, tokenize), self.process_sample(text, tokenizer,
- tokenize)
- return [prompt], [text]
- else:
- return [], []
-
-
- class Pile(PromptReader):
- is_json = True
- PATH = "/mnt/train"
- filtered_sources = ["Github", "StackExchange", "DM Mathematics", "Ubuntu IRC", "EuroParl", "YoutubeSubtitles",
- "Enron Emails"]
- downsample_sources = {"PubMed Central": 0.3, "ArXiv": 0.3, "FreeLaw": 0.3}
-
- def print_info(self, info):
- total_dict = defaultdict(int)
- while True:
- try:
- source_dict = info.get(block=False)
- for source, length in source_dict.items():
- total_dict[source] += length
- except Empty:
- break
- print_rank_0(total_dict)
-
- def tokenize_worker(self, input, output, info, tokenizer, tokenize):
- source_dict = defaultdict(int)
- for row in iter(input.get, 'STOP'):
- row = row.rstrip()
- if row:
- if self.is_json:
- row = json.loads(row)
- prompts, texts, source = self.process_line(row, tokenizer, tokenize)
- length = 0
- for prompt, text in zip(prompts, texts):
- length += len(text)
- output.put((prompt, text))
- if source:
- source_dict[source] += length
- output.put("COMPLETE")
- info.put(source_dict)
-
- def process_line(self, data, tokenizer, tokenize):
- source = data["meta"].get("pile_set_name", None)
- text = data.get("text", None)
- if source and text:
- if source in self.filtered_sources:
- return [], [], None
- elif source in self.downsample_sources and random.random() > self.downsample_sources[source]:
- return [], [], None
- else:
- prompt, text = self.process_sample("", tokenizer, tokenize), self.process_sample(text, tokenizer,
- tokenize)
- return [prompt], [text], source
- else:
- return [], [], None
-
-
- class Stories(PromptReader):
- is_json = True
- PATH = "/dataset/fd5061f6/english_data/stories_31G.jsonl"
-
- def process_line(self, data, tokenizer, tokenize):
- text = data.get("text", None)
- if text:
- prompt, text = self.process_sample("", tokenizer, tokenize), self.process_sample(text, tokenizer,
- tokenize)
- return [prompt], [text]
- else:
- return [], []
-
-
- class BertBaseData(BertData):
- PATH = '/root/data/formatted_one_article_per_line'
-
-
- class BertLargeData(BertData):
- PATH = '/dataset/c07bd62b/cognitive/zhengxiao/formatted_one_article_per_line_large'
-
-
- NAMED_CORPORA = {
- 'wikipedia': wikipedia,
- 'wikipedia-key': KeyReader,
- 'openwebtext': OpenWebText,
- "zhihu": zhihu,
- "zhidao": zhidao,
- "baike": baike,
- "test": TestDataset,
- 'wikibook': BertData,
- "bert-base": BertBaseData,
- "bert-large": BertLargeData,
- 'cc-news': CCNews,
- 'pile': Pile,
- 'stories': Stories
- }
|