|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
-
- # @Date: 2021/4/12
- # @Author: qing
- import os
- import json
- import time
- import traceback
- import numpy as np
- import moxing as mox
-
- from model_path import models_path
- from model_manager import get_args, get_model, get_tokenizer
- from generate import generate, generate_samples
-
-
- class TextGenerate(object):
-
- def __init__(self, max_in_len=768):
- self.max_in_len = max_in_len
- self.root_path = os.path.dirname(__file__)
- self.data_path = os.path.join(self.root_path, "data")
- self.save_path = os.path.join(self.root_path, "data")
-
- def load_model(self, model_name):
- args_opt = get_args()
-
- args_opt.ckpt_path = models_path[model_name] + "part"
- args_opt.word_embedding_path = models_path[model_name] + "_word_embedding.npy"
- args_opt.position_embedding_path = models_path[model_name] + "_position_embedding.npy"
- args_opt.top_query_embedding_path = models_path[model_name] + "_top_query_embedding.npy"
-
- model, config, rank = get_model(args_opt)
- tokenizer = get_tokenizer(args_opt)
-
- self.args_opt = args_opt
- self.model = model
- self.config = config
- self.rank = rank
- self.tokenizer = tokenizer
-
- return model, config, rank, tokenizer
-
- def do_generate(self, input_text):
- output_text, input_ids, output_list = "", "", ""
-
- try:
- tokenized_text = self.tokenizer.tokenize(input_text)
- start_sentence = self.tokenizer.convert_tokens_to_ids(tokenized_text)
- input_ids = np.array(start_sentence).reshape(1, -1)
-
- # outputs = generate_samples(pangu_eval, input_ids, config.seq_length, end_token=tokenizer.eot_id, top_p=0.75)
- outputs = generate(self.model, input_ids, self.config.seq_length, end_token=self.tokenizer.eot_id, TOPK=5, max_num=1024)
- output_list = outputs.tolist()
- output_list = output_list[input_ids.shape[-1]:]
- output_text = "".join(self.tokenizer.convert_ids_to_tokens(output_list))
-
- except Exception as err:
- print("Process do generate exception, info: ", err)
- traceback.print_exc()
- print("input_ids:", input_ids)
- print("output_ids:", output_list)
-
- return output_text
-
- def generate(self, data_file, save_file):
- print("Text generate start!")
- start_time = time.time()
- with open(os.path.join(self.data_path, data_file), "r", encoding="utf-8") as f:
- input_texts = f.readlines()
-
- generate_results, cnt = {}, len(input_texts)
- # pbar = tqdm(total=len(input_texts), desc=f"Text generate.")
- for i, input_text in enumerate(input_texts):
- # pbar.update(1)
- if len(input_text)>self.max_in_len: input_text = input_text[-self.max_in_len:]
-
- input_text = input_text.rstrip()
- output_text = self.do_generate(input_text)
- output_text_post_processed = output_text.strip()
- output_text_post_processed = output_text_post_processed.split("\n")[0]
- text_id = "text" + str(i+1)
- generate_results[text_id] = {"输入": input_text, "生成": output_text, "生成后处理": output_text_post_processed}
-
- if self.rank == 0: print(f"\n输入:{input_text}\n生成:{output_text}\n")
-
- if self.rank == 0:
- save_path_local = os.path.join(self.save_path, save_file)
- with open(save_path_local, "w", encoding="utf-8") as f:
- json.dump(generate_results, f, indent=2, ensure_ascii=False)
-
- save_path_obs = os.path.join(self.args_opt.save_path, "textgenerate", save_file)
- print("Copy the output file {} to the obs:{}".format(save_path_local, save_path_obs))
- mox.file.copy(save_path_local, save_path_obs)
-
- end_time = time.time()
- print(f"Text generate start, test num: {cnt}, cost time: {end_time - start_time}!")
-
- def generate_test():
- for model_name in models_path:
- text_generater = TextGenerate()
- text_generater.load_model(model_name)
- text_generater.generate("generate_in.txt", f"{model_name}_generate_out.json")
-
-
- if __name__ == "__main__":
- generate_test()
|