|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/9/5
- -----------------------------------
- """
- from tqdm import tqdm
- import json
-
- import torch
- from torch.utils.data.dataset import Dataset
- from torch.utils.data.dataloader import DataLoader
- from transformers import DataCollatorForSeq2Seq
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
-
- from config import args, logger
-
-
- # define a random dataset
- class TranslationDataset(Dataset):
- def __init__(self, tokenizer):
- super().__init__()
- self.input_ids = []
- self.attention_mask = []
- self.labels = []
-
- self.tokenizer = tokenizer
- self.data_collator = DataCollatorForSeq2Seq(self.tokenizer)
-
- def read_file(self, src_path, tgt_path):
- src = []
- tgt = []
- with open(src_path, encoding='utf-8') as f:
- for line in f.readlines():
- line = line.strip()
- if not line:
- continue
- src.append(line)
-
- with open(tgt_path, encoding='utf-8') as f:
- for line in f.readlines():
- line = line.strip()
- if not line:
- continue
- tgt.append(line)
-
- over_size = 0
- for src_text, tgt_text in tqdm(zip(src, tgt), desc="processing data", leave=False):
- encode_text = self.tokenizer(src_text, text_target=tgt_text)
-
- if len(encode_text['input_ids']) > 1024 or len(encode_text['labels']) > 1024:
- print(src_text)
- continue
-
- self.input_ids.append(encode_text['input_ids'])
- self.attention_mask.append(encode_text['attention_mask'])
- self.labels.append(encode_text['labels'])
-
- def __getitem__(self, idx):
- return {
- "input_ids": self.input_ids[idx],
- "attention_mask": self.attention_mask[idx],
- "labels": self.labels[idx]
- }
-
- def __len__(self):
- return len(self.input_ids)
-
- # def collate(self, features):
- #
- # labels = [feature["labels"] for feature in features]
- # # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
- # # same length to return tensors.
- # max_label_length = max(len(l) for l in labels)
- #
- # for feature in features:
- # remainder = [-100] * (max_label_length - len(feature["labels"]))
- # feature["labels"] = (
- # feature["labels"] + remainder
- # )
- #
- # features = self.tokenizer.pad(features, return_tensors="pt", )
- #
- # batch_input_ids = torch.cat([feature["input_ids"] for feature in features], dim=0)
- # batch_attention_mask = torch.cat([feature["attention_mask"] for feature in features], dim=0)
- # batch_labels = torch.cat([feature["labels"] for feature in features], dim=0)
- #
- # return {
- # "batch_input_ids": batch_input_ids,
- # "batch_attention_mask": batch_attention_mask,
- # "batch_labels": batch_labels,
- # }
-
-
- if __name__ == '__main__':
- tokenizer = MBart50TokenizerFast.from_pretrained(args.pretrain_dir, src_lang="zh_CN", tgt_lang="en_XX")
- train_data = TranslationDataset(tokenizer)
- train_data.read_file(args.src_path, args.tgt_path)
- train_loader = DataLoader(train_data, batch_size=2, collate_fn=train_data.data_collator)
-
- # print(len(train_data))
- # print(len(train_loader))
-
- for batch in train_loader:
- # print(batch["batch_input_ids"])
- print(batch)
- break
|