|
- # -*- coding: utf-8 -*-
- # file: data_utils.py
- # author: songyouwei <youwei0314@gmail.com>
- # Copyright (C) 2018. All Rights Reserved.
-
- import os
- import pickle
- import numpy as np
- import torch
- from torch.utils.data import Dataset
- # from transformers import BertTokenizer
- from pytorch_pretrained_bert import BertTokenizer
-
-
- def build_tokenizer(fnames, max_seq_len, dat_fname):
- if os.path.exists(dat_fname):
- print('loading tokenizer:', dat_fname)
- tokenizer = pickle.load(open(dat_fname, 'rb'))
- else:
- text = ''
- for fname in fnames:
- fin = open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
- lines = fin.readlines()
- fin.close()
- for i in range(0, len(lines), 3):
- text_raw = lines[i].lower().strip()
- try:
- entity, attribute = lines[i + 1].lower().strip().split()
- except:
- entity = lines[i + 1].lower().strip()
- attribute = ''
- text += text_raw + entity + attribute + " "
-
- tokenizer = Tokenizer(max_seq_len)
- tokenizer.fit_on_text(text)
- pickle.dump(tokenizer, open(dat_fname, 'wb'))
- return tokenizer
-
-
- def _load_word_vec(path, word2idx=None, embed_dim=300):
- fin = open(path, 'r', encoding='utf-8', newline='\n', errors='ignore')
- word_vec = {}
- for line in fin:
- tokens = line.rstrip().split()
- word, vec = ' '.join(tokens[:-embed_dim]), tokens[-embed_dim:]
- if word in word2idx.keys():
- word_vec[word] = np.asarray(vec, dtype='float32')
- return word_vec
-
-
- def build_embedding_matrix(word2idx, embed_dim, dat_fname):
- if os.path.exists(dat_fname):
- print('loading embedding_matrix:', dat_fname)
- embedding_matrix = pickle.load(open(dat_fname, 'rb'))
- else:
- print('loading word vectors...')
- embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim)) # idx 0 and len(word2idx)+1 are all-zeros
- fname = '/home/luowangda/StanceQA/AnswerStance-master/data/nlp_res/embeddings/baike/baike-50.vec.txt' \
- if embed_dim !=300 else "/home/luowangda/StanceQA/AnswerStance-master/data/nlp_res/embeddings/baike/baike-50.vec.txt"
- # if embed_dim != 300 else '/home/luowangda/glove/glove.6B.300d.txt'
-
- word_vec = _load_word_vec(fname, word2idx=word2idx, embed_dim=embed_dim)
- print('building embedding_matrix:', dat_fname)
- for word, i in word2idx.items():
- vec = word_vec.get(word)
- if vec is not None:
- # words not found in embedding index will be all-zeros.
- embedding_matrix[i] = vec
- pickle.dump(embedding_matrix, open(dat_fname, 'wb'))
- return embedding_matrix
-
-
- def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0):
- x = (np.ones(maxlen) * value).astype(dtype)
- if truncating == 'pre':
- trunc = sequence[-maxlen:]
- else:
- trunc = sequence[:maxlen]
- trunc = np.asarray(trunc, dtype=dtype)
- if padding == 'post':
- x[:len(trunc)] = trunc
- else:
- x[-len(trunc):] = trunc
- return x
-
-
- class Tokenizer(object):
- def __init__(self, max_seq_len, lower=True):
- self.lower = lower
- self.max_seq_len = max_seq_len
- self.word2idx = {}
- self.idx2word = {}
- self.idx = 1
-
- def fit_on_text(self, text):
- if self.lower:
- text = text.lower()
- words = text.split()
- for word in words:
- if word not in self.word2idx:
- self.word2idx[word] = self.idx
- self.idx2word[self.idx] = word
- self.idx += 1
-
- def text_to_sequence(self, text, reverse=False, padding='post', truncating='post'):
- if self.lower:
- text = text.lower()
- words = text.split()
- unknownidx = len(self.word2idx)+1
- sequence = [self.word2idx[w] if w in self.word2idx else unknownidx for w in words]
- if len(sequence) == 0:
- sequence = [0]
- if reverse:
- sequence = sequence[::-1]
- return pad_and_truncate(sequence, self.max_seq_len, padding=padding, truncating=truncating)
-
-
- class Tokenizer4Bert:
- def __init__(self, max_seq_len, pretrained_bert_name):
- self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_name)
- self.max_seq_len = max_seq_len
-
- def text_to_sequence(self, text, reverse=False, padding='post', truncating='post'):
- sequence = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
- if len(sequence) == 0:
- sequence = [0]
- if reverse:
- sequence = sequence[::-1]
- return pad_and_truncate(sequence, self.max_seq_len, padding=padding, truncating=truncating)
-
- import pandas as pd
- import jieba
-
- class QADataset(Dataset):
- def __init__(self, fname, tokenizer):
-
- # 标签分布
- stance_dic = {-1: 0, 0: 1, 1: 2}
- # stance_dic_cl = {0: 0, 1: 1, 2: 2}
-
- data_pd = pd.read_csv(fname)
-
- all_data = []
-
- for index, item in data_pd.iterrows(): # 一个句子
-
- ques = item['ques']
- answ = item['answ']
-
- ques = ques[:25]
- answ = answ[:45]
-
- # con = item['con']
- label = item['label']
- if "con_label" in item:
- con_label = item['con_label']
- else:
- con_label = 0
-
- label = stance_dic[label]
- con_label =int(con_label)
-
-
- # con_indices = tokenizer.text_to_sequence(con)
- ques_indices = tokenizer.text_to_sequence(ques)
- answ_indices = tokenizer.text_to_sequence(answ)
-
- _text = " ".join(jieba.cut(ques + " " + answ))
- text_indices = tokenizer.text_to_sequence(_text)
-
- q_text = " ".join(jieba.cut(ques))
- _ques_indices = tokenizer.text_to_sequence(q_text)
-
- a_text = " ".join(jieba.cut(answ))
- _answ_indices = tokenizer.text_to_sequence(a_text)
-
- # print(text_indices)
-
- concat_bert_indices = tokenizer.text_to_sequence('[CLS] ' + ques + ' [SEP] ' + answ + " [SEP]")
- ques_bert_indices = tokenizer.text_to_sequence("[CLS] " + ques + " [SEP]")
- answ_bert_indices = tokenizer.text_to_sequence("[CLS] " + answ + " [SEP]")
-
- qa_len = np.sum(ques_indices != 0)
- an_len = np.sum(answ_indices != 0)
- concat_segments_indices = [0] * (qa_len + 2) + [1] * (an_len + 1)
- ques_mask_indices = [1] * (qa_len + 2)
- answ_mask_indices = [1] * (an_len + 2)
- ques_mask_indices = pad_and_truncate(ques_mask_indices, tokenizer.max_seq_len)
- answ_mask_indices = pad_and_truncate(answ_mask_indices, tokenizer.max_seq_len)
- concat_segments_indices = pad_and_truncate(concat_segments_indices, tokenizer.max_seq_len)
-
- data = {
- 'text_indices':text_indices,
- 'q_indiccs':_ques_indices,
- "a_indices":_answ_indices,
- 'concat_bert_indices': concat_bert_indices,
- 'concat_segments_indices': concat_segments_indices,
- 'ques_bert_indices': ques_bert_indices,
- 'answ_bert_indices': answ_bert_indices,
- "ques_mask_indices":ques_mask_indices,
- "answ_mask_indices":answ_mask_indices,
- 'label': label,
- 'con_label':con_label,
- 'ques':ques,
- 'answ':answ,
- }
- all_data.append(data)
- self.data = all_data
-
- def __getitem__(self, index):
- return self.data[index]
-
- def __len__(self):
- return len(self.data)
|