|
- import os
- import sys
- import json
- import jpype
- import time
-
- jpath = "-Djava.class.path=./update/Search.jar:./search/lib/lucene-core-8.3.1.jar" \
- ":./search/lib/lucene-queryparser-8.3.1.jar:./search/lib/IKAnalyzer-5.0.jar:./search/lib/mysql-connector-java-8.0.18.jar"
- if not jpype.isJVMStarted():
- jpype.startJVM(jpype.getDefaultJVMPath(), '-ea', jpath, convertStrings=False)
- QSearch = jpype.JClass('demo.TextIndex')
- qs = QSearch()
-
- def score(q, topk=1):
- result = qs.getText(topk, q)
- text = []
- scores = []
- touch = []
- part_text = []
- for bean in result:
- text.append(bean.getText())
- scores.append(bean.getScore())
- x = bean.getTouch()
- touch.append(x)
- part_text.append(qs.getChapContent(x))
- # part_text.append("")
- return scores, text, touch, part_text
-
- def main():
- cnt = 0
- options_to_logic = {}
- with open("./data/all_q.txt", "r") as f:
- lines = f.readlines()
- for line in lines:
- _, l, _, _, opt = line.strip().split("\t")
- options_to_logic[opt] = l
- # with open("./data/all_not_include_test_evidence_top_5.json.x.json", "r") as f, \
- # open("./data/all_not_include_test_evidence_top_1_part_all_tmp.json", "w") as nf:
- labels = []
- with open("/data/ldf/baidu/data/pan/test.json", "r") as f, open("tmp.json", "w") as nf:
- lines = f.readlines()
- right = 0
-
- for ind, line in enumerate(lines):
- '''
- questionType, questionId, questionText, optionImg, questionImg, backgroundText, answer: []
- audiourl, subject, option
- '''
- instance = json.loads(line.strip())
- if len(instance["answer"]) == 0 or instance["answer"][0].strip() == "":
- continue
- if len(instance["option"]) != 5:
- continue
- instance['context'] = []
- option = '####'.join(instance['option']).replace("\n", "")
- q = instance['backgroundText'] + instance['questionText']
- q = q.replace("\n", " ")
- if instance['questionType'] == "多项选择题":
- continue
- tuples = []
-
- # print(instance["answer"])
- for i, opt in enumerate(instance['option']):
- query = q + opt
- start = time.time()
- try:
- output = score(query)
- print("time: ", time.time() - start)
- tuples.append((i, output[0][0], output[1][0], output[2][0], opt))
- xxstr = []
- for j in range(1):
- xxstr.append(str(output[2][j][0]))
- instance['context'].append("########".join(xxstr))
- # print(instance['context'])
- except:
- instance['context'].append("")
- import traceback
- traceback.print_exc()
-
- assert len(instance['option']) == 5
- reverse = True
- # if options_to_logic[option] == "逆向":
- # reverse = False
- tuples = sorted(tuples, key=lambda x: x[1], reverse=reverse)
- # nf.write(json.dumps(instance, ensure_ascii=False)+"\n")
- if len(tuples) == 0:
- continue
- if tuples[0][0] == int(instance['answer'][0]) - 1:
- right += 1
- labels.append(1)
-
- else:
- print(q)
- print(tuples[:5])
- print()
- labels.append(0)
- cnt += 1
- json.dump(labels, nf)
- print("right: ", right)
- print("cnt: ", cnt)
- print("ACC: ", 1.0*right/cnt)
-
- if __name__ == "__main__":
- main()
- # topk = 3
- # option = ['阿糖腺苷', '恩替卡韦', '泛昔洛韦', '利巴韦林', '膦甲酸钠']
-
- # for o in option:
- # output = score("患者,女,27岁,确诊慢性乙型肝炎3年,近日化验结果:HBV-DNA2X105copies/ml,ALT122U/L。拟予以抗病毒治疗,首选的药物是"+o, topk=topk)
- # tuples = [] ################葡萄糖注射液双歧三联活菌制剂
- # for i in range(topk):
- # tuples.append((i, output[0][i], output[1][i], output[2][i]))
- # tuples = sorted(tuples, key=lambda x: x[1], reverse=True)
- # for i in range(topk):
-
-
- # print(i)
- # print(output[0][i])
- # print(output[1][i])
- # print(output[2][i])
- # print(output[3][i])
- # print()
|