2000亿开源中文预训练语言模型「鹏程·盘古α」 https://pangu-alpha.openi.org.cn
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

94 lines
3.8 KiB

  1. """
  2. PanGu predict run
  3. """
  4. import argparse
  5. from pangu_alpha_config import PANGUALPHAConfig, set_parse
  6. from pangu_alpha_predict import run_predict
  7. if __name__ == "__main__":
  8. """predict function for PANGUALPHA"""
  9. parser = argparse.ArgumentParser(description="PANGUALPHA predicting")
  10. parser.add_argument('--device_id',
  11. type=int,
  12. default=0,
  13. help="Device id, default is 0.")
  14. parser.add_argument("--device_num",
  15. type=int,
  16. default=128,
  17. help="Use device nums, default is 1.")
  18. parser.add_argument("--distribute",
  19. type=str,
  20. default="true",
  21. choices=["true", "false"],
  22. help="Run distribute, default is false.")
  23. parser.add_argument("--seq_length",
  24. type=int,
  25. default=1024,
  26. help="sequence length, default is 1024.")
  27. parser.add_argument("--vocab_size",
  28. type=int,
  29. default=40000,
  30. help="vocabulary size, default is 40000.")
  31. parser.add_argument("--embedding_size",
  32. type=int,
  33. default=16384,
  34. help="embedding table size, default is 16384.")
  35. parser.add_argument("--num_layers",
  36. type=int,
  37. default=64,
  38. help="total layers, default is 64.")
  39. parser.add_argument("--num_heads",
  40. type=int,
  41. default=128,
  42. help="head size, default is 128.")
  43. parser.add_argument("--stage_num",
  44. type=int,
  45. default=4,
  46. help="Pipeline stage num, default is 4.")
  47. parser.add_argument("--micro_size",
  48. type=int,
  49. default=1,
  50. help="Pipeline micro_size, default is 1.")
  51. parser.add_argument("--load_ckpt_name",
  52. type=str,
  53. default='PANGUALPHA3.ckpt',
  54. help="checkpint file name.")
  55. parser.add_argument("--load_ckpt_path",
  56. type=str,
  57. default=None,
  58. help="predict file path.")
  59. parser.add_argument('--data_url',
  60. required=False,
  61. default=None,
  62. help='Location of data.')
  63. parser.add_argument('--train_url',
  64. required=False,
  65. default=None,
  66. help='Location of training outputs.')
  67. parser.add_argument("--run_type",
  68. type=str,
  69. default="predict",
  70. choices=["train", "predict"],
  71. help="The run type")
  72. parser.add_argument("--mode",
  73. type=str,
  74. default="2.6B",
  75. choices=["200B", "13B", "2.6B", "self_define"],
  76. help="The train/eval mode")
  77. parser.add_argument("--strategy_load_ckpt_path",
  78. type=str,
  79. default="",
  80. help="The training prallel strategy for the model.")
  81. parser.add_argument("--tokenizer_path",
  82. type=str,
  83. default="./tokenizer_path",
  84. help="The path where stores vocab and vocab model file")
  85. args_opt = parser.parse_args()
  86. # The ckpt path shoud like args_opt.load_ckpt_path + f"rank_{rank_id}}/" + args_opt.load_ckpt_name, and the rank_id is the training rank_id.
  87. set_parse(args_opt)
  88. run_predict(args_opt)