|
- from turtle import pos
- from utils.data import Data
- from utils.batchify import batchify
- from utils.config import get_args
- from utils.metric import get_ner_fmeasure
- from model.bilstm_gcn_crf import BLSTM_GCN_CRF
- from pytorch_pretrained_bert import BertTokenizer
- from pytorch_pretrained_bert.optimization import BertAdam
- from utils.tokenization import *
- import os
- import numpy as np
- import copy
- import pickle
- import torch
- import torch.optim as optim
- import time
- import random
- import sys
- import gc
- import collections
- import h5py
- import warnings
- import json
- import time
- warnings.filterwarnings("ignore", category=UserWarning)
-
-
- def data_initialization(args):
- data_stored_directory = args.data_stored_directory
- file = data_stored_directory + args.dataset_name + "_dataset.dset"
- if os.path.exists(file) and not args.refresh:
- data = load_data_setting(data_stored_directory, args.dataset_name)
- else:
- data = Data()
- data.dataset_name = args.dataset_name
- data.norm_char_emb = args.norm_char_emb
- data.norm_gaz_emb = args.norm_gaz_emb
- data.norm_kg_emb = args.norm_kg_emb
- data.use_bert = args.use_bert
- data.pre_model = args.pre_model
- data.number_normalized = args.number_normalized
- data.max_sentence_length = args.max_sentence_length
- data.build_gaz_file(args.gaz_file)
- data.build_kg_file(args.kg_file)
- data.generate_instance(args.train_file, "train", False)
- data.generate_instance(args.dev_file, "dev")
- data.generate_instance(args.test_file, "test")
- data.build_char_pretrain_emb(args.char_embedding_path)
- data.build_gaz_pretrain_emb(args.gaz_file)
- data.build_kg_pretrain_emb(args.kg_file)
- data.build_entity_pretrain_emb(args.entity_type_embdding_path)
- data.build_entity_type_id(args.entity_type_embdding_path)
-
- data.fix_alphabet()
- data.get_tag_scheme()
- # save_data_setting(data, data_stored_directory)
- return data
-
- def load_data_setting(data_stored_directory, name):
- dataset_saved_name = data_stored_directory + name + "_dataset.dset"
- with open(dataset_saved_name, 'rb') as fp:
- data = pickle.load(fp)
- # with h5py.File(dataset_saved_name, 'r') as f:
- # data = f['X'][:]
- print("Data setting loaded from file: ", dataset_saved_name)
- data.show_data_summary()
- return data
-
- def read_sentences(tokenizer_new, content, data, language='CN'):
- sentences = seg_sentence(content, data.max_sentence_length, language=language)
-
- all_token, all_pos = tokenizer_new.tokenize_doc(sentences)
- # print(all_token)
- # print(len(all_token))
- instance_texts = []
- instance_ids = []
- for k, sen_token in enumerate(all_token):
- gazs = []
- kgs = []
- gaz_ids = []
- kg_ids = []
- for i in range(len(sen_token)):
- matched_list = data.gaz.enumerateMatchList(sen_token[i:])
- kg_list = data.kg.enumerateMatchList(sen_token[i:])
- if language == 'EN':
- matched_length = [len(a.split('&')) for a in matched_list]
- matched_length_kg = [len(a.split('&')) for a in kg_list]
- elif language == 'CN':
- matched_length = [len(a) for a in matched_list]
- matched_length_kg = [len(a) for a in kg_list]
- gazs.append(matched_list)
- kgs.append(kg_list)
- matched_id = [data.gaz_alphabet.get_index(entity) for entity in matched_list]
- matched_id_kg = [data.kg_alphabet.get_index(entity) for entity in kg_list]
- if matched_id:
- gaz_ids.append([matched_id, matched_length])
- else:
- gaz_ids.append([])
-
- if matched_id_kg:
- kg_ids.append([matched_id_kg, matched_length_kg])
- else:
- kg_ids.append([])
-
- token_starts = [pos[0] for pos in all_pos[k]]
- token_ends = [pos[1] for pos in all_pos[k]]
- bert_char_ids = tokenizer.convert_tokens_to_ids(sen_token)
- labels = ['O'] * len(bert_char_ids)
- bert_label_ids = [data.label_alphabet.get_index(label) for label in labels]
-
- instance_texts.append([sen_token, gazs, kgs, labels, token_starts, token_ends])
- instance_ids.append([bert_char_ids, gaz_ids, kg_ids, bert_label_ids])
- return instance_texts, instance_ids
-
-
- def predict_check(pred_variable, gold_variable, mask_variable):
- """
- input:
- pred_variable (batch_size, sent_len): pred tag result, in numpy format
- gold_variable (batch_size, sent_len): gold result variable
- mask_variable (batch_size, sent_len): mask variable
- """
- pred = pred_variable.cpu().data.numpy()
- gold = gold_variable.cpu().data.numpy()
- mask = mask_variable.cpu().data.numpy()
- overlaped = (pred == gold)
- right_token = np.sum(overlaped * mask)
- total_token = mask.sum()
- # print("right: %s, total: %s"%(right_token, total_token))
- return right_token, total_token
-
-
- def recover_label(char, pred_variable, gold_variable, mask_variable, label_alphabet, instances_txt, tokenizer, word_recover):
- """
- input:
- pred_variable (batch_size, sent_len): pred tag result
- gold_variable (batch_size, sent_len): gold result variable
- mask_variable (batch_size, sent_len): mask variable
- """
- # pred_variable = pred_variable[word_recover]
- # gold_variable = gold_variable[word_recover]
- # mask_variable = mask_variable[word_recover]
- seq_len = gold_variable.size(1)
- mask = mask_variable.cpu().data.numpy()
- pred_tag = pred_variable.cpu().data.numpy()
- gold_tag = gold_variable.cpu().data.numpy()
- batch_size = mask.shape[0]
-
- pred_label = []
- gold_label = []
- new_char = []
- token_starts = []
- token_ends = []
- co = 0
- mis = 0
- #print(instances_txt)
- #print(batch_size)
- for idx in range(batch_size):
- token_start = instances_txt[idx][-2]
- token_end = instances_txt[idx][-1]
- token_start.append(-1)
- token_start.insert(0, -1)
- token_end.append(-1)
- token_end.insert(0, -1)
- txt = tokenizer.convert_ids_to_tokens(char.cpu().numpy()[idx])
- # print(txt)
- # print(mask[idx])
- new_pred_tag = []
- new_gold_tag = []
- new_mask = []
- process_char = []
- new_token_start = []
- new_token_end = []
- for i, word in enumerate(txt):
- if word.startswith('##'): # 使用BERT的特殊的tokenization时,英文复原的时候可以忽略token开的部分
- continue
- new_token_end[-1] = token_end[i]
- else:
- new_pred_tag.append(pred_tag[idx][i])
- new_gold_tag.append(gold_tag[idx][i])
- new_mask.append(mask[idx][i])
- process_char.append(word)
- if mask[idx][i]:
- new_token_start.append(token_start[i])
- new_token_end.append(token_end[i])
- co += 1
- # print(len(new_pred_tag))
- # print(len(new_gold_tag))
- # print(new_mask)
- assert len(process_char) == len(new_pred_tag)
- new_seq = len(new_pred_tag)
- pred = [label_alphabet.get_instance(new_pred_tag[idy]) for idy in range(new_seq) if new_mask[idy]]
- gold = [label_alphabet.get_instance(new_gold_tag[idy]) for idy in range(new_seq) if new_mask[idy]]
- chars = [process_char[idy] for idy in range(new_seq) if new_mask[idy]]
- # print(pred)
- # print(chars)
-
- pred = pred[1:-1]
- gold = gold[1:-1]
- new_token_start = new_token_start[1:-1]
- new_token_end = new_token_end[1:-1]
- chars = chars[1:-1]
- # print('&&'*20)
- # print(len(instances_txt))
- # if len(instances_txt[idx]) < 6:
- # print(instances_txt[idx])
-
- # print(len(pred))
- # print(len(new_token_start))
- # if len(pred) != len(token_start):
- # mis += 1
- # continue
- pred_label.append(pred)
- gold_label.append(gold)
- new_char.append(chars)
-
- assert len(pred) == len(new_token_start)
- assert len(pred) == len(new_token_start)
- # print(chars)
- # print(gold)
- # print(token_start)
- # print(token_end)
- token_starts.append(new_token_start)
- token_ends.append(new_token_end)
-
- return new_char, pred_label, gold_label, token_starts, token_ends
-
-
- def evaluate(data, model, tokenizer, args, name):
-
- if name == "train":
- instances = data.train_ids
- instances_txt = data.train_texts
- elif name == "dev":
- instances = data.dev_ids
- instances_txt = data.dev_texts
- elif name == 'test':
- instances = data.test_ids
- instances_txt = data.test_texts
- else:
- print("Error: wrong evaluate name,", name)
- pred_results = []
- gold_results = []
- model.eval()
- batch_size = args.batch_size
- start_time = time.time()
- train_num = len(instances)
- token_starts = []
- token_ends = []
- chars = []
- total_batch = train_num//batch_size+1
- for batch_id in range(total_batch):
- start = batch_id*batch_size
- end = (batch_id+1)*batch_size
- if end > train_num:
- end = train_num
- #(start)
- #print(end)
- instance = instances[start:end]
- instance_txt = instances_txt[start:end]
-
- if not instance:
- continue
- char, c_len, gazs, kgs, bert_seq, bert_mask, mask, label, t_graph, c_graph, l_graph, kg_t_graph, kg_c_graph, kg_l_graph, kg_span_graph, span_c_graph = batchify(instance, data.entity_id, args.use_gpu, args.pre_model)
- tag_seq = model(char, c_len, gazs, kgs, bert_seq, bert_mask, t_graph, c_graph, l_graph, kg_t_graph, kg_c_graph, kg_l_graph, kg_span_graph, span_c_graph, mask)
- # print(char.size())
- char, pred_label, gold_label, token_start, token_end = recover_label(char, tag_seq, label, mask, data.label_alphabet, instance_txt, tokenizer)
- pred_results += pred_label
- gold_results += gold_label
- token_starts += token_start
- chars += char
- token_ends += token_end
- decode_time = time.time() - start_time
- speed = len(instances)/decode_time
- # print(pred_results)
- acc, p, r, f, pred = get_ner_fmeasure(gold_results, pred_results, token_starts, token_ends, data.tagscheme)
- return chars, speed, acc, p, r, f, pred
-
- # test document
- def test_sentences(data, doc, model, tokenizer, new_tokenizer, args):
- start = time.time()
- instances_txt, instances_ids = read_sentences(new_tokenizer, doc, data)
- # print(instances_txt)
- # print('*, token_time'*20)
- # print(end-start)
- pred_results = []
- gold_results = []
- model.eval()
- batch_size = args.batch_size
- start_time = time.time()
- train_num = len(instances_ids)
- total_batch = train_num//batch_size+1
- token_starts = []
- token_ends = []
- chars = []
- for batch_id in range(total_batch):
- start = batch_id*batch_size
- end = (batch_id+1)*batch_size
- if end > train_num:
- end = train_num
- #(start)
- #print(end)
- instance = instances_ids[start:end]
- instance_txt = instances_txt[start:end]
-
- #print('*'*20)
- # print(instance_txt)
- if not instance_txt:
- continue
- char, c_len, gazs, kgs, bert_seq, bert_mask, mask, label, recover, t_graph, c_graph, l_graph, kg_t_graph, kg_c_graph, kg_l_graph, kg_span_graph, span_c_graph = batchify(instance, data.entity_id, args.use_gpu, args.pre_model)
- tag_seq = model(char, c_len, gazs, kgs, bert_seq, bert_mask, t_graph, c_graph, l_graph, kg_t_graph, kg_c_graph, kg_l_graph, kg_span_graph, span_c_graph, mask)
-
- char, pred_label, gold_label, token_start, token_end = recover_label(char, tag_seq, label, mask, data.label_alphabet, instance_txt, tokenizer, recover)
- pred_results += pred_label
- gold_results += gold_label
- token_starts += token_start
- token_ends += token_end
- chars += char
- decode_time = time.time() - start_time
- # print('*decode time'*10)
- speed = len(instances_ids)/decode_time
- # print(pred_results)
- acc, p, r, f, preds = get_ner_fmeasure(gold_results, pred_results, token_starts, token_ends, data.tagscheme)
- return chars, speed, acc, p, r, f, preds
-
-
- # test file
- def load_file(filename, data, model, tokenizer, new_tokenizer, args):
- with open(filename, 'r') as f:
- lines = f.readlines()
- result = {}
- for line in lines:
- id_, doc = line.split('\t')
- _, speed, _, _, _, _, pred = test_sentences(data, doc, model, tokenizer, new_tokenizer, args)
- result[id_] =(doc, pred)
- return result
-
-
- def write_file_result(result, out_file):
- with open(out_file, 'w', encoding='utf-8') as f:
- temp = {}
- for k, v in result.items():
- doc = v[0]
- for sen_num, entities in v[1].items():
- for ne in entities:
- if ne[-1] not in temp:
- temp[ne[-1]] = []
- temp[ne[-1]].append(doc[ne[0]:ne[1]])
- for label, mention in temp.items():
- f.write(k + '\t' + label + '\t' + '\t'.join(mention) + '\n')
-
-
- def write_doc_result(result, out_file):
- with open(out_file, 'w', encoding='utf-8') as f:
-
- for k, v in result.items():
- sen = {}
- doc = v[0]
- sen['annotations'] = []
- sen['text'] = doc
- sen['id'] = k
-
- for sen_num, entities in v[1].items():
- for en in entities:
- entity = {}
- entity['start_offset'] = en[0]
- entity['end_offset'] = en[1]
- entity['label'] = en[2]
- entity['mention'] = doc[en[0]:en[1]]
- sen['annotations'].append(entity)
- json.dump(sen, f, ensure_ascii=False)
- f.write('\n')
-
-
-
-
-
- if __name__ == '__main__':
- # tokenizer = BertTokenizer.from_pretrained('/home/admin-mzry/pretrain_model/MTBERT/vocab.txt', do_lower_case=True)
- args, unparsed = get_args()
- for arg in vars(args):
- print(arg, ":", getattr(args, arg))
- os.environ["CUDA_VISIBLE_DEVICES"] = str(args.visible_gpu)
- # 设置随机种子
- seed = args.random_seed
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- np.random.seed(seed)
- random.seed(seed)
- torch.cuda.set_device(int(args.visible_gpu))
-
- torch.backends.cudnn.deterministic = True
- data = data_initialization(args)
- model = BLSTM_GCN_CRF(data, args)
-
- # load tokenization and model
- tokenization_path = args.token_path
- tokenizer = BertTokenizer.from_pretrained(tokenization_path, do_lower_case=True)
- new_tokenizer = GeneralTokenizer(language='CN', type='bert', cased=False)
- model_path = '/home/xy/xy_pro/TextAnalysis/flat-ner/data/model_param/meizhou_all_param/epoch_7_f1_0.9565095999032099.model'
- model.load_state_dict(torch.load(model_path))
-
- # test file
- filename = '1.txt'
- result = load_file(filename, data, model, tokenizer, new_tokenizer, args)
- write_doc_result(result, 'test.json')
- start = time.time()
-
- # test document
- doc = '1、一般情况 The general situation :患者因“左肺癌术后2月余,来院化疗”入院。2、病例特点 The case characteristics :患者因“左肺癌术后2月余,来院化疗”入院。患者约2月余前因“咳嗽咳痰月余”就诊于我院,查胸部CT示:左肺下叶肺癌伴阻塞性肺炎首先考虑,左下肺炎伴脓肿形成待除外,请结合临床及进一步检查。为进一步明确诊断,患者就诊于宁波市明州医院,查PET-CT示:1、左肺下叶肿块,FDG代谢异常增高,考虑肺癌可能性大伴远段阻塞性肺炎,请结合支气管镜病理;左肺上叶前段微小结节影,FDG代谢未见明显增高,CT定期复查;双肺少许纤维灶。2、纵隔左侧气管隆突下及左肺门区肿大淋巴结影,FDG代谢异常增高,考虑转移性淋巴结可能性大;余纵隔多发小淋巴结影,FDG代谢不同程度增高,考虑炎性增生淋巴结可能性大,不除外转移性淋巴结,请结合临床CT密切随诊;右肺门区钙化淋巴结。建议转上级医院进一步治疗,患者遂于2月前至上海胸科医院就诊,于2015.11.30在全麻下行“左下肺切除+系统淋巴结清扫术”,手术过程顺利,术后病理提示:左肺下叶外基底段及后基底段角化型鳞状细胞癌,大小6.5*4.5*7cm,肿瘤侵及脏层胸膜。支气管切断未见癌累及。淋巴结2+/9组见癌转移。管口淋巴结2枚伴纤维结节形成。于2016.01.09来我院行“顺铂针40mgD1-3+吉西他滨针1.6D1、8”方案化疗,后因白细胞下降未行吉西他滨针1.6D8化疗,升白治疗后白细胞恢复正常,为求行第2次减量化疗再来我院就诊。颈软,颈静脉无怒张,气管居中,左胸部有长约15cm手术疤痕,愈合可。左下肺未闻及呼吸音,左侧呼吸运动减弱,双肺呼吸音稍粗,未及明显啰音,心律齐,未及杂音,腹平,未见胃型,全腹触之软,无压痛无反跳痛,未及包块,肝脾肋下未及,移动性浊音阴性,肠鸣音4次/分,双肾区无叩击痛,四肢关节无畸形无活动受限,肌力肌张力正常。' # 修改输入
- chars, speed, acc, p, r, f, pred1 = test_sentences(data, doc, model, tokenizer, new_tokenizer, args)
- # chars, speed, acc, p, r, f, pred1 = evaluate(data, model1, tokenizer, args, 'dev')
- end = time.time()
- print(end-start)
- # for k, v in pred1.items():
- # for ne in v:
- # print('item: '+ doc[ne[0]:ne[1]] + ' ' + ne[-1])
- # print(pred1)
-
|