2000亿开源中文预训练语言模型「鹏程·盘古α」
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

225 lines
9.0 KiB

  1. """
  2. PANGUALPHA train script
  3. """
  4. import os
  5. import numpy as np
  6. import time
  7. from mindspore import context, Tensor
  8. from mindspore.train.model import Model
  9. import mindspore.communication.management as D
  10. from mindspore.context import ParallelMode
  11. from mindspore.train.serialization import load_checkpoint, load_param_into_net, load_distributed_checkpoint
  12. import mindspore.common.dtype as mstype
  13. from mindspore.parallel._cost_model_context import _set_multi_subgraphs
  14. from mindspore.parallel import set_algo_parameters
  15. from pangu_alpha import PANGUALPHAPipeline, PANGUALPHA, EvalNet
  16. from pangu_alpha_config import PANGUALPHAConfig
  17. def run_predict_pipeline(args_opt):
  18. device_id = int(os.getenv("DEVICE_ID"))
  19. rank_id_str = os.getenv('RANK_ID', '0')
  20. rank_id = int(
  21. rank_id_str[rank_id_str.rfind('-') +
  22. 1:])
  23. print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str))
  24. device_id = int(os.getenv('DEVICE_ID'))
  25. local_rank = rank_id
  26. print('local_rank:{}, device id:{} start to run...'.format(
  27. local_rank, device_id),
  28. flush=True)
  29. context.set_context(save_graphs=False,
  30. mode=context.GRAPH_MODE,
  31. device_target="Ascend",
  32. device_id=device_id)
  33. context.set_context(variable_memory_max_size="30GB")
  34. if args_opt.distribute == "true":
  35. D.init()
  36. device_num = D.get_group_size()
  37. rank = D.get_rank()
  38. print("device_id is {}, rank_id is {}, device_num is {}".format(
  39. device_id, rank, device_num))
  40. context.reset_auto_parallel_context()
  41. context.set_auto_parallel_context(
  42. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
  43. gradients_mean=False,
  44. device_num=device_num,
  45. full_batch=True,
  46. loss_repeated_mean=True,
  47. enable_parallel_optimizer=False,
  48. pipeline_stages=args_opt.stage_num)
  49. set_algo_parameters(elementwise_op_strategy_follow=True)
  50. _set_multi_subgraphs()
  51. else:
  52. rank = 0
  53. device_num = 1
  54. model_parallel_num = args_opt.tensor_model_parallel_num
  55. stage_device_num = int(device_num / args_opt.stage_num)
  56. data_parallel_num = int(stage_device_num / model_parallel_num)
  57. per_batch_size = args_opt.per_batch_size
  58. batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
  59. config = PANGUALPHAConfig(
  60. data_parallel_num=data_parallel_num,
  61. model_parallel_num=model_parallel_num,
  62. batch_size=batch_size,
  63. seq_length=args_opt.seq_length,
  64. vocab_size=args_opt.vocab_size,
  65. embedding_size=args_opt.embedding_size,
  66. num_layers=args_opt.num_layers,
  67. num_heads=args_opt.num_heads,
  68. expand_ratio=4,
  69. post_layernorm_residual=False,
  70. dropout_rate=0.0,
  71. compute_dtype=mstype.float16,
  72. use_past=False,
  73. self_layernorm=True,
  74. forward_reduce_scatter=True,
  75. stage_num=args_opt.stage_num,
  76. micro_size=args_opt.micro_size,
  77. word_emb_dp=False)
  78. print("===config is: ", config, flush=True)
  79. print("=====args_opt is: ", args_opt, flush=True)
  80. per_stage_layers = config.num_layers // config.stage_num
  81. per_stage_devices = device_num // config.stage_num
  82. self_stage = rank_id // per_stage_devices
  83. # all cards will save ckpt
  84. train_stage_num = 16
  85. train_device_num = 1024
  86. train_mp = 16
  87. ckpt_name = args_opt.load_ckpt_name
  88. train_per_stage_num = train_device_num // train_stage_num
  89. if config.mp != train_mp:
  90. raise ValueError("the model parallel num is not equal to training model parallel num")
  91. concat_stage_num = train_stage_num // config.stage_num
  92. pangu_alpha = PANGUALPHAPipeline(config)
  93. eval_net = EvalNet(pangu_alpha)
  94. eval_net.set_train(False)
  95. model_predict = Model(eval_net)
  96. inputs_np = Tensor(np.ones(shape=(1, config.seq_length)), mstype.int32)
  97. model_predict.infer_predict_layout(inputs_np)
  98. print("======start load_distributed checkpoint", flush=True)
  99. for i in range(self_stage * concat_stage_num, (self_stage + 1) * concat_stage_num):
  100. stage_position = local_rank % (config.mp * config.dp)
  101. ckpt_rank = i * train_per_stage_num + stage_position # 訓練時候的rank號
  102. ckpt_dir = os.path.join(args_opt.load_ckpt_path, f"rank_{(ckpt_rank)}") # 命名還是以訓練時候的rank號命名
  103. local_ckpt_file = os.path.join(ckpt_dir, ckpt_name)
  104. if not os.path.exists(local_ckpt_file):
  105. raise ValueError("Ckpt file not exits,", local_ckpt_file)
  106. params_dict = load_checkpoint(local_ckpt_file, filter_prefix="adam")
  107. load_param_into_net(eval_net, params_dict)
  108. print("================load param ok=================", flush=True)
  109. # here predict with fake input
  110. model_predict.predict(inputs_np)
  111. def run_predict_no_pipeline(args_opt):
  112. device_id = int(os.getenv("DEVICE_ID"))
  113. rank_id_str = os.getenv('RANK_ID', '0')
  114. rank_id = int(
  115. rank_id_str[rank_id_str.rfind('-') +
  116. 1:])
  117. print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str))
  118. device_id = int(os.getenv('DEVICE_ID'))
  119. local_rank = rank_id
  120. print('local_rank:{}, device id:{} start to run...'.format(
  121. local_rank, device_id),
  122. flush=True)
  123. context.set_context(save_graphs=False,
  124. mode=context.GRAPH_MODE,
  125. device_target="Ascend",
  126. device_id=device_id)
  127. context.set_context(variable_memory_max_size="30GB")
  128. if args_opt.distribute == "true":
  129. D.init()
  130. device_num = D.get_group_size()
  131. rank = D.get_rank()
  132. print("device_id is {}, rank_id is {}, device_num is {}".format(
  133. device_id, rank, device_num))
  134. context.reset_auto_parallel_context()
  135. context.set_auto_parallel_context(
  136. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
  137. gradients_mean=False,
  138. device_num=device_num,
  139. full_batch=True,
  140. loss_repeated_mean=True,
  141. enable_parallel_optimizer=False,
  142. strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
  143. pipeline_stages=args_opt.stage_num)
  144. set_algo_parameters(elementwise_op_strategy_follow=True)
  145. _set_multi_subgraphs()
  146. else:
  147. rank = 0
  148. device_num = 1
  149. model_parallel_num = args_opt.tensor_model_parallel_num
  150. data_parallel_num = int(device_num / model_parallel_num)
  151. per_batch_size = args_opt.per_batch_size
  152. batch_size = per_batch_size * data_parallel_num
  153. config = PANGUALPHAConfig(
  154. data_parallel_num=data_parallel_num,
  155. model_parallel_num=model_parallel_num,
  156. batch_size=batch_size,
  157. seq_length=args_opt.seq_length,
  158. vocab_size=args_opt.vocab_size,
  159. embedding_size=args_opt.embedding_size,
  160. num_layers=args_opt.num_layers,
  161. num_heads=args_opt.num_heads,
  162. expand_ratio=4,
  163. post_layernorm_residual=False,
  164. dropout_rate=0.0,
  165. compute_dtype=mstype.float16,
  166. use_past=False,
  167. self_layernorm=True,
  168. forward_reduce_scatter=True,
  169. stage_num=args_opt.stage_num,
  170. micro_size=args_opt.micro_size,
  171. eod_reset=False,
  172. word_emb_dp=True,
  173. load_ckpt_path=args_opt.load_ckpt_path)
  174. print("===config is: ", config, flush=True)
  175. print("=====args_opt is: ", args_opt, flush=True)
  176. ckpt_name = args_opt.load_ckpt_name
  177. pangu_alpha = PANGUALPHA(config)
  178. eval_net = EvalNet(pangu_alpha)
  179. eval_net.set_train(False)
  180. model_predict = Model(eval_net)
  181. inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
  182. predict_layout = model_predict.infer_predict_layout(inputs_np)
  183. print("======start load_distributed checkpoint", flush=True)
  184. # For 2.6B and 13B models, the number of ckpt files is 512.
  185. ckpt_name = 'filerted'
  186. ckpt_file_list = [os.path.join(args_opt.load_ckpt_path, f"{ckpt_name}_{ckpt_rank}.ckpt") for ckpt_rank in range(0, 512)]
  187. print(f"Loading from path {ckpt_file_list[0]}", flush=True)
  188. load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout)
  189. print("================load param ok=================", flush=True)
  190. from tokenization_jieba import JIEBATokenizer
  191. from generate import generate
  192. tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab.vocab'),
  193. os.path.join(args_opt.tokenizer_path, 'vocab.model'))
  194. sample = "今天是一个好天气"
  195. tokenized_token = tokenizer.tokenize(sample)
  196. start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
  197. input_ids = np.array(start_sentence).reshape(1, -1)
  198. output_ids = generate(model_predict, input_ids, config.seq_length, 9)
  199. output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
  200. print('Output is:', output_samples, flush=True)
  201. def run_predict(args_opt):
  202. if args_opt.stage_num > 1:
  203. run_predict_pipeline(args_opt)
  204. else:
  205. run_predict_no_pipeline(args_opt)