|
|
@@ -0,0 +1,234 @@ |
|
|
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# ============================================================================
|
|
|
|
"""Data operations, will be used in train.py."""
|
|
|
|
|
|
|
|
import os
|
|
|
|
from typing import Dict
|
|
|
|
from enum import Enum
|
|
|
|
import numpy as np
|
|
|
|
from src.config import config
|
|
|
|
|
|
|
|
np.random.seed(config.random_seed)
|
|
|
|
|
|
|
|
|
|
|
|
class BatchType(Enum):
|
|
|
|
HEAD_BATCH = 0
|
|
|
|
TAIL_BATCH = 1
|
|
|
|
SINGLE = 2
|
|
|
|
|
|
|
|
|
|
|
|
class ModeType(Enum):
|
|
|
|
TRAIN = 0
|
|
|
|
VALID = 1
|
|
|
|
TEST = 2
|
|
|
|
|
|
|
|
|
|
|
|
class DataReader:
|
|
|
|
"""
|
|
|
|
Read data class
|
|
|
|
Args:
|
|
|
|
data_path: data path.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, data_path):
|
|
|
|
entity_dict_path = os.path.join(data_path, 'entities.dict')
|
|
|
|
relation_dict_path = os.path.join(data_path, 'relations.dict')
|
|
|
|
train_data_path = os.path.join(data_path, 'train.txt')
|
|
|
|
valid_data_path = os.path.join(data_path, 'valid.txt')
|
|
|
|
test_data_path = os.path.join(data_path, 'test.txt')
|
|
|
|
|
|
|
|
self.entity_dict = self.read_dict(entity_dict_path)
|
|
|
|
self.relation_dict = self.read_dict(relation_dict_path)
|
|
|
|
|
|
|
|
self.train_data = self.read_data(train_data_path, self.entity_dict, self.relation_dict)
|
|
|
|
self.valid_data = self.read_data(valid_data_path, self.entity_dict, self.relation_dict)
|
|
|
|
self.test_data = self.read_data(test_data_path, self.entity_dict, self.relation_dict)
|
|
|
|
|
|
|
|
def read_dict(self, dict_path: str):
|
|
|
|
"""
|
|
|
|
Read entity / relation dict.
|
|
|
|
Format: dict({id: entity / relation})
|
|
|
|
"""
|
|
|
|
|
|
|
|
element_dict = {}
|
|
|
|
with open(dict_path, 'r') as f:
|
|
|
|
for line in f:
|
|
|
|
id_, element = line.strip().split('\t')
|
|
|
|
element_dict[element] = int(id_)
|
|
|
|
|
|
|
|
return element_dict
|
|
|
|
|
|
|
|
def read_data(self, data_path: str, entity_dict: Dict[str, int], relation_dict: Dict[str, int]):
|
|
|
|
"""
|
|
|
|
Read train / valid / test data.
|
|
|
|
"""
|
|
|
|
triples = []
|
|
|
|
with open(data_path, 'r') as f:
|
|
|
|
for line in f:
|
|
|
|
head, relation, tail = line.strip().split('\t')
|
|
|
|
triples.append((entity_dict[head], relation_dict[relation], entity_dict[tail]))
|
|
|
|
return triples
|
|
|
|
|
|
|
|
|
|
|
|
class TrainDataset:
|
|
|
|
"""
|
|
|
|
create training data
|
|
|
|
Args:
|
|
|
|
data_reader: data reader class.
|
|
|
|
neg_size: negative sample size
|
|
|
|
batch_type: batch type. HEAD or TAIL
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, data_reader: DataReader, neg_size: int, batch_type: BatchType):
|
|
|
|
self.triples = data_reader.train_data
|
|
|
|
self.len = len(self.triples)
|
|
|
|
self.num_entity = len(data_reader.entity_dict)
|
|
|
|
self.num_relation = len(data_reader.relation_dict)
|
|
|
|
self.neg_size = neg_size
|
|
|
|
self.batch_type = batch_type
|
|
|
|
|
|
|
|
self.hr_map, self.tr_map, self.hr_freq, self.tr_freq = self.two_tuple_count()
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
"""
|
|
|
|
Returns a positive sample and `self.neg_size` negative samples.
|
|
|
|
"""
|
|
|
|
pos_triple = self.triples[idx]
|
|
|
|
head, rel, tail = pos_triple
|
|
|
|
|
|
|
|
subsampling_weight = self.hr_freq[(head, rel)] + self.tr_freq[(tail, rel)]
|
|
|
|
subsampling_weight = np.sqrt(1 / np.array([subsampling_weight]))
|
|
|
|
|
|
|
|
neg_triples = []
|
|
|
|
neg_size = 0
|
|
|
|
|
|
|
|
while neg_size < self.neg_size:
|
|
|
|
neg_triples_tmp = np.random.randint(self.num_entity, size=self.neg_size * 2)
|
|
|
|
if self.batch_type == BatchType.HEAD_BATCH:
|
|
|
|
mask = np.in1d(
|
|
|
|
neg_triples_tmp,
|
|
|
|
self.tr_map[(tail, rel)],
|
|
|
|
assume_unique=True,
|
|
|
|
invert=True
|
|
|
|
)
|
|
|
|
elif self.batch_type == BatchType.TAIL_BATCH:
|
|
|
|
mask = np.in1d(
|
|
|
|
neg_triples_tmp,
|
|
|
|
self.hr_map[(head, rel)],
|
|
|
|
assume_unique=True,
|
|
|
|
invert=True
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ValueError('Invalid BatchType: {}'.format(self.batch_type))
|
|
|
|
|
|
|
|
neg_triples_tmp = neg_triples_tmp[mask]
|
|
|
|
neg_triples.append(neg_triples_tmp)
|
|
|
|
neg_size += neg_triples_tmp.size
|
|
|
|
|
|
|
|
neg_triples = np.concatenate(neg_triples)[:self.neg_size]
|
|
|
|
|
|
|
|
pos_triple = np.array(pos_triple)
|
|
|
|
neg_triples = np.array(neg_triples)
|
|
|
|
|
|
|
|
return pos_triple, neg_triples, subsampling_weight
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return self.len
|
|
|
|
|
|
|
|
def two_tuple_count(self):
|
|
|
|
"""
|
|
|
|
Return two dict:
|
|
|
|
dict({(h, r): [t1, t2, ...]}),
|
|
|
|
dict({(t, r): [h1, h2, ...]}),
|
|
|
|
"""
|
|
|
|
hr_map = {}
|
|
|
|
hr_freq = {}
|
|
|
|
tr_map = {}
|
|
|
|
tr_freq = {}
|
|
|
|
|
|
|
|
init_cnt = 3
|
|
|
|
for head, rel, tail in self.triples:
|
|
|
|
if (head, rel) not in hr_map.keys():
|
|
|
|
hr_map[(head, rel)] = set()
|
|
|
|
|
|
|
|
if (tail, rel) not in tr_map.keys():
|
|
|
|
tr_map[(tail, rel)] = set()
|
|
|
|
|
|
|
|
if (head, rel) not in hr_freq.keys():
|
|
|
|
hr_freq[(head, rel)] = init_cnt
|
|
|
|
|
|
|
|
if (tail, rel) not in tr_freq.keys():
|
|
|
|
tr_freq[(tail, rel)] = init_cnt
|
|
|
|
|
|
|
|
hr_map[(head, rel)].add(tail)
|
|
|
|
tr_map[(tail, rel)].add(head)
|
|
|
|
hr_freq[(head, rel)] += 1
|
|
|
|
tr_freq[(tail, rel)] += 1
|
|
|
|
|
|
|
|
for key in tr_map:
|
|
|
|
tr_map[key] = np.array(list(tr_map[key]))
|
|
|
|
|
|
|
|
for key in hr_map:
|
|
|
|
hr_map[key] = np.array(list(hr_map[key]))
|
|
|
|
|
|
|
|
return hr_map, tr_map, hr_freq, tr_freq
|
|
|
|
|
|
|
|
|
|
|
|
class TestDataset:
|
|
|
|
"""
|
|
|
|
create test data
|
|
|
|
Args:
|
|
|
|
data_reader: data reader class.
|
|
|
|
mode: test model. VALID or TEST.
|
|
|
|
batch_type: batch type. HEAD or TAIL.
|
|
|
|
"""
|
|
|
|
def __init__(self, data_reader: DataReader, mode: ModeType, batch_type: BatchType):
|
|
|
|
self.triple_set = set(data_reader.train_data + data_reader.valid_data + data_reader.test_data)
|
|
|
|
if mode == ModeType.VALID:
|
|
|
|
self.triples = data_reader.valid_data
|
|
|
|
elif mode == ModeType.TEST:
|
|
|
|
self.triples = data_reader.test_data
|
|
|
|
|
|
|
|
self.len = len(self.triples)
|
|
|
|
|
|
|
|
self.num_entity = len(data_reader.entity_dict)
|
|
|
|
self.num_relation = len(data_reader.relation_dict)
|
|
|
|
|
|
|
|
self.mode = mode
|
|
|
|
self.batch_type = batch_type
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return self.len
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
head, relation, tail = self.triples[idx]
|
|
|
|
if self.batch_type == BatchType.HEAD_BATCH:
|
|
|
|
tmp = [(0, rand_head) if (rand_head, relation, tail) not in self.triple_set
|
|
|
|
else (-1, head) for rand_head in range(self.num_entity)]
|
|
|
|
tmp[head] = (0, head)
|
|
|
|
elif self.batch_type == BatchType.TAIL_BATCH:
|
|
|
|
tmp = [(0, rand_tail) if (head, relation, rand_tail) not in self.triple_set
|
|
|
|
else (-1, tail) for rand_tail in range(self.num_entity)]
|
|
|
|
tmp[tail] = (0, tail)
|
|
|
|
else:
|
|
|
|
raise ValueError('negative batch type {} not supported'.format(self.mode))
|
|
|
|
|
|
|
|
tmp = np.array(tmp)
|
|
|
|
filter_bias = tmp[:, 0]
|
|
|
|
negative_sample = tmp[:, 1]
|
|
|
|
|
|
|
|
positive_sample = np.array((head, relation, tail))
|
|
|
|
|
|
|
|
return positive_sample, negative_sample, filter_bias
|