|
-
-
- TASK2DESC = {
- "high_school_physics": "高中物理",
- "fire_engineer": "注册消防工程师",
- "computer_network": "计算机网络",
- "advanced_mathematics": "高等数学",
- "logic": "逻辑学",
- "middle_school_physics": "初中物理",
- "clinical_medicine": "临床医学",
- "probability_and_statistics": "概率统计",
- "ideological_and_moral_cultivation": "思想道德修养与法律基础",
- "operating_system": "操作系统",
- "middle_school_mathematics": "初中数学",
- "chinese_language_and_literature": "中国语言文学",
- "electrical_engineer": "注册电气工程师",
- "business_administration": "工商管理",
- "high_school_geography": "高中地理",
- "modern_chinese_history": "近代史纲要",
- "legal_professional": "法律职业资格",
- "middle_school_geography": "初中地理",
- "middle_school_chemistry": "初中化学",
- "high_school_biology": "高中生物",
- "high_school_chemistry": "高中化学",
- "physician": "医师资格",
- "high_school_chinese": "高中语文",
- "tax_accountant": "税务师",
- "high_school_history": "高中历史",
- "mao_zedong_thought": "毛泽东思想和中国特色社会主义理论概论",
- "high_school_mathematics": "高中数学",
- "professional_tour_guide": "导游资格",
- "veterinary_medicine": "兽医学",
- "environmental_impact_assessment_engineer": "环境影响评价工程师",
- "basic_medicine": "基础医学",
- "education_science": "教育学",
- "urban_and_rural_planner": "注册城乡规划师",
- "middle_school_biology": "初中生物",
- "plant_protection": "植物保护",
- "middle_school_history": "初中历史",
- "high_school_politics": "高中政治",
- "metrology_engineer": "注册计量师",
- "art_studies": "艺术学",
- "college_economics": "大学经济学",
- "college_chemistry": "大学化学",
- "law": "法学",
- "sports_science": "体育学",
- "civil_servant": "公务员",
- "college_programming": "大学编程",
- "middle_school_politics": "初中政治",
- "teacher_qualification": "教师资格",
- "computer_architecture": "计算机组成",
- "college_physics": "大学物理",
- "discrete_mathematics": "离散数学",
- "marxism": "马克思主义基本原理",
- "accountant": "注册会计师",
- }
-
- import os
- import sys
- import json
- from model_url import get_model_resp, get_url_tokenizer
- import pandas as pd
-
- def run_predict(url, log_path, few_shot = True):
- import numpy as np
- tokenizer = get_url_tokenizer()
-
- id_label = {0: "A", 1: "B", 2: "C", 3: "D"}
- MAIN_DIR = os.path.dirname(os.path.abspath(__file__))
- File_Dir = MAIN_DIR + "/task_dataset/ceval-exam"
- file_dir = File_Dir+"/test"
- dirs = [File_Dir+"/test/"+dir for dir in os.listdir(file_dir) if '.csv' in dir.lower()]
- dirs = sorted(dirs)
- results, cnt, corrent_num_all, acc_all = {}, 0, 0, 0
- for file_dir in dirs:
- count = 0
- correct_num = 0
- acc = 0
- class_val = file_dir.split('/')[-1][:-9]
- # print(f"\nclass_val: {class_val}")
-
- # if class_val in ("middle_school_geography", "middle_school_history","middle_school_biology","middle_school_mathematics",
- # "middle_school_physics", "middle_school_politics", "high_school_geography", "high_school_chemistry",
- # "high_school_history", "high_school_chinese","high_school_politics", "environmental_impact_assessment_engineer",
- # "computer_network","teacher_qualification","tax_accountant","physician","urban_and_rural_planner",
- # "metrology_engineer","accountant","fire_engineer",'chinese_language_and_literature','professional_tour_guide',
- # "art_studies","basic_medicine","business_administration","clinical_medicine","college_chemistry",
- # "college_economics","computer_architecture","education_science","ideological_and_moral_cultivation",
- # "mao_zedong_thought","marxism","modern_chinese_history","operating_system","plant_protection",
- # "sports_science",
- #
- # ):
-
- if class_val in (
- 'civil_servant',
- ):
- # dev_file = os.path.join(File_Dir, "dev", class_val+"_dev.csv")
- # dev_data = pd.read_csv(dev_file)
- # # print(f"\ndev_file_dir: {dev_file_dir}")
- # ######################################################few-shot################################################
- # dev_index = [0,1,2,3,4]
- # sample_index = np.random.choice(dev_index, size=shot_len, replace=False)
- # example = f"以下是中国关于{TASK2DESC[class_val]}考试题,请给出正确答案。\n"
- # for i in sample_index:
- # question, A, B, C, D, answer, _ = dev_data.iloc[i][1:]
- # # choice = "\n".join(
- # # [
- # # "A. " + A,
- # # "B. " + B,
- # # "C. " + C,
- # # "D. " + D,
- # # ]
- # # )
- # # example += f"{question}\n{choice}\n答案:{answer}.{dev_data.iloc[i][answer]}\n"
- # choice = f"\n答案:{dev_data.iloc[i][answer]}"
- # example += f"问题:{question}{choice}\n"
-
-
- data = pd.read_csv(file_dir)
- columns = data.columns
- info = {}
- for idx in range(data.count()[columns[0]]):
- # print(f"=================={idx}====================")
- cnt += 1
- count += 1
- question, A, B, C, D = data.iloc[idx][1:]
-
- A = A.split(" ")
- B = B.split(" ")
- C = C.split(" ")
- D = D.split(" ")
-
- question = question.replace("填入画横线部分最恰当的一项是____。", "").strip()
-
- input_str_one = question
- input_str_two = question
- input_str_thr = question
- input_str_fou = question
-
- for i in range(len(A)):
- input_str_one = input_str_one.replace("____", A[i], 1)
- input_str_two = input_str_two.replace("____", B[i], 1)
- input_str_thr = input_str_thr.replace("____", C[i], 1)
- input_str_fou = input_str_fou.replace("____", D[i], 1)
-
- input_str = []
- input_str.append(input_str_one)
- input_str.append(input_str_two)
- input_str.append(input_str_thr)
- input_str.append(input_str_fou)
-
- input_length_list = []
- for pred in input_str:
- input_length_list.append(len(tokenizer.encode(pred)))
-
- model_resp = get_model_resp(url=url, input_str=input_str, tokens_to_generate=0, top_k=1,
- logprobs=True)
- pred_list = [sum(logprobs) / len(logprobs) for logprobs in model_resp]
- answers_pred = id_label[int(np.argmax(pred_list))]
-
- info[str(idx)] = answers_pred
-
- results[class_val] = info
- print(f"-----------------------{class_val}------------------------")
- print("CEval civil: ", class_val, ", few shot", " number: ", cnt)
-
- with open(log_path + '/ceval_civil_old.json', 'w', encoding='utf-8') as sub_fl:
- sub_fl.write(json.dumps(results, ensure_ascii=False))
|