|
- from transformers import AutoModelForCausalLM, AutoTokenizer
- import torch
-
- def get_local_tokenizer():
- tokenizer = AutoTokenizer.from_pretrained("/userhome/baichuan-7B", trust_remote_code=True)
- return tokenizer
-
- def get_model():
- model = AutoModelForCausalLM.from_pretrained("/userhome/baichuan-7B", trust_remote_code=True).half().cuda()
- return model
-
- tokenizer = get_local_tokenizer()
- model = get_model()
- model.eval()
-
- def get_local_model_resp_one_item(input_str, tokens_to_generate, top_k=3, logprobs=False):
-
-
- batch = tokenizer(input_str, return_tensors='pt').to('cuda')
-
- if logprobs:
-
- labels = batch["input_ids"].clone()[:, 1:]
- with torch.autograd.grad_mode.no_grad():
- output = model(batch["input_ids"], batch['attention_mask'])
- logits = output.logits[:, :-1, :]
- logprobs_list = 0 - torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), reduction='none')
- logprobs_list = logprobs_list.cpu().numpy().tolist()
- return logprobs_list
-
- else:
-
- pred = model.generate(**batch, top_k = top_k, max_new_tokens = tokens_to_generate, repetition_penalty=1.1)
- output_str = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
- return output_str
-
- if __name__ == '__main__':
- input_str = "阅读:" + "是不是cad系统毛病 建议重新下载" + "问:" + "cad捕捉不到点 一直跳来跳去" + "?答:"
- output = get_local_model_resp_one_item(input_str, 100, logprobs=False)
- print(output)
- output = get_local_model_resp_one_item(input_str, 0, logprobs=True)
- print(output)
- pass
|