|
- import tensorflow as tf
- import os
- import sklearn.metrics
- import numpy as np
- import sys
- import math
- import time
- import json
-
- from copy import deepcopy
- from . import framework
- from . import network
-
-
- MIN_ADD_NEG_NUM = 4
- NUM_ADD = 0
-
-
- class PolicyAgent(framework.RelationModel):
- def __init__(self, train_data_loader, batch_size, max_length=150):
- framework.RelationModel.__init__(self, train_data_loader, batch_size, max_length)
- self.weights = tf.placeholder(tf.float32, shape=(), name="weights_scalar")
-
- x = network.embedding.origin_word_position_embedding(self.word, self.word_vec_mat, self.pos1, self.pos2)
- x_train = network.encoder.cnn(x, keep_prob=0.5)
- x_test = network.encoder.cnn(x, keep_prob=1.0)
-
- # (N, m) * (m, N) = (N, N)
- weight_x_train = tf.matmul(x_train, tf.transpose(x_train))
- att_mat_train = tf.linalg.band_part(tf.ones_like(weight_x_train), -1, 0)
- att_x_train = tf.nn.softmax(weight_x_train * att_mat_train, -1)
- att_dis_train = tf.math.top_k(att_mat_train, tf.shape(weight_x_train)[0]).indices
- att_dis_train = tf.cast(att_dis_train, dtype=tf.float32)
- att_dis_train = tf.nn.softmax(att_dis_train, -1)
-
- att_all_train = att_dis_train + att_x_train
- agent_weight = tf.tile(tf.expand_dims(self.ins_label, 0), [tf.shape(weight_x_train)[0], 1])
- true_weight = tf.zeros([tf.shape(weight_x_train)[0], tf.shape(weight_x_train)[0]]) + 0.4
- false_weight = tf.zeros([tf.shape(weight_x_train)[0], tf.shape(weight_x_train)[0]]) + 0.6
- agent_weight = tf.where(tf.equal(agent_weight, 1), true_weight, false_weight)
- att_all_train = tf.nn.softmax(att_all_train * agent_weight, -1)
- x_train = tf.matmul(att_all_train, x_train)
-
- weight_x_test = tf.matmul(x_test, tf.transpose(x_test))
- att_mat_test = tf.linalg.band_part(tf.ones_like(weight_x_test), -1, 0)
- att_x_test = tf.nn.softmax(weight_x_test * att_mat_test, -1)
- att_dis_test = tf.math.top_k(att_mat_test, tf.shape(weight_x_test)[0]).indices
- att_dis_test = tf.cast(att_dis_test, dtype=tf.float32)
- att_dis_test = tf.nn.softmax(att_dis_test, -1)
-
- att_all_test = att_dis_test + att_x_test
- agent_weight_test = tf.tile(tf.expand_dims(self.ins_label, 0), [tf.shape(weight_x_test)[0], 1])
- true_weight_test = tf.zeros([tf.shape(weight_x_test)[0], tf.shape(weight_x_test)[0]]) + 0.4
- false_weight_test = tf.zeros([tf.shape(weight_x_test)[0], tf.shape(weight_x_test)[0]]) + 0.6
- agent_weight_test = tf.where(tf.equal(agent_weight_test, 1), true_weight_test, false_weight_test)
- att_all_test = tf.nn.softmax(att_all_test * agent_weight_test, -1)
- x_test = tf.matmul(att_all_test, x_test)
-
- self._train_logit = network.selector.instance(x_train, 2, keep_prob=0.5)
- self._test_logit = network.selector.instance(x_test, 2, keep_prob=1.0)
- self._loss = network.classifier.softmax_cross_entropy(self._train_logit, self.ins_label, 2, weights=self.weights)
-
- def loss(self):
- return self._loss
-
- def train_logit(self):
- return self._train_logit
-
- def test_logit(self):
- return self._test_logit
-
-
- class RlRelationFramework(framework.RelationFramework):
- def __init__(self, train_data_loader, test_data_loader, max_length=150, batch_size=24):
- framework.RelationFramework.__init__(self, train_data_loader, test_data_loader, max_length, batch_size)
-
- def agent_one_step(self, sess, agent_model, batch_data, run_array, weights=1):
- feed_dict = {
- agent_model.word: batch_data['word'],
- agent_model.pos1: batch_data['pos1'],
- agent_model.pos2: batch_data['pos2'],
- agent_model.ins_label: batch_data['agent_label'],
- agent_model.length: batch_data['length'],
- agent_model.weights: weights
- }
- if 'mask' in batch_data and hasattr(agent_model, "mask"):
- feed_dict.update({agent_model.mask: batch_data['mask']})
- result = sess.run(run_array, feed_dict)
- return result
-
- def pretrain_main_model(self, max_epoch):
- for epoch in range(max_epoch):
- tot_correct = 0
- tot_not_na_correct = 0
- tot = 0
- tot_not_na = 0
- i = 0
- time_sum = 0
-
- for i, batch_data in enumerate(self.train_data_loader):
- time_start = time.time()
- iter_loss, iter_logit, _train_op = self.train_one_step(self.sess, self.model, batch_data, [self.model.loss(), self.model.train_logit(), self.train_op])
- time_end = time.time()
- t = time_end - time_start
- time_sum += t
- iter_output = iter_logit.argmax(-1)
- iter_label = batch_data['rel']
- iter_correct = (iter_output == iter_label).sum()
- iter_not_na_correct = np.logical_and(iter_output == iter_label, iter_label != 0).sum()
- tot_correct += iter_correct
- tot_not_na_correct += iter_not_na_correct
- tot += iter_label.shape[0]
- tot_not_na += (iter_label != 0).sum()
- sys.stdout.write("[pretrain main model training:] epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\r" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
- sys.stdout.flush()
- i += 1
- print("\nAverage iteration time: %f" % (time_sum / i))
-
- def pretrain_agent_model(self, max_epoch):
- # Pre-train policy agent
- for epoch in range(max_epoch):
- tot_correct = 0
- tot_not_na_correct = 0
- tot = 0
- tot_not_na = 0
- time_sum = 0
-
- for i, batch_data in enumerate(self.train_data_loader):
- time_start = time.time()
- batch_data['agent_label'] = batch_data['ins_rel'] + 0
- batch_data['agent_label'][batch_data['agent_label'] > 0] = 1
- iter_loss, iter_logit, _train_op = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.loss(), self.agent_model.train_logit(), self.agent_train_op])
- time_end = time.time()
- t = time_end - time_start
- time_sum += t
- iter_output = iter_logit.argmax(-1)
- iter_label = batch_data['ins_rel']
- iter_correct = (iter_output == iter_label).sum()
- iter_not_na_correct = np.logical_and(iter_output == iter_label, iter_label != 0).sum()
- tot_correct += iter_correct
- tot_not_na_correct += iter_not_na_correct
- tot += iter_label.shape[0]
- tot_not_na += (iter_label != 0).sum()
- if tot_not_na > 0:
- sys.stdout.write("[rl training:] epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\r" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
-
- sys.stdout.flush()
- i += 1
-
- def add_negative_bag(self, batch_data, action_result):
- """
- 根据action result增加负例bag
- Add negative bag according to action result
-
- :return:
- """
- origin_scope = deepcopy(batch_data["scope"])
- scope = deepcopy(batch_data["scope"])
- ins_rel = batch_data["ins_rel"]
- rel = batch_data["rel"]
- word = batch_data["word"]
- pmc_word = batch_data['pmc_word']
- pubmed_word = batch_data['pubmed_word']
- pubmed_and_pmc_word = batch_data['pubmed_and_pmc_word']
- pubmed_myself_word = batch_data['pubmed_myself_word']
- wiki_pubmed_word = batch_data['wiki_pubmed_word']
- pos1 = batch_data['pos1']
- pos2 = batch_data['pos2']
- mask = batch_data['mask']
- length = batch_data['length']
-
- not_na_changed_index = (action_result == 0) & (ins_rel == 1)
-
- for item in scope:
- begin = item[0]
- end = item[1]
-
- bag_selected = not_na_changed_index[begin: end]
- bag_ins_rel = list()
- bag_rel = list()
- bag_word = list()
- bag_pmc_word = list()
- bag_pubmed_word = list()
- bag_pubmed_and_pmc_word = list()
- bag_pubmed_myself_word = list()
- bag_wiki_pubmed_word = list()
- bag_pos1 = list()
- bag_pos2 = list()
- bag_mask = list()
- bag_length = list()
- fp_sentences = list()
-
- for i, ind in enumerate(bag_selected):
- if ind > 0:
- location_index = begin + i
- bag_ins_rel.append(ins_rel[location_index])
- bag_word.append(word[location_index])
- fp_sentences.append(batch_data["sentences"][location_index])
-
- bag_pmc_word.append(pmc_word[location_index])
- bag_pubmed_word.append(pubmed_word[location_index])
- bag_pubmed_and_pmc_word.append(pubmed_and_pmc_word[location_index])
- bag_pubmed_myself_word.append(pubmed_myself_word[location_index])
- bag_wiki_pubmed_word.append(wiki_pubmed_word[location_index])
- bag_pos1.append(pos1[location_index])
- bag_pos2.append(pos2[location_index])
- bag_mask.append(mask[location_index])
- bag_length.append(length[location_index])
-
- # add new negative bag
- if len(bag_ins_rel) > MIN_ADD_NEG_NUM:
- global NUM_ADD
- NUM_ADD += 1
-
- origin_scope = np.concatenate([origin_scope, [(len(ins_rel), len(ins_rel) + len(bag_ins_rel))]])
- ins_rel = np.concatenate([ins_rel, bag_ins_rel])
- rel = np.concatenate([rel, [0]])
- word = np.concatenate([word, bag_word])
- pmc_word = np.concatenate([pmc_word, bag_pmc_word])
- pubmed_word = np.concatenate([pubmed_word, bag_pubmed_word])
- pubmed_and_pmc_word = np.concatenate([pubmed_and_pmc_word, bag_pubmed_and_pmc_word])
- pubmed_myself_word = np.concatenate([pubmed_myself_word, bag_pubmed_myself_word])
- wiki_pubmed_word = np.concatenate([wiki_pubmed_word, bag_wiki_pubmed_word])
- pos1 = np.concatenate([pos1, bag_pos1])
- pos2 = np.concatenate([pos2, bag_pos2])
- mask = np.concatenate([mask, bag_mask])
- length = np.concatenate([length, bag_length])
-
- # update batch_data
- batch_data["scope"] = origin_scope
- batch_data["ins_rel"] = ins_rel
- batch_data["rel"] = rel
- batch_data["word"] = word
- batch_data['pmc_word'] = pmc_word
- batch_data['pubmed_word'] = pubmed_word
- batch_data['pubmed_and_pmc_word'] = pubmed_and_pmc_word
- batch_data['pubmed_myself_word'] = pubmed_myself_word
- batch_data['wiki_pubmed_word'] = wiki_pubmed_word
- batch_data['pos1'] = pos1
- batch_data['pos2'] = pos2
- batch_data['mask'] = mask
- batch_data['length'] = length
-
- def train(self,
- model, # The main model
- agent_model, # The model of policy agent
- model_name,
- ckpt_dir='./checkpoint',
- summary_dir='./summary',
- test_result_dir='./test_result',
- learning_rate=0.001,
- learning_rate_agent=1e-10,
- max_epoch=600,
- pretrain_agent_epoch=1,
- pretrain_model=None,
- test_epoch=1,
- optimizer=tf.train.GradientDescentOptimizer):
-
- print("Start training...")
-
- # Init
- self.model = model(self.train_data_loader, self.train_data_loader.batch_size, self.train_data_loader.max_length)
- model_optimizer = optimizer(learning_rate)
- grads = model_optimizer.compute_gradients(self.model.loss())
- self.train_op = model_optimizer.apply_gradients(grads)
-
- # Init policy agent
- self.agent_model = agent_model(self.train_data_loader, self.train_data_loader.batch_size, self.train_data_loader.max_length)
- agent_optimizer = optimizer(learning_rate_agent)
- agent_grads = agent_optimizer.compute_gradients(self.agent_model.loss())
- self.agent_train_op = agent_optimizer.apply_gradients(agent_grads)
-
- # Session, writer and saver
- config = tf.ConfigProto(allow_soft_placement=True) # allow cpu computing if there is no gpu available
- gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
- config.gpu_options.allow_growth = True
- self.sess = tf.Session(config=config)
- summary_writer = tf.summary.FileWriter(summary_dir, self.sess.graph)
- saver = tf.compat.v1.train.Saver(max_to_keep=None)
-
- if pretrain_model is None:
- self.sess.run(tf.global_variables_initializer())
- else:
- model_file = tf.train.latest_checkpoint(pretrain_model)
- saver.restore(self.sess, model_file)
- print("restore suceeded ..............................")
-
- # pretrain main model
- self.pretrain_main_model(max_epoch=20)
-
- # pretrain policy agent
- self.pretrain_agent_model(max_epoch=20)
-
- # Train
- tot_delete = 0
- batch_count = 0
- reward = 0.0
- for epoch in range(max_epoch):
- print('###### Epoch ' + str(epoch) + ' ######')
- tot_correct = 0
- tot_not_na_correct = 0
- tot = 0
- tot_not_na = 0
- i = 0
- time_sum = 0
- batch_stack = []
- # Update policy agent
- for i, batch_data in enumerate(self.train_data_loader):
- # Make action
- batch_data['agent_label'] = batch_data['ins_rel'] + 0
- batch_data['agent_label'][batch_data['agent_label'] > 0] = 1
- batch_stack.append(batch_data)
- iter_logit = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.train_logit()])[0]
- action_result = iter_logit.argmax(-1)
- if np.sum(action_result) == 0:
- continue
-
- # According to the action, add negative case bag, extract Na instance
- # in non Na bag selected by selector as new Na bag
-
- new_batch_data = deepcopy(batch_data)
- self.add_negative_bag(new_batch_data, action_result)
-
- # Calculate reward
- batch_delete = np.sum(np.logical_and(batch_data['ins_rel'] != 0, action_result == 0))
- batch_data['agent_label'][action_result == 0] = 0
-
- # train main model with new batch_data
- iter_loss = self.train_one_step(self.sess, self.model, new_batch_data, [self.model.loss()])[0]
- reward += iter_loss
- tot_delete += batch_delete
- batch_count += 1
-
- # Update parameters of policy agent
- alpha = 0.1
- try:
- if batch_count == 1:
- reward = reward / float(batch_count)
- average_loss = reward
- reward = - math.log(1 - math.e ** (-reward))
- sys.stdout.write('tot delete : %f | reward : %f | average loss : %f\r' % (tot_delete, reward, average_loss))
- sys.stdout.flush()
- for batch_data in batch_stack:
- self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_train_op], weights=reward * alpha)
- batch_count = 0
- reward = 0
- tot_delete = 0
- batch_stack = []
- except Exception as e:
- print(e)
- continue
- i += 1
-
- # Train the main model
- for i, batch_data in enumerate(self.train_data_loader):
- batch_data['agent_label'] = batch_data['ins_rel'] + 0
- batch_data['agent_label'][batch_data['agent_label'] > 0] = 1
- time_start = time.time()
-
- # Make actions
- iter_logit = self.agent_one_step(self.sess, self.agent_model, batch_data, [self.agent_model.train_logit()])[0]
- action_result = iter_logit.argmax(-1)
-
- # add new negative bag
- new_batch_data = deepcopy(batch_data)
- self.add_negative_bag(new_batch_data, action_result)
- # batch_data['ins_rel'][action_result == 0] = 0
-
- # Real training
- iter_loss, iter_logit, _train_op = self.train_one_step(self.sess, self.model, new_batch_data, [self.model.loss(), self.model.train_logit(), self.train_op])
-
- time_end = time.time()
- t = time_end - time_start
- time_sum += t
- iter_output = iter_logit.argmax(-1)
- if tot_not_na > 0:
- sys.stdout.write("epoch %d step %d time %.2f | loss: %f, not NA accuracy: %f, accuracy: %f\r" % (epoch, i, t, iter_loss, float(tot_not_na_correct) / tot_not_na, float(tot_correct) / tot))
- sys.stdout.flush()
- i += 1
- print("\nAverage iteration time: %f" % (time_sum / i))
-
- self.test(model)
-
- # save model
- try:
- print("save the best model to %s" % (os.path.join(ckpt_dir + "_rl", model_name)))
- saver.save(self.sess, os.path.join(ckpt_dir + "_rl", model_name))
- except Exception as e:
- print("save meta fail---------------------------------------------------")
-
- print("######")
- print("Finish training " + model_name)
|