|
- import os
- import sys
- import json
- from model_url import get_model_resp, get_url_tokenizer
-
-
- def run_predict(url, log_path, few_shot = True):
- import numpy as np
- tokenizer = get_url_tokenizer()
- MAIN_DIR = os.path.dirname(os.path.abspath(__file__))
- file_dir = MAIN_DIR + "/task_dataset/C3/m-dev.json"
- count = 0
- correct_num = 0
- acc = 0
-
- with open(file_dir, "r", encoding="utf8") as f:
- data = json.load(f)
- rows = []
- for _, row in enumerate(data):
- content = row[0]
- content_str = ' '.join(
- [''.join(paragraph) for paragraph in content])
-
- for question in row[1]:
- label = question['choice'].index(question['answer'])
- length = len(question['choice'])
- if length < 4:
- fill_value = question['choice'][0] # 以第一个值为填充值
- fill_count = 4 - length # 需要填充的数量
- question['choice'] += [fill_value] * fill_count # 填充
-
- rows.append({
- 'content': content_str,
- 'question': question['question'],
- 'choices': question['choice'],
- 'choice0': question['choice'][0],
- 'choice1': question['choice'][1],
- 'choice2': question['choice'][2],
- 'choice3': question['choice'][3],
- 'label': label
- })
- for info in rows:
- count += 1
- content, question, choices, choice0, choice1, choice2, choice3, label = info['content'], \
- info['question'],info['choices'],info['choice0'],info['choice1'],info['choice2'],info['choice3'],info['label']
- example = f"文章:{content}\n问题:{question}\n答案:"
- input_str_one = f"{example}{choice0}"
- input_str_two = f"{example}{choice1}"
- input_str_thr = f"{example}{choice2}"
- input_str_fou = f"{example}{choice3}"
-
- input_str = []
- input_str.append(input_str_one)
- input_str.append(input_str_two)
- input_str.append(input_str_thr)
- input_str.append(input_str_fou)
-
- mask_length_list = []
- input_length_list = []
-
-
- for pred in input_str:
- input_length_list.append(len(tokenizer.encode(pred)))
- mask_length_list.append(len(tokenizer.encode(example)))
-
- model_resp = get_model_resp(url=url, input_str=input_str, tokens_to_generate=0, top_k=1, logprobs=True)
- return_resp = []
- for resp_item, input_length, mask_length in zip(model_resp, input_length_list, mask_length_list):
- # assert len(resp_item) == input_length - 1
- item = resp_item[mask_length - 1:input_length - 1]
- return_resp.append(item)
-
- pred_list = [sum(logprobs) / len(logprobs) for logprobs in return_resp]
- answers_pred = int(np.argmax(pred_list))
-
- if answers_pred == label:
- correct_num += 1
- acc = correct_num / count
- print(f"c3m, 准确率Acc:{acc}, number: {count}")
-
- if not few_shot:
- with open(log_path + '/c3m_zeroshot.txt', 'w') as file:
- file.write(f"c3m, zero shot , Acc: {acc}, number: {count}")
- else:
- with open(log_path + '/c3m_fewshot.txt', 'w') as file:
- file.write(f"c3m, few shot , Acc: {acc}, number: {count}")
|