|
- import os
- import sys
- import json
- from model_url import get_model_resp, get_url_tokenizer, process_inputstr, json_out_stream
- import pandas as pd
-
- cmmlu_subject_mapping = {
- 'agronomy': '农学',
- 'anatomy': '解剖学',
- 'ancient_chinese': '古汉语',
- 'arts': '艺术学',
- 'astronomy': '天文学',
- 'business_ethics': '商业伦理',
- 'chinese_civil_service_exam': '中国公务员考试',
- 'chinese_driving_rule': '中国驾驶规则',
- 'chinese_food_culture': '中国饮食文化',
- 'chinese_foreign_policy': '中国外交政策',
- 'chinese_history': '中国历史',
- 'chinese_literature': '中国文学',
- 'chinese_teacher_qualification': '中国教师资格',
- 'clinical_knowledge': '临床知识',
- 'college_actuarial_science': '大学精算学',
- 'college_education': '大学教育学',
- 'college_engineering_hydrology': '大学工程水文学',
- 'college_law': '大学法律',
- 'college_mathematics': '大学数学',
- 'college_medical_statistics': '大学医学统计',
- 'college_medicine': '大学医学',
- 'computer_science': '计算机科学',
- 'computer_security': '计算机安全',
- 'conceptual_physics': '概念物理学',
- 'construction_project_management': '建设工程管理',
- 'economics': '经济学',
- 'education': '教育学',
- 'electrical_engineering': '电气工程',
- 'elementary_chinese': '小学语文',
- 'elementary_commonsense': '小学常识',
- 'elementary_information_and_technology': '小学信息技术',
- 'elementary_mathematics': '初等数学',
- 'ethnology': '民族学',
- 'food_science': '食品科学',
- 'genetics': '遗传学',
- 'global_facts': '全球事实',
- 'high_school_biology': '高中生物',
- 'high_school_chemistry': '高中化学',
- 'high_school_geography': '高中地理',
- 'high_school_mathematics': '高中数学',
- 'high_school_physics': '高中物理学',
- 'high_school_politics': '高中政治',
- 'human_sexuality': '人类性行为',
- 'international_law': '国际法学',
- 'journalism': '新闻学',
- 'jurisprudence': '法理学',
- 'legal_and_moral_basis': '法律与道德基础',
- 'logical': '逻辑学',
- 'machine_learning': '机器学习',
- 'management': '管理学',
- 'marketing': '市场营销',
- 'marxist_theory': '马克思主义理论',
- 'modern_chinese': '现代汉语',
- 'nutrition': '营养学',
- 'philosophy': '哲学',
- 'professional_accounting': '专业会计',
- 'professional_law': '专业法学',
- 'professional_medicine': '专业医学',
- 'professional_psychology': '专业心理学',
- 'public_relations': '公共关系',
- 'security_study': '安全研究',
- 'sociology': '社会学',
- 'sports_science': '体育学',
- 'traditional_chinese_medicine': '中医中药',
- 'virology': '病毒学',
- 'world_history': '世界历史',
- 'world_religions': '世界宗教'
- }
-
-
- def run_predict(url, log_path, few_shot = True):
- import numpy as np
- tokenizer = get_url_tokenizer()
-
- id_label = {0: "A", 1: "B", 2: "C", 3: "D"}
-
- File_Dir = "task_dataset/cmmlu"
- file_dir = File_Dir + "/test"
- dirs = [File_Dir+"/test/"+dir for dir in os.listdir(file_dir) if '.csv' in dir.lower()]
- dirs = sorted(dirs)
- total = 0
- results, acc_all = {}, {}
- results = []
- for file_dir in dirs:
- count = 0
- correct_num = 0
- acc = 0
- class_val = file_dir.split('/')[-1].split(".")[0]
- _ch_name = cmmlu_subject_mapping[class_val]
- # print(f"\nclass_val: {class_val}")
- pre = ""
- if few_shot:
- dev_file = File_Dir + "/dev/"+ class_val + "_dev.csv"
- dev_data = pd.read_csv(dev_file)
- dev_index = [0, 1, 2, 3]
- sample_index = np.random.choice(dev_index, size=4, replace=False)
-
- example = f"以下是关于{_ch_name}的单项选择题,请直接给出正确答案的选项。\n"
- for i in sample_index:
- id, question, A, B, C, D, answer = dev_data.iloc[i]
- pre += f"{example}题目:{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n答案是: "
- # choice = {
- # "A": A,
- # "B": B,
- # "C": C,
- # "D": D,
- # }
- # choice = f"{choice[answer]}"
- pre = f"{pre}{answer}\n"
-
- data = pd.read_csv(file_dir)
- columns = data.columns
- info = []
-
- for idx in range(data.count()[columns[0]]):
- count += 1
- example = f"以下是关于{_ch_name}的单项选择题,请直接给出正确答案的选项。\n"
- id, question, A, B, C, D, answers = data.iloc[idx]
- example = f"{pre}{example}题目:{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n答案是: {answers}"
- results.append({"text":example, "src":"cmmlu"})
-
- # input_str_one = f"A"
- # input_str_two = f"B"
- # input_str_thr = f"C"
- # input_str_fou = f"D"
- #
- # input_str_one, example1 = process_inputstr(input_str_one, example, tokenizer)
- # input_str_two, example2 = process_inputstr(input_str_two, example, tokenizer)
- # input_str_thr, example3 = process_inputstr(input_str_thr, example, tokenizer)
- # input_str_fou, example4 = process_inputstr(input_str_fou, example, tokenizer)
- #
- # input_str = []
- # input_str.append(input_str_one)
- # input_str.append(input_str_two)
- # input_str.append(input_str_thr)
- # input_str.append(input_str_fou)
- #
- # example_list = []
- # example_list.append(example1)
- # example_list.append(example2)
- # example_list.append(example3)
- # example_list.append(example4)
- #
- # mask_length_list = []
- # input_length_list = []
- # for pred, example in zip(input_str, example_list):
- # input_length_list.append(len(tokenizer.encode(pred)))
- # mask_length_list.append(len(tokenizer.encode(example)))
- # model_resp = get_model_resp(url=url, input_str=input_str, tokens_to_generate=0, top_k=1, logprobs=True)
- # return_resp = []
- # for resp_item, input_length, mask_length in zip(model_resp, input_length_list, mask_length_list):
- # # assert len(resp_item) == input_length - 1
- # item = resp_item[mask_length - 1:input_length - 1]
- # return_resp.append(item)
- #
- # pred_list = [sum(logprobs) / len(logprobs) for logprobs in return_resp]
- # answers_pred = id_label[int(np.argmax(pred_list))]
- # info[str(idx)] = answers_pred
- #
- # if answers_pred == str(answers):
- # correct_num += 1
- # acc = correct_num / count * 100
- # print("=================== acc ========================")
- # print("class_val", class_val, "acc ", acc )
- #
- # results[class_val] = info
- # acc_all[class_val] = acc
- # print("=================== acc ========================")
- # print("mmlu fewshot all-acc", acc_all)
- json_out_stream(log_path + '/cmmlu.json', results)
- # if few_shot:
- # with open(log_path + r'\mmlu_fewshot_acc.json', 'w', encoding='utf-8') as file:
- # file.write(json.dumps(acc_all, ensure_ascii=False))
- #
- # with open(log_path + r'\mmlu_fewshot.json', 'w', encoding='utf-8') as file:
- # file.write(json.dumps(results, ensure_ascii=False))
- # else:
- # with open(log_path + r'\mmlu_zeroshot_acc.json', 'w', encoding='utf-8') as file:
- # file.write(json.dumps(acc_all, ensure_ascii=False))
- #
- # with open(log_path + r'\mmlu_zeroshot.json', 'w', encoding='utf-8') as file:
- # file.write(json.dumps(results, ensure_ascii=False))
|