|
- import os
- import sys
- import json
- from model_url import get_model_resp, get_url_tokenizer
-
- def run_predict(url, log_path, few_shot = True):
- import numpy as np
- tokenizer = get_url_tokenizer()
-
- id_label_pretrain = {0: "不是", 1: "是"}
- id_label_sft = {0: "不是", 1: "是"}
- MAIN_DIR = os.path.dirname(os.path.abspath(__file__))
- file_dir = MAIN_DIR + "/task_dataset/csl/test_public.json"
- count = 0
- correct_num = 0
- acc = 0
-
- pre = ""
- devfile = MAIN_DIR + "/task_dataset/csl/dev_few_all.json"
- cnt = 0
- # if few_shot:
- # with open(devfile, "r", encoding="utf8") as f:
- # for line in f.readlines():
- # cnt += 1
- # line = json.loads(line)
- # abst, keyword, label = line["abst"], line["keyword"], line["label"]
- # if os.getenv('URL_VERSION') == 'v1':
- # pre += f"摘要:{abst}\n关键词{keyword}在这段摘要中是否全部为真实关键词:{id_label_pretrain[int(label)]}\n"
- # else:
- # pre += f"摘要:{abst}\n关键词:{keyword}\n判断上面摘要对应的关键词是否合适,如果合适则输出'是‘,不合适则输出'否'。\n{id_label_sft[int(label)]}\n"
- # if cnt == 5:
- # break
-
- with open(file_dir, "r", encoding="utf8") as f:
- for line in f.readlines():
- count += 1
- line = json.loads(line)
- abst, keyword, label = line["abst"], line["keyword"], line["label"]
- # if os.getenv('URL_VERSION') == 'v1':
- # example = f"{pre}摘要:{abst}\n关键词{keyword}在这段摘要中是否全部为真实关键词:" #
- # input_str_one = f"{example}不是"
- # input_str_two = f"{example}是"
- # else:
- # example = f"{pre}摘要:{abst}\n关键词:{keyword}\n判断上面摘要对应的关键词是否合适,如果合适则输出'是‘,不合适则输出'否'。\n" #
- # input_str_one = f"{example}否"
- # input_str_two = f"{example}是"
-
- ### opencompass prompt ######
- example = ""
- input_str_one = f"摘要:{abst}"
- input_str_two = f"摘要:{abst}\n关键词:{keyword}"
- input_str = []
- input_str.append(input_str_one)
- input_str.append(input_str_two)
- mask_length_list = []
- input_length_list = []
- for pred in input_str:
- input_length_list.append(len(tokenizer.encode(pred)))
- mask_length_list.append(len(tokenizer.encode(example)))
- model_resp = get_model_resp(url=url, input_str=input_str, tokens_to_generate=0, top_k=1, logprobs=True)
- return_resp = []
- for resp_item, input_length, mask_length in zip(model_resp, input_length_list, mask_length_list):
- # assert len(resp_item) == input_length - 1
- item = resp_item[mask_length - 1:input_length - 1]
- return_resp.append(item)
-
- pred_list = [sum(logprobs) / len(logprobs) for logprobs in return_resp]
- answers_pred = int(np.argmax(pred_list))
-
- if answers_pred == int(label):
- correct_num += 1
- acc = correct_num / count
- print(f"CSL, 准确率Acc:{acc}, number: {count}")
-
- if not few_shot:
- with open(log_path + '/CSL_zeroshot.txt', 'w') as file:
- file.write(f"CSL, zero shot , Acc: {acc}, number: {count}")
- else:
- with open(log_path + '/CSL_fewshot.txt', 'w') as file:
- file.write(f"CSL, few shot , Acc: {acc}, number: {count}")
|