2000亿开源中文预训练语言模型「鹏程·盘古α」
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

119 lines
4.0 KiB

  1. """
  2. network config setting
  3. """
  4. import mindspore.common.dtype as mstype
  5. class PANGUALPHAConfig:
  6. """
  7. PANGUALPHA config class which defines the model size
  8. """
  9. def __init__(self,
  10. data_parallel_num,
  11. model_parallel_num,
  12. batch_size=32,
  13. seq_length=1024,
  14. vocab_size=50257,
  15. embedding_size=768,
  16. num_layers=12,
  17. num_heads=12,
  18. expand_ratio=4,
  19. post_layernorm_residual=False,
  20. dropout_rate=0.1,
  21. compute_dtype=mstype.float16,
  22. use_past=False,
  23. self_layernorm=True,
  24. forward_reduce_scatter=True,
  25. word_emb_dp=True,
  26. stage_num=16,
  27. eod_reset=True,
  28. micro_size=32,
  29. load_ckpt_path=None):
  30. self.batch_size = batch_size
  31. self.seq_length = seq_length
  32. self.vocab_size = vocab_size
  33. self.embedding_size = embedding_size
  34. self.num_layers = num_layers
  35. self.num_heads = num_heads
  36. self.expand_ratio = expand_ratio
  37. self.post_layernorm_residual = post_layernorm_residual
  38. self.dropout_rate = dropout_rate
  39. self.compute_dtype = compute_dtype
  40. self.use_past = use_past
  41. self.dp = data_parallel_num
  42. self.mp = model_parallel_num
  43. self.self_layernorm = self_layernorm
  44. self.forward_reduce_scatter = forward_reduce_scatter
  45. self.stage_num = stage_num
  46. self.micro_size = micro_size
  47. self.word_emb_dp = word_emb_dp
  48. self.eod_reset = eod_reset
  49. # Used for loading embedding tables
  50. self.load_ckpt_path = load_ckpt_path
  51. def __str__(self):
  52. info = "[PANGUALPHAConfig]" + '===' * 10 + '\n'
  53. for k, v in self.__dict__.items():
  54. var_info = "{}:{}\n".format(k, v)
  55. info += var_info
  56. info += '=' * 10
  57. return info
  58. def set_parse(args_opt):
  59. if args_opt.mode == "200B":
  60. args_opt.seq_length = 1024
  61. args_opt.vocab_size = 40000
  62. args_opt.embedding_size = 16384
  63. args_opt.num_layers = 64
  64. args_opt.num_heads = 128
  65. if args_opt.run_type == "train":
  66. args_opt.start_lr = 6e-5
  67. args_opt.end_lr = 6e-6
  68. args_opt.optimizer_shard = False
  69. args_opt.stage_num = 16
  70. args_opt.micro_size = 32
  71. args_opt.tensor_model_parallel_num = 16
  72. args_opt.per_batch_size = 1
  73. elif args_opt.run_type == "predict":
  74. args_opt.stage_num = 4
  75. args_opt.micro_size = 1
  76. args_opt.per_batch_size = 1
  77. elif args_opt.mode == "13B":
  78. args_opt.seq_length = 1024
  79. args_opt.vocab_size = 40000
  80. args_opt.embedding_size = 5120
  81. args_opt.num_layers = 40
  82. args_opt.num_heads = 40
  83. args_opt.tensor_model_parallel_num = 8
  84. if args_opt.run_type == "train":
  85. args_opt.start_lr = 5e-5
  86. args_opt.end_lr = 1e-6
  87. args_opt.optimizer_shard = True
  88. args_opt.stage_num = 1
  89. args_opt.micro_size = 1
  90. args_opt.per_batch_size = 16
  91. elif args_opt.run_type == "predict":
  92. args_opt.stage_num = 1
  93. args_opt.micro_size = 1
  94. args_opt.per_batch_size = 1
  95. elif args_opt.mode == "2.6B":
  96. args_opt.seq_length = 1024
  97. args_opt.vocab_size = 40000
  98. args_opt.embedding_size = 2560
  99. args_opt.num_layers = 32
  100. args_opt.num_heads = 32
  101. args_opt.tensor_model_parallel_num = 8
  102. if args_opt.run_type == "train":
  103. args_opt.start_lr = 1e-4
  104. args_opt.end_lr = 1e-6
  105. args_opt.optimizer_shard = True
  106. args_opt.stage_num = 1
  107. args_opt.micro_size = 1
  108. args_opt.per_batch_size = 2
  109. elif args_opt.run_type == "predict":
  110. args_opt.stage_num = 1
  111. args_opt.micro_size = 1
  112. args_opt.per_batch_size = 1