|
- import os
- import sys
- import json
- from model_url import get_model_resp, get_url_tokenizer
-
-
-
- def run_predict(url, log_path, few_shot = True):
- import numpy as np
- tokenizer = get_url_tokenizer()
-
- id_label = {0: 'contradiction', 1: 'entailment', 2: 'neutral'}
- MAIN_DIR = os.path.dirname(os.path.abspath(__file__))
- File_Dir = MAIN_DIR + "/task_dataset/cmnli_public"
- file_dir = os.path.join(File_Dir, "dev.json")
- count = 0
- correct_num = 0
- acc = 0
-
- pre_list = [
- ["新的权利已经足够好了", "每个人都很喜欢最新的福利", "无关"],
- ["他犹豫不决,一个人沉醉于这个村子的宁静之中。", "他喜欢这个村庄是多么的宁静。", "蕴含"],
- ["他以身殉职,终年59岁", "他是在今年去世的", "无关"],
- ["对不起事情就是这样。","事情就是这样,不需要道歉。", "矛盾"],
- ["提供自助餐和菜单。", "有自助餐。", "蕴含"],
- ]
- pre = ""
- if few_shot:
- for pre_str in pre_list:
- sentence1, sentence2, label = pre_str
- pre += f"语句一:“{sentence1}”\n语句二:“{sentence2}”\n请问这两句话是什么关系?{label}\n"
-
- with open(file_dir, "r", encoding="utf8") as f:
- for line in f.readlines():
- line = json.loads(line)
- sentence1, sentence2, label = line["sentence1"], line["sentence2"], line["label"]
- if label == '-':
- continue
- count += 1
-
- example = f"{pre}语句一:“{sentence1}”\n语句二:“{sentence2}”\n请问这两句话是什么关系?"
- input_str_one = f"{example}矛盾"
- input_str_two = f"{example}蕴含"
- input_str_thr = f"{example}无关"
-
- input_str = []
- input_str.append(input_str_one)
- input_str.append(input_str_two)
- input_str.append(input_str_thr)
-
- mask_length1 = len(tokenizer.encode(example))
- mask_length2 = len(tokenizer.encode(example))
- mask_length3 = len(tokenizer.encode(example))
-
- mask_length_list = []
- input_length_list = []
-
- mask_length_list.append(mask_length1)
- mask_length_list.append(mask_length2)
- mask_length_list.append(mask_length3)
-
- for pred in input_str:
- input_length_list.append(len(tokenizer.encode(pred)))
-
- model_resp = get_model_resp(url=url, input_str=input_str, tokens_to_generate=0, top_k=1, logprobs=True)
- return_resp = []
- for resp_item, input_length, mask_length in zip(model_resp, input_length_list, mask_length_list):
- assert len(resp_item) == input_length - 1
- item = resp_item[mask_length - 1:input_length - 1]
- return_resp.append(item)
-
- pred_list = [sum(logprobs) / len(logprobs) for logprobs in return_resp]
- answers_pred = int(np.argmax(pred_list))
-
- if id_label[answers_pred] == label:
- correct_num += 1
- acc = correct_num / count
- print(f"cmnli, 准确率Acc:{acc}, number: {count}")
-
- if not few_shot:
- with open(log_path + '/cmnli_zeroshot.txt', 'w') as file:
- file.write(f"cmnli, zero shot , Acc: {acc}, number: {count}")
- else:
- with open(log_path + '/cmnli_fewshot.txt', 'w') as file:
- file.write(f"cmnli, few shot , Acc: {acc}, number: {count}")
|