@@ -0,0 +1,234 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""Data operations, will be used in train.py.""" | |||||
import os | |||||
from typing import Dict | |||||
from enum import Enum | |||||
import numpy as np | |||||
from src.config import config | |||||
np.random.seed(config.random_seed) | |||||
class BatchType(Enum): | |||||
HEAD_BATCH = 0 | |||||
TAIL_BATCH = 1 | |||||
SINGLE = 2 | |||||
class ModeType(Enum): | |||||
TRAIN = 0 | |||||
VALID = 1 | |||||
TEST = 2 | |||||
class DataReader: | |||||
""" | |||||
Read data class | |||||
Args: | |||||
data_path: data path. | |||||
""" | |||||
def __init__(self, data_path): | |||||
entity_dict_path = os.path.join(data_path, 'entities.dict') | |||||
relation_dict_path = os.path.join(data_path, 'relations.dict') | |||||
train_data_path = os.path.join(data_path, 'train.txt') | |||||
valid_data_path = os.path.join(data_path, 'valid.txt') | |||||
test_data_path = os.path.join(data_path, 'test.txt') | |||||
self.entity_dict = self.read_dict(entity_dict_path) | |||||
self.relation_dict = self.read_dict(relation_dict_path) | |||||
self.train_data = self.read_data(train_data_path, self.entity_dict, self.relation_dict) | |||||
self.valid_data = self.read_data(valid_data_path, self.entity_dict, self.relation_dict) | |||||
self.test_data = self.read_data(test_data_path, self.entity_dict, self.relation_dict) | |||||
def read_dict(self, dict_path: str): | |||||
""" | |||||
Read entity / relation dict. | |||||
Format: dict({id: entity / relation}) | |||||
""" | |||||
element_dict = {} | |||||
with open(dict_path, 'r') as f: | |||||
for line in f: | |||||
id_, element = line.strip().split('\t') | |||||
element_dict[element] = int(id_) | |||||
return element_dict | |||||
def read_data(self, data_path: str, entity_dict: Dict[str, int], relation_dict: Dict[str, int]): | |||||
""" | |||||
Read train / valid / test data. | |||||
""" | |||||
triples = [] | |||||
with open(data_path, 'r') as f: | |||||
for line in f: | |||||
head, relation, tail = line.strip().split('\t') | |||||
triples.append((entity_dict[head], relation_dict[relation], entity_dict[tail])) | |||||
return triples | |||||
class TrainDataset: | |||||
""" | |||||
create training data | |||||
Args: | |||||
data_reader: data reader class. | |||||
neg_size: negative sample size | |||||
batch_type: batch type. HEAD or TAIL | |||||
""" | |||||
def __init__(self, data_reader: DataReader, neg_size: int, batch_type: BatchType): | |||||
self.triples = data_reader.train_data | |||||
self.len = len(self.triples) | |||||
self.num_entity = len(data_reader.entity_dict) | |||||
self.num_relation = len(data_reader.relation_dict) | |||||
self.neg_size = neg_size | |||||
self.batch_type = batch_type | |||||
self.hr_map, self.tr_map, self.hr_freq, self.tr_freq = self.two_tuple_count() | |||||
def __getitem__(self, idx): | |||||
""" | |||||
Returns a positive sample and `self.neg_size` negative samples. | |||||
""" | |||||
pos_triple = self.triples[idx] | |||||
head, rel, tail = pos_triple | |||||
subsampling_weight = self.hr_freq[(head, rel)] + self.tr_freq[(tail, rel)] | |||||
subsampling_weight = np.sqrt(1 / np.array([subsampling_weight])) | |||||
neg_triples = [] | |||||
neg_size = 0 | |||||
while neg_size < self.neg_size: | |||||
neg_triples_tmp = np.random.randint(self.num_entity, size=self.neg_size * 2) | |||||
if self.batch_type == BatchType.HEAD_BATCH: | |||||
mask = np.in1d( | |||||
neg_triples_tmp, | |||||
self.tr_map[(tail, rel)], | |||||
assume_unique=True, | |||||
invert=True | |||||
) | |||||
elif self.batch_type == BatchType.TAIL_BATCH: | |||||
mask = np.in1d( | |||||
neg_triples_tmp, | |||||
self.hr_map[(head, rel)], | |||||
assume_unique=True, | |||||
invert=True | |||||
) | |||||
else: | |||||
raise ValueError('Invalid BatchType: {}'.format(self.batch_type)) | |||||
neg_triples_tmp = neg_triples_tmp[mask] | |||||
neg_triples.append(neg_triples_tmp) | |||||
neg_size += neg_triples_tmp.size | |||||
neg_triples = np.concatenate(neg_triples)[:self.neg_size] | |||||
pos_triple = np.array(pos_triple) | |||||
neg_triples = np.array(neg_triples) | |||||
return pos_triple, neg_triples, subsampling_weight | |||||
def __len__(self): | |||||
return self.len | |||||
def two_tuple_count(self): | |||||
""" | |||||
Return two dict: | |||||
dict({(h, r): [t1, t2, ...]}), | |||||
dict({(t, r): [h1, h2, ...]}), | |||||
""" | |||||
hr_map = {} | |||||
hr_freq = {} | |||||
tr_map = {} | |||||
tr_freq = {} | |||||
init_cnt = 3 | |||||
for head, rel, tail in self.triples: | |||||
if (head, rel) not in hr_map.keys(): | |||||
hr_map[(head, rel)] = set() | |||||
if (tail, rel) not in tr_map.keys(): | |||||
tr_map[(tail, rel)] = set() | |||||
if (head, rel) not in hr_freq.keys(): | |||||
hr_freq[(head, rel)] = init_cnt | |||||
if (tail, rel) not in tr_freq.keys(): | |||||
tr_freq[(tail, rel)] = init_cnt | |||||
hr_map[(head, rel)].add(tail) | |||||
tr_map[(tail, rel)].add(head) | |||||
hr_freq[(head, rel)] += 1 | |||||
tr_freq[(tail, rel)] += 1 | |||||
for key in tr_map: | |||||
tr_map[key] = np.array(list(tr_map[key])) | |||||
for key in hr_map: | |||||
hr_map[key] = np.array(list(hr_map[key])) | |||||
return hr_map, tr_map, hr_freq, tr_freq | |||||
class TestDataset: | |||||
""" | |||||
create test data | |||||
Args: | |||||
data_reader: data reader class. | |||||
mode: test model. VALID or TEST. | |||||
batch_type: batch type. HEAD or TAIL. | |||||
""" | |||||
def __init__(self, data_reader: DataReader, mode: ModeType, batch_type: BatchType): | |||||
self.triple_set = set(data_reader.train_data + data_reader.valid_data + data_reader.test_data) | |||||
if mode == ModeType.VALID: | |||||
self.triples = data_reader.valid_data | |||||
elif mode == ModeType.TEST: | |||||
self.triples = data_reader.test_data | |||||
self.len = len(self.triples) | |||||
self.num_entity = len(data_reader.entity_dict) | |||||
self.num_relation = len(data_reader.relation_dict) | |||||
self.mode = mode | |||||
self.batch_type = batch_type | |||||
def __len__(self): | |||||
return self.len | |||||
def __getitem__(self, idx): | |||||
head, relation, tail = self.triples[idx] | |||||
if self.batch_type == BatchType.HEAD_BATCH: | |||||
tmp = [(0, rand_head) if (rand_head, relation, tail) not in self.triple_set | |||||
else (-1, head) for rand_head in range(self.num_entity)] | |||||
tmp[head] = (0, head) | |||||
elif self.batch_type == BatchType.TAIL_BATCH: | |||||
tmp = [(0, rand_tail) if (head, relation, rand_tail) not in self.triple_set | |||||
else (-1, tail) for rand_tail in range(self.num_entity)] | |||||
tmp[tail] = (0, tail) | |||||
else: | |||||
raise ValueError('negative batch type {} not supported'.format(self.mode)) | |||||
tmp = np.array(tmp) | |||||
filter_bias = tmp[:, 0] | |||||
negative_sample = tmp[:, 1] | |||||
positive_sample = np.array((head, relation, tail)) | |||||
return positive_sample, negative_sample, filter_bias |
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》