|
- from utils.data import Data
- from utils.batchify import batchify
- from utils.config import get_args
- from utils.metric import get_ner_fmeasure
- from model.bilstm_gcn_crf import BLSTM_GCN_CRF
- from pytorch_pretrained_bert import BertTokenizer
- from pytorch_pretrained_bert.optimization import BertAdam
- import os
- import numpy as np
- import copy
- import pickle
- import torch
- import torch.optim as optim
- import time
- import random
- import sys
- import gc
- import collections
- import h5py
- import warnings
- warnings.filterwarnings("ignore", category=UserWarning)
-
-
- def data_initialization(args):
- data_stored_directory = args.data_stored_directory
- file = data_stored_directory + args.dataset_name + "_dataset.dset"
- if os.path.exists(file) and not args.refresh:
- data = load_data_setting(data_stored_directory, args.dataset_name)
- else:
- data = Data()
- data.dataset_name = args.dataset_name
- data.norm_char_emb = args.norm_char_emb
- data.norm_gaz_emb = args.norm_gaz_emb
- data.norm_kg_emb = args.norm_kg_emb
- data.use_bert = args.use_bert
- data.pre_model = args.pre_model
- data.language = args.language
- data.number_normalized = args.number_normalized
- data.max_sentence_length = args.max_sentence_length
- data.token_path = args.token_path
- data.bert_path = args.bert_path
- data.build_gaz_file(args.gaz_file)
- data.build_kg_file(args.kg_file)
- data.generate_instance(args.train_file, "train", False)
- data.generate_instance(args.dev_file, "dev")
- data.generate_instance(args.test_file, "test")
- #data.build_char_pretrain_emb(args.char_embedding_path)
- #data.build_gaz_pretrain_emb(args.gaz_file)
- #data.build_kg_pretrain_emb(args.kg_file)
- data.build_entity_pretrain_emb(args.entity_type_embdding_path)
- data.build_entity_type_id(args.entity_type_embdding_path)
- data.fix_alphabet()
- data.get_tag_scheme()
- save_data_setting(data, data_stored_directory)
- return data
-
-
- def save_data_setting(data, data_stored_directory):
- data.show_data_summary()
- if not os.path.exists(data_stored_directory):
- os.makedirs(data_stored_directory)
- dataset_saved_name = data_stored_directory + data.dataset_name +"_dataset.dset"
- with open(dataset_saved_name, 'wb') as fp:
- pickle.dump(data, fp, pickle.HIGHEST_PROTOCOL)
- # with h5py.File(dataset_saved_name, 'w') as f:
- # f.create_dataset('X', data=data)
- print("Data setting saved to file: ", dataset_saved_name)
-
- def load_data_setting(data_stored_directory, name):
- dataset_saved_name = data_stored_directory + name + "_dataset.dset"
- with open(dataset_saved_name, 'rb') as fp:
- data = pickle.load(fp)
- # with h5py.File(dataset_saved_name, 'r') as f:
- # data = f['X'][:]
- print("Data setting loaded from file: ", dataset_saved_name)
- data.show_data_summary()
- return data
-
-
- def lr_decay(optimizer, epoch, decay_rate, init_lr):
- lr = init_lr * ((1-decay_rate)**epoch)
- print(" Learning rate is setted as:", lr)
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- return optimizer
-
-
- def predict_check(pred_variable, gold_variable, mask_variable):
- """
- input:
- pred_variable (batch_size, sent_len): pred tag result, in numpy format
- gold_variable (batch_size, sent_len): gold result variable
- mask_variable (batch_size, sent_len): mask variable
- """
- pred = pred_variable.cpu().data.numpy()
- gold = gold_variable.cpu().data.numpy()
- mask = mask_variable.cpu().data.numpy()
- overlaped = (pred == gold)
- right_token = np.sum(overlaped * mask)
- total_token = mask.sum()
- # print("right: %s, total: %s"%(right_token, total_token))
- return right_token, total_token
-
-
- def recover_label(char, pred_variable, gold_variable, mask_variable, label_alphabet, instances_txt, tokenizer, word_recover):
- """
- input:
- pred_variable (batch_size, sent_len): pred tag result
- gold_variable (batch_size, sent_len): gold result variable
- mask_variable (batch_size, sent_len): mask variable
- """
- # pred_variable = pred_variable[word_recover]
- # gold_variable = gold_variable[word_recover]
- # mask_variable = mask_variable[word_recover]
- seq_len = gold_variable.size(1)
- mask = mask_variable.cpu().data.numpy()
- pred_tag = pred_variable.cpu().data.numpy()
- gold_tag = gold_variable.cpu().data.numpy()
- batch_size = mask.shape[0]
-
- pred_label = []
- gold_label = []
- new_char = []
- token_starts = []
- token_ends = []
- co = 0
- mis = 0
- #print(instances_txt)
- #print(batch_size)
- for idx in range(batch_size):
- token_start = instances_txt[idx][-2]
- token_end = instances_txt[idx][-1]
- token_start.append(-1)
- token_start.insert(0, -1)
- token_end.append(-1)
- token_end.insert(0, -1)
- txt = tokenizer.convert_ids_to_tokens(char.cpu().numpy()[idx])
- # print(txt)
- # print(token_start)
- # print(token_end)
- # print(mask[idx])
- new_pred_tag = []
- new_gold_tag = []
- new_mask = []
- process_char = []
- new_token_start = []
- new_token_end = []
- total_token = mask[idx].sum()
- # print(total_token)
- # print(len(token_start))
- # print(len(token_end))
- # assert total_token == len(token_start)
- # assert total_token == len(token_end)
- for i, word in enumerate(txt):
- if word.startswith('##'): # 使用BERT的特殊的tokenization时,英文复原的时候可以忽略token开的部分
- continue
- new_token_end[-1] = token_end[i]
- else:
- new_pred_tag.append(pred_tag[idx][i])
- new_gold_tag.append(gold_tag[idx][i])
- new_mask.append(mask[idx][i])
- process_char.append(word)
- if mask[idx][i]:
- new_token_start.append(token_start[i])
- new_token_end.append(token_end[i])
- co += 1
- # print(len(new_pred_tag))
- # print(len(new_gold_tag))
- # print(new_mask)
- assert len(process_char) == len(new_pred_tag)
- new_seq = len(new_pred_tag)
- pred = [label_alphabet.get_instance(new_pred_tag[idy]) for idy in range(new_seq) if new_mask[idy]]
- gold = [label_alphabet.get_instance(new_gold_tag[idy]) for idy in range(new_seq) if new_mask[idy]]
- chars = [process_char[idy] for idy in range(new_seq) if new_mask[idy]]
- # print(pred)
- # print(chars)
- pred = pred[1:-1]
- gold = gold[1:-1]
- new_token_start = new_token_start[1:-1]
- new_token_end = new_token_end[1:-1]
- chars = chars[1:-1]
-
- pred_label.append(pred)
- gold_label.append(gold)
- new_char.append(chars)
-
- assert len(pred) == len(new_token_start)
- assert len(pred) == len(new_token_start)
- # print(chars)
- # print(gold)
- # print(token_start)
- # print(token_end)
- token_starts.append(new_token_start)
- token_ends.append(new_token_end)
-
- return new_char, pred_label, gold_label, token_starts, token_ends
-
-
- def evaluate(data, model, tokenizer, args, name):
-
- if name == "train":
- instances = data.train_ids
- instances_txt = data.train_texts
- elif name == "dev":
- instances = data.dev_ids
- instances_txt = data.dev_texts
- elif name == 'test':
- instances = data.test_ids
- instances_txt = data.test_texts
- else:
- print("Error: wrong evaluate name,", name)
- pred_results = []
- gold_results = []
- model.eval()
- batch_size = args.batch_size
- start_time = time.time()
- train_num = len(instances)
- token_starts = []
- token_ends = []
- chars = []
- total_batch = train_num//batch_size+1
- for batch_id in range(total_batch):
- start = batch_id*batch_size
- end = (batch_id+1)*batch_size
- if end > train_num:
- end = train_num
- #(start)
- #print(end)
- instance = instances[start:end]
- instance_txt = instances_txt[start:end]
-
- if not instance:
- continue
- char, c_len, gazs, kgs, bert_seq, bert_mask, mask, label, recover, t_graph, c_graph, l_graph, kg_t_graph, kg_c_graph, kg_l_graph, kg_span_graph, span_c_graph = batchify(instance, data.entity_id, args.use_gpu, args.pre_model)
- tag_seq = model(char, c_len, gazs, kgs, bert_seq, bert_mask, t_graph, c_graph, l_graph, kg_t_graph, kg_c_graph, kg_l_graph, kg_span_graph, span_c_graph, mask)
- # print(char.size())
- char, pred_label, gold_label, token_start, token_end = recover_label(char, tag_seq, label, mask, data.label_alphabet, instance_txt, tokenizer, recover)
- pred_results += pred_label
- gold_results += gold_label
- token_starts += token_start
- chars += char
- token_ends += token_end
- decode_time = time.time() - start_time
- speed = len(instances)/decode_time
- # print(pred_results)
- acc, p, r, f, pred = get_ner_fmeasure(gold_results, pred_results, token_starts, token_ends, data.tagscheme)
- return chars, speed, acc, p, r, f, pred
-
-
- def train(data, model, tokenizer, args):
- # param_optimizer = list(model.named_parameters())
- # no_decay = ['bias', 'gamma', 'beta']
- no_decay = ["bias", "LayerNorm.weight"]
-
- model_param = list(model.named_parameters())
-
- bert_param_optimizer = []
- other_param_optimizer = []
-
- for name, para in model_param:
- # print(name)
- space = name.split('.')
- # print(space)
- if space[0] == 'bert_encoder':
- bert_param_optimizer.append((name, para))
- else:
- other_param_optimizer.append((name, para))
-
- optimizer_grouped_parameters = [
- # bert other module
- {"params": [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)],
- "weight_decay": args.lr_decay, 'lr': args.lr},
- {"params": [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)],
- "weight_decay": 0.0, 'lr': args.lr},
-
- # 其他模块,差分学习率
- {"params": [p for n, p in other_param_optimizer if not any(nd in n for nd in no_decay)],
- "weight_decay": args.lr_decay, 'lr': args.lr_other},
- {"params": [p for n, p in other_param_optimizer if any(nd in n for nd in no_decay)],
- "weight_decay": 0.0, 'lr': args.lr_other},
- ] # 差分学习率
- batch_size = args.batch_size
- train_num = len(data.train_ids)
- total_batch = train_num // batch_size + 1
- num_train_optimization_steps = int(total_batch / args.gradient_accumulation_steps) * args.max_epoch
- warmup_steps = int(0.1 * num_train_optimization_steps)
- if args.optimizer == "Adam":
- # optimizer = optim.Adam(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.l2_penalty)
- optimizer=BertAdam(optimizer_grouped_parameters, warmup=0.1, t_total=num_train_optimization_steps, lr=args.lr)
- # optimizer = optim.Adam([{'params': optimizer_grouped_parameters}, {'params':base_parameter, 'lr': args.lr}], lr=args.lr_other, weight_decay=args.l2_penalty)
- elif args.optimizer == "SGD":
- optimizer = optim.SGD(optimizer_grouped_parameters, lr=args.lr, weight_decay=args.l2_penalty)
- # scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)
- best_dev = -1
- for idx in range(args.max_epoch):
- epoch_start = time.time()
- temp_start = epoch_start
- print("Epoch: %s/%s" % (idx, args.max_epoch))
- optimizer = lr_decay(optimizer, idx, args.lr_decay, args.lr)
- instance_count = 0
- sample_loss = 0
- total_loss = 0
- random.shuffle(data.train_ids)
- model.train()
- model.zero_grad()
-
- for batch_id in range(total_batch):
- start = batch_id*batch_size
- end = (batch_id+1)*batch_size
- if end > train_num:
- end = train_num
-
- instance = data.train_ids[start:end]
- if not instance:
- continue
- model.zero_grad()
- char, c_len, gazs, kgs, bert_seq, bert_mask, mask, label, recover, t_graph, c_graph, l_graph, kg_t_graph, kg_c_graph, kg_l_graph, kg_span_graph, span_c_graph = batchify(instance, data.entity_id, args.use_gpu, args.pre_model)
- # print(bert_seq)
- # print(label.size())
- loss = model.neg_log_likelihood(char, c_len, gazs, kgs, bert_seq, bert_mask, t_graph, c_graph, l_graph, kg_t_graph, kg_c_graph, kg_l_graph, kg_span_graph, span_c_graph, mask, label)
- instance_count += 1
- sample_loss += loss.item()
- total_loss += loss.item()
- loss.backward()
- if args.use_clip:
- torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
-
- if (batch_id + 1) % args.gradient_accumulation_steps == 0:
- optimizer.step()
- # scheduler.step() # Update learning rate schedule
- model.zero_grad()
-
- if end % 500 == 0:
- temp_time = time.time()
- temp_cost = temp_time - temp_start
- temp_start = temp_time
- print(" Instance: %s; Time: %.2fs; loss: %.4f" % (
- end, temp_cost, sample_loss))
- sys.stdout.flush()
- sample_loss = 0
- temp_time = time.time()
- temp_cost = temp_time - temp_start
- print(" Instance: %s; Time: %.2fs; loss: %.4f" % (end, temp_cost, sample_loss))
- epoch_finish = time.time()
- epoch_cost = epoch_finish - epoch_start
- print("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s"%(idx, epoch_cost, train_num/epoch_cost, total_loss))
- chars, speed, acc, p, r, f, _ = evaluate(data, model, tokenizer, args, "dev")
- dev_finish = time.time()
- dev_cost = dev_finish - epoch_finish
- current_score = f
- print(
- "Dev: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (dev_cost, speed, acc, p, r, f))
- if current_score > best_dev:
- print("Exceed previous best f score:", best_dev)
- if not os.path.exists(args.param_stored_directory + args.dataset_name + "_param"):
- os.makedirs(args.param_stored_directory + args.dataset_name + "_param")
- model_name = "{}epoch_{}_f1_{}.model".format(args.param_stored_directory + args.dataset_name + "_param/", idx, current_score)
- torch.save(model.state_dict(), model_name)
- best_dev = current_score
-
- # evaluate test
- #chars,speed, acc, p, r, f, _ = evaluate(data, model, tokenizer, args, "test")
- #print(
- #"Test: speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f" % (speed, acc, p, r, f))
- gc.collect()
-
- def write_test_result(chars, pred_result, out_file):
- with open(out_file, 'w') as f:
- for i, r in enumerate(pred_result):
- for j, l in enumerate(r):
- # f.write(chars[i][j].strip() + '\t' + l.strip() + '\n')
- f.write(l.strip() + '\n')
- f.write('\n')
-
-
- if __name__ == '__main__':
-
- args, unparsed = get_args()
- for arg in vars(args):
- print(arg, ":", getattr(args, arg))
- os.environ["CUDA_VISIBLE_DEVICES"] = str(args.visible_gpu)
- # 设置随机种子
- seed = args.random_seed
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- np.random.seed(seed)
- random.seed(seed)
- torch.cuda.set_device(int(args.visible_gpu))
-
- torch.backends.cudnn.deterministic = True
- data = data_initialization(args)
-
- tokenizer = BertTokenizer.from_pretrained(data.token_path, do_lower_case=True)
- model = BLSTM_GCN_CRF(data, args)
- train(data, model, tokenizer, args)
-
-
-
|