|
- import module
- import numpy as np
- import random
- import tensorflow as tf
- import sys
- import os
-
- dataset_dir = sys.argv[1]
- if not os.path.isdir(dataset_dir):
- raise Exception("[ERROR] Dataset dir %s doesn't exist!" % (dataset_dir))
-
- SEED = 1234
- tf.compat.v1.random.set_random_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- BATCH_SIZE = 8
- main_model_lr = 1e-2
- agent_model_lr = 1e-3
- MAX_LEN = 150
-
- train_loader = module.data_loader.JsonFileDataLoader(os.path.join(dataset_dir, 'train_prevent_newreplace.json'),
- os.path.join(dataset_dir, 'word_vec.json'),
- os.path.join(dataset_dir, 'pmc'),
- os.path.join(dataset_dir, 'pubmed'),
- os.path.join(dataset_dir, 'pubmed_and_pmc'),
- os.path.join(dataset_dir, 'pubmed_myself'),
- os.path.join(dataset_dir, 'wiki_pubmed'),
- os.path.join(dataset_dir, 'prevent_rel2id.json'),
- mode=module.data_loader.JsonFileDataLoader.MODE_RELFACT_BAG,
- shuffle=True,
- batch_size=BATCH_SIZE)
-
- test_loader = module.data_loader.JsonFileDataLoader(os.path.join(dataset_dir, 'MAY_PREVENT_MANUALLY_LABELLED_newreplace_changeid.json'),
- os.path.join(dataset_dir, 'word_vec.json'),
- os.path.join(dataset_dir, 'pmc'),
- os.path.join(dataset_dir, 'pubmed'),
- os.path.join(dataset_dir, 'pubmed_and_pmc'),
- os.path.join(dataset_dir, 'pubmed_myself'),
- os.path.join(dataset_dir, 'wiki_pubmed'),
- os.path.join(dataset_dir, 'prevent_rel2id.json'),
- mode=module.data_loader.JsonFileDataLoader.MODE_INSTANCE,
- shuffle=False,
- batch_size=BATCH_SIZE)
-
-
-
- class model(module.framework.RelationModel):
- encoder = "pcnn"
- selector = "att"
-
- def __init__(self, train_data_loader, batch_size, max_length=MAX_LEN):
- module.framework.RelationModel.__init__(self, train_data_loader, batch_size, max_length=max_length)
- self.mask = tf.placeholder(dtype=tf.int32, shape=[None, max_length], name="mask")
-
- # Embedding
- with tf.name_scope('embedding'):
- e1, e2, x = module.network.embedding.word_position_embedding(batch_size, self.word, self.word_vec_mat,
- self.pmc_word, self.word_vec_mat_pmc,
- self.pubmed_word, self.word_vec_mat_pubmed,
- self.pubmed_and_pmc_word,
- self.word_vec_mat_pubmed_and_pmc,
- self.pubmed_myself_word,
- self.word_vec_mat_pubmed_myself,
- self.wiki_pubmed_word,
- self.word_vec_mat_wiki_pubmed,
- self.pos1, self.pos2)
-
- # Encoder
- with tf.name_scope('encoder'):
- if model.encoder == "pcnn":
- x_train = module.network.encoder.pcnn(x, self.mask, keep_prob=0.5)
- x_test = module.network.encoder.pcnn(x, self.mask, keep_prob=1.0)
- elif model.encoder == "cnn":
- x_train = module.network.encoder.cnn(x, keep_prob=0.5)
- x_test = module.network.encoder.cnn(x, keep_prob=1.0)
- elif model.encoder == "rnn":
- x_train = module.network.encoder.rnn(x, self.length, keep_prob=0.5)
- x_test = module.network.encoder.rnn(x, self.length, keep_prob=1.0)
- elif model.encoder == "birnn":
- x_train = module.network.encoder.birnn(x, self.length, keep_prob=0.5)
- x_test = module.network.encoder.birnn(x, self.length, keep_prob=1.0)
- else:
- raise NotImplementedError
-
- # Selector
- with tf.name_scope('selector'):
- if model.selector == "att":
- self._train_logit, train_repre = module.network.selector.bag_attention(e1, e2, x_train, self.scope,
- self.ins_label, self.rel_tot,
- True, keep_prob=0.5)
- self._test_logit, test_repre = module.network.selector.bag_attention(e1, e2, x_test, self.scope,
- self.ins_label, self.rel_tot,
- False, keep_prob=1.0)
- elif model.selector == "ave":
- self._train_logit, train_repre = module.network.selector.bag_average(x_train, self.scope, self.rel_tot,
- keep_prob=0.5)
- self._test_logit, test_repre = module.network.selector.bag_average(x_test, self.scope, self.rel_tot,
- keep_prob=1.0)
- self._test_logit = tf.nn.softmax(self._test_logit)
- elif model.selector == "one":
- self._train_logit, train_repre = module.network.selector.bag_one(x_train, self.scope, self.label,
- self.rel_tot, True, keep_prob=0.5)
- self._test_logit, test_repre = module.network.selector.bag_one(x_test, self.scope, self.label,
- self.rel_tot, False, keep_prob=1.0)
- self._test_logit = tf.nn.softmax(self._test_logit)
- elif model.selector == "cross_max":
- self._train_logit, train_repre = module.network.selector.bag_cross_max(x_train, self.scope,
- self.rel_tot, keep_prob=0.5)
- self._test_logit, test_repre = module.network.selector.bag_cross_max(x_test, self.scope, self.rel_tot,
- keep_prob=1.0)
- self._test_logit = tf.nn.softmax(self._test_logit)
- else:
- raise NotImplementedError
-
- # Classifier
- with tf.name_scope('classifier'):
- self._loss = module.network.classifier.softmax_cross_entropy(self._train_logit, self.label, self.rel_tot,
- weights_table=self.get_weights())
-
- def loss(self):
- return self._loss
-
- def train_logit(self):
- return self._train_logit
-
- def test_logit(self):
- return self._test_logit
-
- def get_weights(self):
- with tf.variable_scope("weights_table", reuse=tf.AUTO_REUSE):
- print("Calculating weights_table...")
- _weights_table = np.zeros((self.rel_tot), dtype=np.float32)
- for i in range(len(self.train_data_loader.data_rel)):
- _weights_table[self.train_data_loader.data_rel[i]] += 1.0
- _weights_table = 1 / (_weights_table ** 0.05)
- weights_table = tf.get_variable(name='weights_table', dtype=tf.float32, trainable=False,
- initializer=_weights_table)
- print("Finish calculating")
- return weights_table
-
-
- if __name__ == "__main__":
- r_p_model = module.rl_pacnn.RlRelationFramework(train_loader, test_loader, max_length=MAX_LEN,
- batch_size=BATCH_SIZE)
-
- # 指定自己的模型保存路径和相应模型加载路径,如pretrain_model参数 代表模型预加载路径,若重新训练设置为None
- # 模型保存路径由ckpt_dir和model_name参数共同确定,使用者可自行指定自己的路径名称
- r_p_model.train(model, module.rl_pacnn.PolicyAgent, model_name="medline_test",
- max_epoch=0,
- ckpt_dir="./ckpt/prevent",
- pretrain_model="./ckpt/prevent",
- learning_rate=main_model_lr,
- learning_rate_agent=agent_model_lr)
|