|
- from six import iteritems
-
- import json
- import os
- import multiprocessing
- import numpy as np
- import random
- import math
-
-
- class FileDataLoader:
- def __next__(self):
- raise NotImplementedError
-
- def next(self):
- return self.__next__()
-
- def next_batch(self, batch_size):
- raise NotImplementedError
-
-
- class JsonFileDataLoader(FileDataLoader):
- MODE_INSTANCE = 0 # One batch contains batch_size instances.
- MODE_ENTPAIR_BAG = 1 # One batch contains batch_size bags, instances in which have the same entity pair (usually for testing).
- MODE_RELFACT_BAG = 2 # One batch contains batch size bags, instances in which have the same relation fact. (usually for training).
-
- def _load_preprocessed_file(self):
- name_prefix = '.'.join(self.file_name.split('/')[-1].split('.')[:-1])
- word_vec_name_prefix = '.'.join(self.word_vec_file_name.split('/')[-1].split('.')[:-1])
- processed_data_dir = 'medicine_rl_pacnn_train_data'
- if not os.path.isdir(processed_data_dir):
- return False
- word_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_word.npy')
- pmc_word_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pmc_word.npy')
- pubmed_word_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pubmed_word.npy')
- pubmed_and_pmc_word_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pubmed_and_pmc_word.npy')
- pubmed_myself_word_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pubmed_myself_word.npy')
- wiki_pubmed_word_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_wiki_pubmed_word.npy')
-
- pos1_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos1.npy')
- pos2_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos2.npy')
- rel_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_rel.npy')
- mask_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_mask.npy')
- length_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_length.npy')
- entpair2scope_file_name = os.path.join(processed_data_dir, name_prefix + '_entpair2scope.json')
- relfact2scope_file_name = os.path.join(processed_data_dir, name_prefix + '_relfact2scope.json')
-
- pmc_word_vec_mat_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_pmc_mat.npy')
- pubmed_word_vec_mat_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_mat.npy')
- pubmed_and_pmc_word_vec_mat_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_and_pmc_mat.npy')
- pubmed_myself_word_vec_mat_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_myself_mat.npy')
- wiki_pubmed_word_vec_mat_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_wiki_pubmed_mat.npy')
- word_vec_mat_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy')
-
- pmc_word2id_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_pmc_word2id.json')
- sentences_file_name = os.path.join(processed_data_dir, name_prefix + '_sentences.json')
- pubmed_word2id_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_word2id.json')
- pubmed_myself_word2id_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_myself_word2id.json')
- pubmed_and_pmc_word2id_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_and_pmc_word2id.json')
- wiki_pubmed_word2id_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_wiki_pubmed_word2id.json')
- word2id_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json')
-
- if not os.path.exists(word_npy_file_name) or \
- not os.path.exists(pmc_word_npy_file_name) or \
- not os.path.exists(pubmed_word_npy_file_name) or \
- not os.path.exists(pubmed_and_pmc_word_npy_file_name) or \
- not os.path.exists(pubmed_myself_word_npy_file_name) or \
- not os.path.exists(wiki_pubmed_word_npy_file_name) or \
- not os.path.exists(pos1_npy_file_name) or \
- not os.path.exists(pos2_npy_file_name) or \
- not os.path.exists(rel_npy_file_name) or \
- not os.path.exists(mask_npy_file_name) or \
- not os.path.exists(length_npy_file_name) or \
- not os.path.exists(entpair2scope_file_name) or \
- not os.path.exists(relfact2scope_file_name) or \
- not os.path.exists(word_vec_mat_file_name) or \
- not os.path.exists(word2id_file_name) or \
- not os.path.exists(sentences_file_name):
- return False
- print("Pre-processed files exist. Loading them...")
- self.data_word = np.load(word_npy_file_name)
- self.sentences = json.load(open(sentences_file_name))
-
- self.data_word_pmc = np.load(pmc_word_npy_file_name)
- self.data_word_pubmed = np.load(pubmed_word_npy_file_name)
- self.data_word_pubmed_and_pmc = np.load(pubmed_and_pmc_word_npy_file_name)
- self.data_word_pubmed_myself = np.load(pubmed_myself_word_npy_file_name)
- self.data_word_wiki_pubmed = np.load(wiki_pubmed_word_npy_file_name)
-
- self.data_pos1 = np.load(pos1_npy_file_name)
- self.data_pos2 = np.load(pos2_npy_file_name)
- self.data_rel = np.load(rel_npy_file_name)
- self.data_mask = np.load(mask_npy_file_name)
- self.data_length = np.load(length_npy_file_name)
- self.entpair2scope = json.load(open(entpair2scope_file_name))
- self.relfact2scope = json.load(open(relfact2scope_file_name))
-
- self.word_vec_mat = np.load(word_vec_mat_file_name)
- self.word2id = json.load(open(word2id_file_name))
-
- self.word_vec_mat_pmc = np.load(pmc_word_vec_mat_file_name)
- self.word2id_pmc = json.load(open(pmc_word2id_file_name))
-
- self.word_vec_mat_pubmed = np.load(pubmed_word_vec_mat_file_name)
- self.word2id_pubmed = json.load(open(pubmed_word2id_file_name))
-
- self.word_vec_mat_pubmed_and_pmc = np.load(pubmed_and_pmc_word_vec_mat_file_name)
- self.word2id_pubmed_and_pmc = json.load(open(pubmed_and_pmc_word2id_file_name))
-
- self.word_vec_mat_pubmed_myself = np.load(pubmed_myself_word_vec_mat_file_name)
- self.word2id_pubmed_myself = json.load(open(pubmed_myself_word2id_file_name))
-
- self.word_vec_mat_wiki_pubmed = np.load(wiki_pubmed_word_vec_mat_file_name)
- self.word2id_wiki_pubmed = json.load(open(wiki_pubmed_word2id_file_name))
-
- if self.data_word.shape[1] != self.max_length \
- or self.data_word_pmc.shape[1] != self.max_length \
- or self.data_word_pubmed.shape[1] != self.max_length \
- or self.data_word_pubmed_and_pmc.shape[1] != self.max_length \
- or self.data_word_pubmed_myself.shape[1] != self.max_length \
- or self.data_word_wiki_pubmed.shape[1] != self.max_length:
- print("Pre-processed files don't match current settings. Reprocessing...")
- return False
- print("Finish loading")
- return True
-
- def get_glove_w2v(self, case_sensitive):
- """
- load glove embedding
- :return:
- """
-
- ori_word_vec = json.load(open(self.word_vec_file_name, "r"))
-
- # Pre-process word vec
- word2id = {}
- word_vec_tot = len(ori_word_vec)
-
- word_vec_dim = len(ori_word_vec[0]['vec'])
-
- UNK = word_vec_tot
- BLANK = word_vec_tot + 1
-
- # 400000, 50
- print("Got {} words of {} dims".format(word_vec_tot, word_vec_dim))
- print("Building word vector matrix and mapping...")
-
- # shape: (400000, 50), 全是0
- word_vec_mat = np.zeros((word_vec_tot, word_vec_dim), dtype=np.float32)
-
- for cur_id, word in enumerate(ori_word_vec):
-
- w = word['word']
- if not case_sensitive:
- w = w.lower()
- word2id[w] = cur_id
- word_vec_mat[cur_id, :] = word['vec']
-
- word2id['UNK'] = UNK
- word2id['BLANK'] = BLANK
-
- return UNK, BLANK, word2id, word_vec_mat
-
- def get_other_mcnn_w2v(self, case_sensitive, w2v_name):
- """
- load mcnn pretrain embedding
- :return:
- """
- ori_word_vec = dict()
- word2id = {}
-
- word_vec_dim = 200
- with open(w2v_name) as f:
- for i, line in enumerate(f):
- words = str(line).strip("\r\n").split(" ")
- w = words[0]
-
- if not case_sensitive:
- w = w.lower()
-
- ori_word_vec[w] = [float(words[i]) for i in range(1, len(words))]
- word2id[w] = i
- if i == 0:
- word_vec_dim = len(ori_word_vec[w])
-
- # Pre-process word vec
-
- word_vec_tot = len(ori_word_vec)
-
- UNK = word_vec_tot
- BLANK = word_vec_tot + 1
- drug1 = word_vec_tot + 2
- drug2 = word_vec_tot + 3
- drug0 = word_vec_tot + 4
-
- word_vec_mat = np.zeros((word_vec_tot, word_vec_dim), dtype=np.float32)
- for k, v in word2id.items():
- print(k, v)
- word_vec_mat[v, :] = ori_word_vec[k]
-
- word2id['UNK'] = UNK
- word2id['BLANK'] = BLANK
- word2id["drug1"] = drug1
- word2id["drug2"] = drug2
- word2id["drug0"] = drug0
-
- return UNK, BLANK, word2id, word_vec_mat
-
- def __init__(self,
- file_name,
- word_vec_file_name,
- word_vec_pmc_name,
- word_vec_pubmed_name,
- word_vec_pubmed_and_pmc_name,
- word_vec_pubmed_myself_name,
- word_vec_wiki_pubmed_name,
- rel2id_file_name,
- mode,
- shuffle=True,
- max_length=150,
- case_sensitive=True,
- reprocess=False,
- batch_size=30):
-
- self.file_name = file_name
- self.word_vec_file_name = word_vec_file_name
- self.word_vec_pmc_name = word_vec_pmc_name
- self.word_vec_pubmed_name = word_vec_pubmed_name
- self.word_vec_pubmed_and_pmc_name = word_vec_pubmed_and_pmc_name
- self.word_vec_pubmed_myself_name = word_vec_pubmed_myself_name
- self.word_vec_wiki_pubmed_name = word_vec_wiki_pubmed_name
-
- self.case_sensitive = case_sensitive
- self.max_length = max_length
- self.mode = mode
- self.shuffle = shuffle
- self.batch_size = batch_size
- self.rel2id = json.load(open(rel2id_file_name))
-
- if reprocess or not self._load_preprocessed_file(): # Try to load pre-processed files:
- # Check files
- if file_name is None or not os.path.isfile(file_name):
- raise Exception("[ERROR] Data file doesn't exist")
- if word_vec_file_name is None or not os.path.isfile(word_vec_file_name):
- raise Exception("[ERROR] Word vector file doesn't exist")
-
- # Load files
- print("Loading data file...")
- self.ori_data = json.load(open(self.file_name, "r")) # self.file_name: "./data/medline20191011t/train_treat.json"
- print("Finish loading")
-
- # Eliminate case sensitive
- if not case_sensitive:
- print("Elimiating case sensitive problem...")
- for i in range(len(self.ori_data)):
- self.ori_data[i]['sentence'] = self.ori_data[i]['sentence'].lower()
- self.ori_data[i]['head']['word'] = self.ori_data[i]['head']['word'].lower()
- self.ori_data[i]['tail']['word'] = self.ori_data[i]['tail']['word'].lower()
- print("Finish eliminating")
-
- # Sort data by entities and relations
- print("Sort data...")
- self.ori_data.sort(key=lambda a: a['head']['id'] + '#' + a['tail']['id'] + '#' + a['relation'])
- print("Finish sorting")
-
- # load w2v
-
- UNK, BLANK, self.word2id, self.word_vec_mat = self.get_glove_w2v(case_sensitive)
-
- UNK_pmc, BLANK_pmc, self.word2id_pmc, self.word_vec_mat_pmc = self.get_other_mcnn_w2v(case_sensitive,
- self.word_vec_pmc_name)
-
-
- UNK_pubmed, BLANK_pubmed, self.word2id_pubmed, self.word_vec_mat_pubmed = self.get_other_mcnn_w2v(case_sensitive,
- self.word_vec_pubmed_name)
-
- UNK_pubmed_and_pmc, BLANK_pubmed_and_pmc, self.word2id_pubmed_and_pmc, \
- self.word_vec_mat_pubmed_and_pmc = self.get_other_mcnn_w2v(case_sensitive,
- self.word_vec_pubmed_and_pmc_name)
-
- UNK_pubmed_myself, BLANK_pubmed_myself, self.word2id_pubmed_myself, \
- self.word_vec_mat_pubmed_myself = self.get_other_mcnn_w2v(case_sensitive,
- self.word_vec_pubmed_myself_name)
-
- UNK_wiki_pubmed, BLANK_wiki_pubmed, self.word2id_wiki_pubmed, \
- self.word_vec_mat_wiki_pubmed = self.get_other_mcnn_w2v(case_sensitive,
- self.word_vec_wiki_pubmed_name)
-
- # Pre-process data
- print("Pre-processing data...")
- self.instance_tot = len(self.ori_data) # 16000
- self.entpair2scope = {} # (head, tail) -> scope
- self.relfact2scope = {} # (head, tail, relation) -> scope
- self.data_word = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) # (16000, 120)
- self.sentences = list()
- self.data_word_pmc = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) # (16000, 120)
- self.data_word_pubmed = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) # (16000, 120)
- self.data_word_pubmed_and_pmc = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) # (16000, 120)
- self.data_word_pubmed_myself = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) # (16000, 120)
- self.data_word_wiki_pubmed = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) # (16000, 120)
-
- self.data_pos1 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) # (16000, 120)
- self.data_pos2 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) # (16000, 120)
- self.data_rel = np.zeros((self.instance_tot), dtype=np.int32) # (16000,)
- self.data_mask = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) # (16000, 120)
- self.data_length = np.zeros((self.instance_tot), dtype=np.int32) # (16000,)
- last_entpair = ''
- last_entpair_pos = -1
- last_relfact = ''
- last_relfact_pos = -1
- for i in range(self.instance_tot):
-
- ins = self.ori_data[i] # {'head': {'id': 'C0007584', 'type': '', 'word': 'cells'}, 'tail': {'id': 'C0007634', 'type': '', 'word': 'cells'}, 'relation': 'NA', 'sentence': 'cells in lamina 5 respond as though many cells of lamina 4 converge on them'}
-
- #if "drug0" in ins["replace_sentence"]:
- # continue
-
- # add filter, if filter_result == true, rel=NA
- add_filter = True
- filter_result = self.filter(ins)
- if add_filter and filter_result:
- self.data_rel[i] = self.rel2id['NA']
- else:
- # haven't used filter part
- if ins['relation'] in self.rel2id:
- self.data_rel[i] = self.rel2id[ins['relation']] # 0, 1, 2 ...
- else:
- self.data_rel[i] = self.rel2id['NA'] # 0
-
- sentence = ins["replace_sentence"]
- self.sentences.append({"sentence": ins["sentence"], "head": ins["head"]["word"], "tail": ins["tail"]["word"]})
-
- # this may explain the frequent relapse of duodenal ulcers after cimetidine treatment .
- head = ins['head']['word']
- tail = ins['tail']['word']
- cur_entpair = ins['head']['id'] + '#' + ins['tail']['id'] # C0013295#C0008783
- cur_relfact = ins['head']['id'] + '#' + ins['tail']['id'] + '#' + ins['relation'] # C0013295#C0008783#may_treat
- if cur_entpair != last_entpair:
- if last_entpair != '':
- self.entpair2scope[last_entpair] = [last_entpair_pos, i] # left closed right open
- last_entpair = cur_entpair
- last_entpair_pos = i
- if cur_relfact != last_relfact:
- if last_relfact != '' and (i - last_relfact_pos > 0):
- self.relfact2scope[last_relfact] = [last_relfact_pos, i]
- last_relfact = cur_relfact
- last_relfact_pos = i
-
- # p1 = sentence.find(head)
- # p2 = sentence.find(tail)
- p1 = int(ins['head']['replace_begin'])
- p2 = int(ins['tail']['replace_begin'])
-
- words = ""
- if '@@@@@' not in sentence:
- words = sentence.split('@@')
- else:
- err_pos = sentence.index('@@@@@')
- err_before = sentence[:err_pos]
- err_before_words = err_before.split('@@')
- err = sentence[err_pos:err_pos+5]
- err_words = ['@']
- err_after = sentence[err_pos+5:]
- err_after_words = err_after.split('@@')
- words = err_before_words + err_words + err_after_words
-
- cur_ref_data_word = self.data_word[i]
- cur_ref_data_word_pmc = self.data_word_pmc[i]
- cur_ref_data_word_pubmed = self.data_word_pubmed[i]
- cur_ref_data_word_pubmed_and_pmc = self.data_word_pubmed_and_pmc[i]
- cur_ref_data_word_pubmed_myself = self.data_word_pubmed_myself[i]
- cur_ref_data_word_wiki_pubmed = self.data_word_wiki_pubmed[i]
-
- pos1 = -1
- pos2 = -1
- cur_pos = 0
-
- for j, word in enumerate(words):
- if j < max_length:
- if word in self.word2id:
- cur_ref_data_word[j] = self.word2id[word]
- else:
- cur_ref_data_word[j] = UNK
-
- if word in self.word2id_pmc:
- cur_ref_data_word_pmc[j] = self.word2id_pmc[word]
- else:
- cur_ref_data_word_pmc[j] = UNK_pmc
-
- if word in self.word2id_pubmed:
- cur_ref_data_word_pubmed[j] = self.word2id_pubmed[word]
- else:
- cur_ref_data_word_pubmed[j] = UNK_pubmed
-
- if word in self.word2id_pubmed_and_pmc:
- cur_ref_data_word_pubmed_and_pmc[j] = self.word2id_pubmed_and_pmc[word]
- else:
- cur_ref_data_word_pubmed_and_pmc[j] = UNK_pubmed_and_pmc
-
- if word in self.word2id_pubmed_myself:
- cur_ref_data_word_pubmed_myself[j] = self.word2id_pubmed_myself[word]
- else:
- cur_ref_data_word_pubmed_myself[j] = UNK_pubmed_myself
-
- if word in self.word2id_wiki_pubmed:
- cur_ref_data_word_wiki_pubmed[j] = self.word2id_wiki_pubmed[word]
- else:
- cur_ref_data_word_wiki_pubmed[j] = UNK_wiki_pubmed
-
- if cur_pos == p1:
- pos1 = j
- p1 = -1
-
- if cur_pos == p2:
- pos2 = j
- p2 = -1
- cur_pos += len(word) + 2
-
- for num in range(j + 1, max_length):
- cur_ref_data_word[num] = BLANK
- cur_ref_data_word_pmc[num] = BLANK_pmc
- cur_ref_data_word_pubmed[num] = BLANK_pubmed
- cur_ref_data_word_pubmed_and_pmc[num] = BLANK_pubmed_and_pmc
- cur_ref_data_word_pubmed_myself[num] = BLANK_pubmed_myself
- cur_ref_data_word_wiki_pubmed[num] = BLANK_wiki_pubmed
-
- self.data_length[i] = len(words)
- if len(words) > max_length:
- self.data_length[i] = max_length
-
- if pos1 == -1 or pos2 == -1:
- # print(pos1, pos2)
- if pos1 == -1:
- pos1 = 0
- if pos2 == -1:
- pos2 = 0
- print("[ERROR] Position error, index = {}, sentence = {}, head = {}, tail = {}".format(i, sentence, head, tail))
-
- if pos1 >= max_length:
- pos1 = max_length - 1
- if pos2 >= max_length:
- pos2 = max_length - 1
-
- pos_min = min(pos1, pos2)
- pos_max = max(pos1, pos2)
-
- for j in range(max_length):
- self.data_pos1[i][j] = j - pos1 + max_length
- self.data_pos2[i][j] = j - pos2 + max_length
- if j >= self.data_length[i]:
- self.data_mask[i][j] = 0
- elif j <= pos_min:
- self.data_mask[i][j] = 1
- elif j <= pos_max:
- self.data_mask[i][j] = 2
- else:
- self.data_mask[i][j] = 3
-
- if last_entpair != '':
- self.entpair2scope[last_entpair] = [last_entpair_pos, self.instance_tot] # left closed right open
- if last_relfact != '' and (self.instance_tot - last_relfact_pos > 0):
- self.relfact2scope[last_relfact] = [last_relfact_pos, self.instance_tot]
-
- print("Finish pre-processing")
-
- print("Storing processed files...")
- name_prefix = '.'.join(os.path.split(file_name)[-1].split('.')[:-1])
- word_vec_name_prefix = '.'.join(os.path.split(word_vec_file_name)[-1].split('.')[:-1])
- processed_data_dir = 'medicine_rl_pacnn_train_data'
- if not os.path.isdir(processed_data_dir):
- os.mkdir(processed_data_dir)
- np.save(os.path.join(processed_data_dir, name_prefix + '_word.npy'), self.data_word)
- np.save(os.path.join(processed_data_dir, name_prefix + '_pmc_word.npy'), self.data_word_pmc)
- np.save(os.path.join(processed_data_dir, name_prefix + '_pubmed_word.npy'), self.data_word_pubmed)
- np.save(os.path.join(processed_data_dir, name_prefix + '_pubmed_and_pmc_word.npy'),
- self.data_word_pubmed_and_pmc)
- np.save(os.path.join(processed_data_dir, name_prefix + '_pubmed_myself_word.npy'),
- self.data_word_pubmed_myself)
- np.save(os.path.join(processed_data_dir, name_prefix + '_wiki_pubmed_word.npy'),
- self.data_word_wiki_pubmed)
-
- np.save(os.path.join(processed_data_dir, name_prefix + '_pos1.npy'), self.data_pos1)
- np.save(os.path.join(processed_data_dir, name_prefix + '_pos2.npy'), self.data_pos2)
- np.save(os.path.join(processed_data_dir, name_prefix + '_rel.npy'), self.data_rel)
- np.save(os.path.join(processed_data_dir, name_prefix + '_mask.npy'), self.data_mask)
- np.save(os.path.join(processed_data_dir, name_prefix + '_length.npy'), self.data_length)
- json.dump(self.entpair2scope, open(os.path.join(processed_data_dir, name_prefix + '_entpair2scope.json'), 'w'))
- json.dump(self.relfact2scope, open(os.path.join(processed_data_dir, name_prefix + '_relfact2scope.json'), 'w'))
-
- np.save(os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy'), self.word_vec_mat)
- np.save(os.path.join(processed_data_dir, word_vec_name_prefix + '_pmc_mat.npy'), self.word_vec_mat_pmc)
- np.save(os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_mat.npy'), self.word_vec_mat_pubmed)
- np.save(os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_and_pmc_mat.npy'),
- self.word_vec_mat_pubmed_and_pmc)
- np.save(os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_myself_mat.npy'),
- self.word_vec_mat_pubmed_myself)
- np.save(os.path.join(processed_data_dir, word_vec_name_prefix + '_wiki_pubmed_mat.npy'),
- self.word_vec_mat_wiki_pubmed)
-
-
- json.dump(self.word2id, open(os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json'), 'w'))
- json.dump(self.word2id_pmc, open(os.path.join(processed_data_dir, word_vec_name_prefix + '_pmc_word2id.json'), 'w'))
- json.dump(self.word2id_pubmed, open(os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_word2id.json'), 'w'))
- json.dump(self.word2id_pubmed_and_pmc, open(os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_and_pmc_word2id.json'), 'w'))
- json.dump(self.word2id_pubmed_myself, open(os.path.join(processed_data_dir, word_vec_name_prefix + '_pubmed_myself_word2id.json'), 'w'))
- json.dump(self.word2id_wiki_pubmed, open(os.path.join(processed_data_dir, word_vec_name_prefix + '_wiki_pubmed_word2id.json'), 'w'))
- json.dump(self.sentences, open(os.path.join(processed_data_dir, name_prefix + '_sentences.json'), 'w'))
-
- print("Finish storing")
-
- # Prepare for idx
- self.instance_tot = self.data_word.shape[0]
- self.entpair_tot = len(self.entpair2scope)
- self.relfact_tot = 0 # The number of relation facts, without NA.
- for key in self.relfact2scope:
- if key[-2:] != 'NA':
- self.relfact_tot += 1
- self.rel_tot = len(self.rel2id)
-
- if self.mode == self.MODE_INSTANCE:
- self.order = list(range(self.instance_tot))
- elif self.mode == self.MODE_ENTPAIR_BAG:
- self.order = list(range(len(self.entpair2scope)))
- self.scope_name = []
- self.scope = []
- for key, value in iteritems(self.entpair2scope):
- self.scope_name.append(key)
- self.scope.append(value)
- elif self.mode == self.MODE_RELFACT_BAG:
- self.order = list(range(len(self.relfact2scope)))
- self.scope_name = []
- self.scope = []
- for key, value in iteritems(self.relfact2scope):
- self.scope_name.append(key)
- self.scope.append(value)
- else:
- raise Exception("[ERROR] Invalid mode")
- self.idx = 0
-
- if self.shuffle:
- random.shuffle(self.order)
-
- print("Total relation fact: %d" % (self.relfact_tot))
-
- def __iter__(self):
- return self
-
- def __next__(self):
- return self.next_batch(self.batch_size)
-
- def next_batch(self, batch_size):
- if self.idx >= len(self.order):
- self.idx = 0
- if self.shuffle:
- random.shuffle(self.order)
- raise StopIteration
-
- batch_data = {}
-
- if self.mode == self.MODE_INSTANCE:
- idx0 = self.idx
- idx1 = self.idx + batch_size
- if idx1 > len(self.order):
- idx1 = len(self.order)
- self.idx = idx1
- batch_data['word'] = self.data_word[idx0:idx1]
- batch_data['sentences'] = self.sentences[idx0:idx1]
- batch_data['pmc_word'] = self.data_word_pmc[idx0:idx1]
- batch_data['pubmed_word'] = self.data_word_pubmed[idx0:idx1]
- batch_data['pubmed_and_pmc_word'] = self.data_word_pubmed_and_pmc[idx0:idx1]
- batch_data['pubmed_myself_word'] = self.data_word_pubmed_myself[idx0:idx1]
- batch_data['wiki_pubmed_word'] = self.data_word_wiki_pubmed[idx0:idx1]
-
- batch_data['pos1'] = self.data_pos1[idx0:idx1]
- batch_data['pos2'] = self.data_pos2[idx0:idx1]
- batch_data['rel'] = self.data_rel[idx0:idx1]
- batch_data['mask'] = self.data_mask[idx0:idx1]
- batch_data['length'] = self.data_length[idx0:idx1]
- batch_data['scope'] = np.stack([list(range(batch_size)), list(range(1, batch_size + 1))], axis=1)
- batch_data['ins_rel'] = self.data_rel[idx0:idx1]
- if idx1 - idx0 < batch_size:
- padding = batch_size - (idx1 - idx0)
- batch_data['word'] = np.concatenate([batch_data['word'], np.zeros((padding, self.data_word.shape[-1]), dtype=np.int32)])
- for _ in range(padding):
- batch_data['sentences'].append({"head": "", "tail": "", "sentence": ""})
-
- batch_data['pmc_word'] = np.concatenate([batch_data['pmc_word'], np.zeros((padding, self.data_word_pmc.shape[-1]), dtype=np.int32)])
- batch_data['pubmed_word'] = np.concatenate([batch_data['pubmed_word'], np.zeros((padding, self.data_word_pubmed.shape[-1]), dtype=np.int32)])
- batch_data['pubmed_and_pmc_word'] = np.concatenate([batch_data['pubmed_and_pmc_word'], np.zeros((padding, self.data_word_pubmed_and_pmc.shape[-1]), dtype=np.int32)])
- batch_data['pubmed_myself_word'] = np.concatenate([batch_data['pubmed_myself_word'], np.zeros((padding, self.data_word_pubmed_myself.shape[-1]), dtype=np.int32)])
- batch_data['wiki_pubmed_word'] = np.concatenate([batch_data['wiki_pubmed_word'], np.zeros((padding, self.data_word_wiki_pubmed.shape[-1]), dtype=np.int32)])
-
- batch_data['pos1'] = np.concatenate([batch_data['pos1'], np.zeros((padding, self.data_pos1.shape[-1]), dtype=np.int32)])
- batch_data['pos2'] = np.concatenate([batch_data['pos2'], np.zeros((padding, self.data_pos2.shape[-1]), dtype=np.int32)])
- batch_data['mask'] = np.concatenate([batch_data['mask'], np.zeros((padding, self.data_mask.shape[-1]), dtype=np.int32)])
- batch_data['rel'] = np.concatenate([batch_data['rel'], np.zeros((padding), dtype=np.int32)])
- batch_data['ins_rel'] = np.concatenate([batch_data['ins_rel'], np.zeros((padding), dtype=np.int32)])
- batch_data['length'] = np.concatenate([batch_data['length'], np.zeros((padding), dtype=np.int32)])
- elif self.mode == self.MODE_ENTPAIR_BAG or self.mode == self.MODE_RELFACT_BAG:
- idx0 = self.idx
- idx1 = self.idx + batch_size
- if idx1 > len(self.order):
- idx1 = len(self.order)
- self.idx = idx1
- _word = []
- _sentences = []
- _pmc_word = []
- _pubmed_word = []
- _pubmed_and_pmc_word = []
- _pubmed_myself_word = []
- _wiki_pubmed_word = []
-
- _pos1 = []
- _pos2 = []
- _mask = []
- _rel = []
- _ins_rel = []
- _multi_rel = []
- _entpair = []
- _length = []
- _scope = []
- cur_pos = 0
- for i in range(idx0, idx1):
- _pmc_word.append(self.data_word_pmc[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
- _pubmed_and_pmc_word.append(self.data_word_pubmed_and_pmc[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
- _pubmed_word.append(self.data_word_pubmed[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
- _pubmed_myself_word.append(self.data_word_pubmed_myself[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
- _wiki_pubmed_word.append(self.data_word_wiki_pubmed[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
- _word.append(self.data_word[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
- _sentences.extend(self.sentences[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
-
- _pos1.append(self.data_pos1[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
- _pos2.append(self.data_pos2[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
- _mask.append(self.data_mask[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
- _rel.append(self.data_rel[self.scope[self.order[i]][0]])
- _ins_rel.append(self.data_rel[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
- _length.append(self.data_length[self.scope[self.order[i]][0]:self.scope[self.order[i]][1]])
- bag_size = self.scope[self.order[i]][1] - self.scope[self.order[i]][0]
- _scope.append([cur_pos, cur_pos + bag_size])
- cur_pos = cur_pos + bag_size
- if self.mode == self.MODE_ENTPAIR_BAG:
- _one_multi_rel = np.zeros((self.rel_tot), dtype=np.int32)
- for j in range(self.scope[self.order[i]][0], self.scope[self.order[i]][1]):
- _one_multi_rel[self.data_rel[j]] = 1
- _multi_rel.append(_one_multi_rel)
- _entpair.append(self.scope_name[self.order[i]])
- for i in range(batch_size - (idx1 - idx0)):
- _pmc_word.append(np.zeros((1, self.data_word_pmc.shape[-1]), dtype=np.int32))
- _pubmed_and_pmc_word.append(np.zeros((1, self.data_word_pubmed_and_pmc.shape[-1]), dtype=np.int32))
- _pubmed_word.append(np.zeros((1, self.data_word_pubmed.shape[-1]), dtype=np.int32))
- _pubmed_myself_word.append(np.zeros((1, self.data_word_pubmed_myself.shape[-1]), dtype=np.int32))
- _wiki_pubmed_word.append(np.zeros((1, self.data_word_wiki_pubmed.shape[-1]), dtype=np.int32))
- _word.append(np.zeros((1, self.data_word.shape[-1]), dtype=np.int32))
- _sentences.append({"head": "", "tail": "", "sentence": ""})
-
- _pos1.append(np.zeros((1, self.data_pos1.shape[-1]), dtype=np.int32))
- _pos2.append(np.zeros((1, self.data_pos2.shape[-1]), dtype=np.int32))
- _mask.append(np.zeros((1, self.data_mask.shape[-1]), dtype=np.int32))
- _rel.append(0)
- _ins_rel.append(np.zeros((1), dtype=np.int32))
- _length.append(np.zeros((1), dtype=np.int32))
- _scope.append([cur_pos, cur_pos + 1])
- cur_pos += 1
- if self.mode == self.MODE_ENTPAIR_BAG:
- _multi_rel.append(np.zeros((self.rel_tot), dtype=np.int32))
- _entpair.append('None#None')
- batch_data['pmc_word'] = np.concatenate(_pmc_word)
- batch_data['pubmed_word'] = np.concatenate(_pubmed_word)
- batch_data['pubmed_and_pmc_word'] = np.concatenate(_pubmed_and_pmc_word)
- batch_data['pubmed_myself_word'] = np.concatenate(_pubmed_myself_word)
- batch_data['wiki_pubmed_word'] = np.concatenate(_wiki_pubmed_word)
- batch_data['word'] = np.concatenate(_word)
- batch_data['sentences'] = _sentences
-
- batch_data['pos1'] = np.concatenate(_pos1)
- batch_data['pos2'] = np.concatenate(_pos2)
- batch_data['mask'] = np.concatenate(_mask)
- batch_data['rel'] = np.stack(_rel)
- batch_data['ins_rel'] = np.concatenate(_ins_rel)
- if self.mode == self.MODE_ENTPAIR_BAG:
- batch_data['multi_rel'] = np.stack(_multi_rel)
- batch_data['entpair'] = _entpair
- batch_data['length'] = np.concatenate(_length)
- batch_data['scope'] = np.stack(_scope)
-
- len_thre = 1000
- if len(batch_data['word']) > len_thre:
- print('*' * 100)
- batch_data['pmc_word'] = batch_data['pmc_word'][:len_thre]
- batch_data['pubmed_word'] = batch_data['pubmed_word'][:len_thre]
- batch_data['pubmed_and_pmc_word'] = batch_data['pubmed_and_pmc_word'][:len_thre]
- batch_data['pubmed_myself_word'] = batch_data['pubmed_myself_word'][:len_thre]
- batch_data['wiki_pubmed_word'] = batch_data['wiki_pubmed_word'][:len_thre]
- batch_data['word'] = batch_data['word'][:len_thre]
- batch_data['sentences'] = batch_data['sentences'][:len_thre]
-
- batch_data['pos1'] = batch_data['pos1'][:len_thre]
- batch_data['pos2'] = batch_data['pos2'][:len_thre]
- batch_data['mask'] = batch_data['mask'][:len_thre]
- batch_data['ins_rel'] = batch_data['ins_rel'][:len_thre]
- batch_data['length'] = batch_data['length'][:len_thre]
- batch_data['scope'] = np.array([[0, len_thre]])
-
- return batch_data
-
- # to decide whether the two entities are illegal
-
- def filter(self, ins):
- e1_name = ins['head']['word'].lower()
- e2_name = ins['tail']['word'].lower()
-
- return self.filter_1(e1_name, e2_name) \
- or self.filter_2(e1_name, e2_name) \
- or self.filter_3(ins) \
- or self.filter_4(ins)
-
-
- def filter_1(self, e1_name, e2_name):
- return e1_name == e2_name
-
-
- def filter_2(self, e1_name, e2_name):
- if len(str(e1_name).split(" ")) > 1:
- if len(str(e2_name).split(" ")) == 1:
- split_words = str(e1_name).split(" ")
- line = "".join([word[0] for word in split_words if str(word).rstrip() != ""])
- return line == e2_name
-
- if len(str(e2_name).split(" ")) > 1:
- if len(str(e1_name).split(" ")) == 1:
- split_words = str(e2_name).split(" ")
- # print "split words", split_words, "\t", e1_name, e2_name
- line = "".join([word[0] for word in split_words if str(word).rstrip() != ""])
- return line == e1_name
-
-
- def filter_3(self, ins):
- new_pair = ins['new_pair']
- new_pair_words = new_pair.split('@@')
- e1_pos = 0
- e2_pos = len(new_pair_words) - 1
- if math.fabs(e2_pos - e1_pos) == 1:
- return True
-
- if math.fabs(e2_pos - e1_pos) == 2:
-
- between = new_pair_words[1].lower()
- if between == "or" or between == "," or between == "(" or between == "-":
- return True
-
- if math.fabs(e2_pos - e1_pos) == 3:
- word = str(" ".join(new_pair_words[1: 3])).lower()
- if word == ", or" or word == "such as":
- return True
-
- # a,b,c, and d
-
- def filter_4(self, ins):
- except_words = [",", 'drug0', 'or', '(', '[', ')', ']', "and"]
- flags = False
- new_pair = ins['new_pair']
- new_pair_words = new_pair.split('@@')
- e1_pos = 0
- e2_pos = len(new_pair_words) - 1
-
- # print sequence
- for i in range(e1_pos + 1, e2_pos):
- word = str(new_pair_words[i]).lower()
- if word not in except_words:
- return False
- else:
- if word == "and":
- flags = True
- if flags is True:
- if e2_pos - e1_pos <= 4:
- return False
- return True
|