|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/9/5
- -----------------------------------
- """
- from tqdm import tqdm
- import json
-
- import torch
- from transformers import BertTokenizerFast, RobertaTokenizerFast
- from torch.utils.data.dataset import Dataset
- from torch.utils.data.dataloader import DataLoader
-
- from config import args, logger
-
-
- # define a random dataset
- class TextClassificationDataset(Dataset):
- def __init__(self, pretrain_dir):
- super().__init__()
- self.input_ids = []
- self.line_ids = []
- self.labels = []
- self.attention_mask = []
-
- self.tokenizer = BertTokenizerFast.from_pretrained(pretrain_dir)
- # self.tokenizer = RobertaTokenizerFast.from_pretrained("/userhome/pretrain/roberta-base")
- self.label2id = {}
- self.id2label = {}
-
- def build_label_index(self):
- self.label2id['0'] = 0
- self.label2id['1'] = 1
- self.id2label[0] = 0
- self.id2label[1] = 1
-
- def read_file(self, data_path, max_length):
- data = []
- with open(data_path, encoding='utf-8') as f:
- for line in f.readlines():
- line = line.strip()
- if not line:
- continue
- data.append(line)
- data.pop(0)
-
- error_line_nums = 0
- for line in tqdm(data, desc="processing data", leave=False):
- text, label = line.split('\t')
- if label not in self.label2id:
- error_line_nums += 1
- continue
-
- encode_text = self.tokenizer(text,
- max_length=max_length,
- padding='max_length',
- truncation=True,
- return_tensors='pt')
-
- self.input_ids.append(encode_text['input_ids'])
- self.attention_mask.append(encode_text['attention_mask'])
-
- self.labels.append(torch.tensor([self.label2id[label]], dtype=torch.int64))
-
- self.line_ids = list(range(len(self.input_ids)))
- logger.info(f"共{error_line_nums}条数据标签错误")
-
- def __getitem__(self, idx):
- return {
- "input_ids": self.input_ids[idx],
- "line_ids": self.line_ids[idx],
- "attention_mask": self.attention_mask[idx],
- "labels": self.labels[idx],
- }
-
- @property
- def label_nums(self):
- return len(self.label2id)
-
- def __len__(self):
- return len(self.input_ids)
-
- @staticmethod
- def collate(batch_data):
- batch_input_ids = torch.cat([data["input_ids"] for data in batch_data], dim=0)
- batch_attention_mask = torch.cat([data["attention_mask"] for data in batch_data], dim=0)
- batch_line_ids = [data["line_ids"] for data in batch_data]
- batch_labels = torch.cat([data["labels"] for data in batch_data], dim=0)
-
- return {
- "batch_input_ids": batch_input_ids,
- "batch_line_ids": batch_line_ids,
- "batch_labels": batch_labels,
- "batch_attention_mask": batch_attention_mask,
- }
-
-
- if __name__ == '__main__':
- train_data = TextClassificationDataset(args.pretrain_dir)
- train_data.build_label_index()
- train_data.read_file(args.train_path, max_length=30)
- train_loader = DataLoader(train_data, batch_size=3, collate_fn=train_data.collate, shuffle=True)
-
- # print(len(train_data))
- # print(len(train_loader))
-
- for batch in train_loader:
- print(batch["batch_input_ids"])
- print(batch["batch_line_ids"])
- break
|