2000亿开源中文预训练语言模型「鹏程·盘古α」
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

124 lines
5.0 KiB

  1. """
  2. PanGu predict run
  3. """
  4. import argparse
  5. from pangu_alpha_config import PANGUALPHAConfig, set_parse
  6. from pangu_alpha_train import run_train
  7. if __name__ == "__main__":
  8. """train function for PanGu-Alpha"""
  9. parser = argparse.ArgumentParser(description="PanGu training")
  10. parser.add_argument('--train_url',
  11. required=False,
  12. default=None,
  13. help='Location of training outputs.')
  14. parser.add_argument('--data_url',
  15. required=False,
  16. default="/cache_pangu_alpha/V1-sample60-baike-math-bpe-1024",
  17. help='Location of data.')
  18. parser.add_argument("--distribute",
  19. type=str,
  20. default="true",
  21. choices=["true", "false"],
  22. help="Run distribute, default is false.")
  23. parser.add_argument("--optimizer",
  24. type=str,
  25. default="adam",
  26. choices=["adam", "lamb"],
  27. help="select which optimizer to be used, default adam")
  28. parser.add_argument("--epoch_size",
  29. type=int,
  30. default=1,
  31. help="Epoch size, default is 1.")
  32. parser.add_argument("--warmup_step",
  33. type=int,
  34. default=2000,
  35. help="Warmup step, default is 2000.")
  36. parser.add_argument("--decay_steps",
  37. type=int,
  38. default=80000,
  39. help="Learning rate decay step, default is 80000.")
  40. parser.add_argument("--start_lr",
  41. type=float,
  42. default="6e-5",
  43. help="Start learning rate, default is 6e-5.")
  44. parser.add_argument("--end_lr",
  45. type=float,
  46. default="6e-6",
  47. help="End learning rate, default is 6e-6.")
  48. parser.add_argument("--sink_size",
  49. type=int,
  50. default=2,
  51. help="Sink size for every iteration, default is 2")
  52. parser.add_argument("--weight_decay",
  53. type=float,
  54. default=1e-1,
  55. help="weight decay of optimizer")
  56. parser.add_argument('--ckpt_save_sir',
  57. required=False,
  58. default="/cache/ckpt/",
  59. help='Dir to save ckpt.')
  60. parser.add_argument("--seq_length",
  61. type=int,
  62. default=1024,
  63. help="sequence length, default is 1024.")
  64. parser.add_argument("--vocab_size",
  65. type=int,
  66. default=40000,
  67. help="vocabulary size, default is 40000.")
  68. parser.add_argument("--embedding_size",
  69. type=int,
  70. default=16384,
  71. help="embedding table size, default is 16384.")
  72. parser.add_argument("--num_layers",
  73. type=int,
  74. default=64,
  75. help="total layers, default is 64.")
  76. parser.add_argument("--num_heads",
  77. type=int,
  78. default=128,
  79. help="head size, default is 128.")
  80. parser.add_argument("--optimizer_shard",
  81. type=int,
  82. default=0,
  83. choices=[0, 1],
  84. help="enable optimizer shard.")
  85. parser.add_argument("--stage_num",
  86. type=int,
  87. default=16,
  88. help="Pipeline stage num, default is 16.")
  89. parser.add_argument("--micro_size",
  90. type=int,
  91. default=32,
  92. help="Pipeline micro_size, default is 32.")
  93. parser.add_argument("--tensor_model_parallel_num",
  94. type=int,
  95. default=16,
  96. help="The model parallel dim of slicing tensor.")
  97. parser.add_argument("--per_batch_size",
  98. type=int,
  99. default=1,
  100. help="The batch size of each card.")
  101. parser.add_argument("--save_steps",
  102. type=int,
  103. default=1000,
  104. help="Checkpoint save steps, default is 2000.")
  105. parser.add_argument("--run_type",
  106. type=str,
  107. default="train",
  108. choices=["train", "predict"],
  109. help="The run type")
  110. parser.add_argument("--mode",
  111. type=str,
  112. default="2.6B",
  113. choices=["200B", "13B", "2.6B", "self_define"],
  114. help="The train/eval mode")
  115. args_opt = parser.parse_args()
  116. # set the input configs by train_mode
  117. set_parse(args_opt)
  118. run_train(args_opt)