|
-
- TASK2DESC = {
- "high_school_physics": "高中物理",
- "fire_engineer": "注册消防工程师",
- "computer_network": "计算机网络",
- "advanced_mathematics": "高等数学",
- "logic": "逻辑学",
- "middle_school_physics": "初中物理",
- "clinical_medicine": "临床医学",
- "probability_and_statistics": "概率统计",
- "ideological_and_moral_cultivation": "思想道德修养与法律基础",
- "operating_system": "操作系统",
- "middle_school_mathematics": "初中数学",
- "chinese_language_and_literature": "中国语言文学",
- "electrical_engineer": "注册电气工程师",
- "business_administration": "工商管理",
- "high_school_geography": "高中地理",
- "modern_chinese_history": "近代史纲要",
- "legal_professional": "法律职业资格",
- "middle_school_geography": "初中地理",
- "middle_school_chemistry": "初中化学",
- "high_school_biology": "高中生物",
- "high_school_chemistry": "高中化学",
- "physician": "医师资格",
- "high_school_chinese": "高中语文",
- "tax_accountant": "税务师",
- "high_school_history": "高中历史",
- "mao_zedong_thought": "毛泽东思想和中国特色社会主义理论概论",
- "high_school_mathematics": "高中数学",
- "professional_tour_guide": "导游资格",
- "veterinary_medicine": "兽医学",
- "environmental_impact_assessment_engineer": "环境影响评价工程师",
- "basic_medicine": "基础医学",
- "education_science": "教育学",
- "urban_and_rural_planner": "注册城乡规划师",
- "middle_school_biology": "初中生物",
- "plant_protection": "植物保护",
- "middle_school_history": "初中历史",
- "high_school_politics": "高中政治",
- "metrology_engineer": "注册计量师",
- "art_studies": "艺术学",
- "college_economics": "大学经济学",
- "college_chemistry": "大学化学",
- "law": "法学",
- "sports_science": "体育学",
- "civil_servant": "公务员",
- "college_programming": "大学编程",
- "middle_school_politics": "初中政治",
- "teacher_qualification": "教师资格",
- "computer_architecture": "计算机组成",
- "college_physics": "大学物理",
- "discrete_mathematics": "离散数学",
- "marxism": "马克思主义基本原理",
- "accountant": "注册会计师",
- }
-
- import os
- import sys
- import json
- from concurrent.futures import ThreadPoolExecutor
- from model_url import get_model_resp, get_url_tokenizer
- import pandas as pd
-
- def run_predict(url, log_path):
- import numpy as np
- tokenizer = get_url_tokenizer()
-
- id_label = {0: "A", 1: "B", 2: "C", 3: "D"}
- MAIN_DIR = os.path.dirname(os.path.abspath(__file__))
- File_Dir = MAIN_DIR + "/task_dataset/ceval-exam"
- file_dir = File_Dir+"/test"
- dirs = [File_Dir+"/test/"+dir for dir in os.listdir(file_dir) if '.csv' in dir.lower()]
- dirs = sorted(dirs)
- results, cnt, corrent_num_all, acc_all = {}, 0, 0, 0
-
- for file_dir in dirs:
- count = 0
- class_val = file_dir.split('/')[-1][:-9]
- example = f"以下是中国关于{TASK2DESC[class_val]}考试题,请给出正确答案。\n"
-
- data = pd.read_csv(file_dir)
- columns = data.columns
- info = {}
- input_str_list = []
- for idx in range(data.count()[columns[0]]):
- # print(f"=================={idx}====================")
- cnt += 1
- count += 1
- question, A, B, C, D = data.iloc[idx][1:]
- choice = "\n".join(
- [
- "A. " + A,
- "B. " + B,
- "C. " + C,
- "D. " + D,
- ]
- )
- mask = f"{example}{question}\n{choice}\n答案:"
- input_str_one = f"{mask}A"
- input_str_two = f"{mask}B"
- input_str_thr = f"{mask}C"
- input_str_fou = f"{mask}D"
-
- input_str = []
- input_str.append(input_str_one)
- input_str.append(input_str_two)
- input_str.append(input_str_thr)
- input_str.append(input_str_fou)
- mask_length_list = []
- input_length_list = []
- # for pred in input_str:
- # input_length_list.append(len(tokenizer.encode(pred)))
- # # mask_length_list.append(len(tokenizer.encode(mask)))
- # mask_length_list.append(0)
- input_str_list.append(input_str)
-
- with ThreadPoolExecutor() as executor:
- model_resp = list(
- executor.map(get_model_resp,
- [url] * len(input_str_list),
- input_str_list,
- [0] * len(input_str_list),
- [1] * len(input_str_list),
- [True] * len(input_str_list)
- )
- )
-
- for idx, resp_item in enumerate(model_resp):
- pred_list = [sum(logprobs) / len(logprobs) for logprobs in resp_item]
- answers_pred = id_label[int(np.argmax(pred_list))]
- info[str(idx)] = answers_pred
-
- results[class_val] = info
- print(f"-----------------------{class_val}------------------------")
- print("CEval: ", class_val, ", zero shot", " number: ", cnt)
-
- with open(log_path + '/ceval_zeroshot_new.json', 'w', encoding='utf-8') as sub_fl:
- sub_fl.write(json.dumps(results, ensure_ascii=False))
-
|