2000亿开源中文预训练语言模型「鹏程·盘古α」
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

46 lines
1.4 KiB

  1. """
  2. TopK for text generation
  3. """
  4. import numpy as np
  5. import mindspore.common.dtype as mstype
  6. from mindspore.common.tensor import Tensor
  7. def generate(model, origin_inputs, seq_length, end_token=50256):
  8. """
  9. TopK for text generation
  10. Inputs:
  11. model: the model for inferencing
  12. origin_inputs: the original inputs based on which the model will continue writing
  13. seq_length: seq_length for the model
  14. end_token: end of sentence token id
  15. Returns:
  16. outputs: the ids for the generated text
  17. """
  18. TOPK = 3
  19. seq_length = seq_length
  20. bs, valid_length = origin_inputs.shape
  21. pad_length = seq_length - origin_inputs.shape[-1]
  22. input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
  23. print("input_ids is ", input_ids)
  24. while valid_length < seq_length:
  25. inputs = Tensor(input_ids, mstype.int32)
  26. probs, p_args = model.predict(inputs)
  27. probs = probs.asnumpy()[valid_length-1, :]
  28. p_args = p_args.asnumpy()[valid_length-1, :]
  29. p = probs
  30. p = p / sum(p)
  31. target_index = np.random.choice(len(p), p=p)
  32. if p_args[target_index] == end_token or valid_length == seq_length-1:
  33. outputs = input_ids
  34. break
  35. input_ids[0][valid_length] = p_args[target_index]
  36. valid_length += 1
  37. length = np.sum(outputs != 0)
  38. outputs = outputs[0][:length]
  39. return outputs