|
- import json
-
- import pandas as pd
- import torch
- # from datasets import load_dataset
- from torch.utils.data import Dataset
- import json
- from tqdm import tqdm
-
- def load_dataset(path, split):
- discard = 0
- datasets = []
- with open(f"{path}/web_text_zh_{split}.json",encoding="utf-8") as f:
- for i, line in tqdm(enumerate(f)):
- item = json.loads(line)
- sample = {"prompt": item['title'] + item['desc'] +"模型回答:",
- "label":item["content"] }
-
- if len( sample['prompt'] + sample['label'])> 500:
- discard+=1
- else:
- datasets.append(sample)
- if split=="valid" and i == 2000:
- print(f"File: {path}/web_text_zh_{split}.json Num of over-length: {discard}")
- return datasets
- print(f"File: {path}/web_text_zh_{split}.json Num of over-length: {discard}")
- return datasets
-
- def get_dataset_from_jsonl(jsonl_file, return_summary=True):
- # if return_summary is True, return a list of posts with summary concatenated
- # if return_summary is False, return a list of posts and a list of summaries
- with open(jsonl_file, "r") as f:
- dataset = [json.loads(line) for line in f]
- post_list = []
- summary_list = []
- for d in dataset:
- if return_summary:
- post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: {d['summary']}"
- else:
- post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: "
- summary_list.append(d["summary"])
- post_list.append(post)
- if not return_summary:
- return post_list, summary_list
- return post_list
-
-
- class TLDRDataset(Dataset):
- def __init__(self, train_path, tokenizer, split, max_length=550):
- self.post_list = []
- dataset = load_dataset(train_path, split=split)
- for sample in dataset:
- txt = tokenizer.bos_token + sample["prompt"] + sample["label"] + tokenizer.eos_token
- self.post_list.append( txt )
- if "valid" in train_path:
- self.post_list = self.post_list[0:2000]
- self.tokenizer = tokenizer
- self.max_length = max_length
- self.input_ids = []
- self.attn_masks = []
-
- for k in range(5):
- print(f"TLDRDataset sample-{k}\n: {dataset[k]}")
-
- def __len__(self):
- return len(self.post_list)
-
- def __getitem__(self, idx):
- txt = self.post_list[idx]
- encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
- input_ids = torch.tensor(encodings_dict["input_ids"])
- attn_masks = torch.tensor(encodings_dict["attention_mask"])
-
- return {
- "input_ids": input_ids,
- "attention_mask": attn_masks,
- "labels": input_ids,
- }
-
-
- class ComparisonDataset(Dataset):
- def __init__(self, comparison_path, tokenizer, max_length=550):
- with open(comparison_path, "r") as f:
- dataset = [json.loads(line) for line in f]
-
- self.tokenizer = tokenizer
- self.post_list = []
- self.summaries_0 = []
- self.summaries_1 = []
- self.labels = []
- self.max_length = max_length
-
- def make_text(post, summarize):
- return f"SUBREDDIT: r/{post['subreddit']}\nTITLE: {post['title']}\nPOST: {post['post']}\nTL;DR: {summarize}"
-
- for sample in dataset: # chosen summary is always the first one
- self.post_list.append(sample["info"]["post"])
- # NOTE: The chosen summary is always the first one, i.e. `sample["summaries"][0]`
- if sample["choice"] == 0:
- self.summaries_0.append(make_text(sample["info"], sample["summaries"][0]["text"]))
- self.summaries_1.append(make_text(sample["info"], sample["summaries"][1]["text"]))
- else:
- self.summaries_0.append(make_text(sample["info"], sample["summaries"][1]["text"]))
- self.summaries_1.append(make_text(sample["info"], sample["summaries"][0]["text"]))
- self.labels.append(0)
-
- def __len__(self):
- return len(self.post_list)
-
- def __getitem__(self, idx):
- summ0 = self.summaries_0[idx]
- summ1 = self.summaries_1[idx]
- encodings_dict = self.tokenizer(
- [summ0, summ1],
- truncation=True,
- max_length=self.max_length,
- padding="max_length",
- )
- input_ids = torch.tensor(encodings_dict["input_ids"])
- attention_mask = torch.tensor(encodings_dict["attention_mask"])
- return {"input_ids": input_ids, "attention_mask": attention_mask}
-
-
- class AllSummDataset(Dataset):
- def __init__(self, train_path, tokenizer, split, max_length=1024):
- df = pd.read_parquet(train_path)
- if split == "valid":
- df = df.sample(n=5000)
- self.summarizes = []
- for i, row in df.iterrows():
- self.summarizes.append(f"Summarize: {row['text']}. TL;DR: {row['summary']}")
- self.tokenizer = tokenizer
- self.max_length = max_length
- self.input_ids = []
- self.attn_masks = []
-
- def __len__(self):
- return len(self.summarizes)
-
- def __getitem__(self, idx):
- txt = self.summarizes[idx]
- encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length")
- input_ids = torch.tensor(encodings_dict["input_ids"])
- attn_masks = torch.tensor(encodings_dict["attention_mask"])
-
- return {
- "input_ids": input_ids,
- "attention_mask": attn_masks,
- "labels": input_ids,
- }
|