2000亿开源中文预训练语言模型「鹏程·盘古α」 https://pangu-alpha.openi.org.cn
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

342 lines
14 KiB

  1. """
  2. PanGu train script
  3. """
  4. import os
  5. import math
  6. from pathlib2 import Path
  7. from mindspore import context
  8. from mindspore.train.model import Model
  9. import mindspore.communication.management as D
  10. from mindspore.context import ParallelMode
  11. import mindspore.nn as nn
  12. from mindspore.train.callback import TimeMonitor, ModelCheckpoint, CheckpointConfig, Callback
  13. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  14. import mindspore.common.dtype as mstype
  15. from mindspore.parallel._cost_model_context import _set_multi_subgraphs
  16. from mindspore.parallel import set_algo_parameters
  17. import mindspore.dataset as de
  18. from dataset import create_dataset, create_dataset_dp
  19. from pangu_alpha import PANGUALPHAPipeline, PANGUALPHA, PANGUALPHAWithLossPipeline, PANGUALPHAWithLoss, CrossEntropyLoss
  20. from pangu_alpha_wrapcell import PANGUALPHATrainPipelineWithLossScaleCell, PANGUALPHATrainOneStepWithLossScaleCell, \
  21. VirtualDatasetOneInputCell
  22. from utils import LearningRate
  23. from pangu_alpha_config import PANGUALPHAConfig, set_parse
  24. class LossCallBack(Callback):
  25. """
  26. Monitor the loss in training.
  27. """
  28. def __init__(self, dataset_size=-1, local_rank=0, scale=1):
  29. super(LossCallBack, self).__init__()
  30. self._dataset_size = dataset_size
  31. self.local_rank = local_rank
  32. self.scale = scale
  33. def step_end(self, run_context):
  34. """
  35. Print loss after each step
  36. """
  37. cb_params = run_context.original_args()
  38. # NOTE: We send the data after sending twice sink size
  39. # where sink size is equal to the dataset_size (a fake one) here
  40. de.config.set_sending_batches(cb_params.cur_step_num + 2*self._dataset_size)
  41. if self._dataset_size > 0 and self.local_rank % 8 == 0:
  42. percent, epoch_num = math.modf(cb_params.cur_step_num /
  43. self._dataset_size)
  44. if percent == 0:
  45. percent = 1
  46. epoch_num -= 1
  47. print(
  48. "local_rank: {}, epoch: {}, step: {}, output is {}, overflow is {}, scale is {}"
  49. .format(int(self.local_rank), int(epoch_num),
  50. cb_params.cur_step_num,
  51. cb_params.net_outputs[0].asnumpy() / self.scale,
  52. cb_params.net_outputs[1].asnumpy(),
  53. cb_params.net_outputs[2].asnumpy()))
  54. if len(cb_params.net_outputs) > 3:
  55. print("global norm is: ", cb_params.net_outputs[3].asnumpy())
  56. def run_train_pipeline(args_opt):
  57. device_id = int(os.getenv("DEVICE_ID"))
  58. rank_id = int(os.getenv("RANK_ID"))
  59. local_rank = rank_id
  60. print('local_rank:{}, device id:{} start to run...'.format(
  61. local_rank, device_id),
  62. flush=True)
  63. context.set_context(save_graphs=False,
  64. mode=context.GRAPH_MODE,
  65. device_target="Ascend",
  66. device_id=device_id)
  67. context.set_context(variable_memory_max_size="31GB")
  68. strategy_ckpt_save_file = "/cache/" + "strategy" + str(local_rank) + ".ckpt"
  69. if args_opt.distribute == "true":
  70. D.init()
  71. device_num = D.get_group_size()
  72. rank = D.get_rank()
  73. print("device_id is {}, rank_id is {}, device_num is {}".format(
  74. device_id, rank, device_num))
  75. context.reset_auto_parallel_context()
  76. context.set_auto_parallel_context(
  77. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
  78. gradients_mean=False,
  79. device_num=device_num,
  80. full_batch=True,
  81. loss_repeated_mean=True,
  82. enable_parallel_optimizer=bool(args_opt.optimizer_shard),
  83. pipeline_stages=args_opt.stage_num,
  84. strategy_ckpt_save_file=strategy_ckpt_save_file)
  85. set_algo_parameters(elementwise_op_strategy_follow=True)
  86. _set_multi_subgraphs()
  87. else:
  88. rank = 0
  89. device_num = 1
  90. model_parallel_num = args_opt.tensor_model_parallel_num
  91. stage_device_num = int(device_num / args_opt.stage_num)
  92. data_parallel_num = int(stage_device_num / model_parallel_num)
  93. per_batch_size = args_opt.per_batch_size
  94. batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
  95. config = PANGUALPHAConfig(
  96. data_parallel_num=data_parallel_num,
  97. model_parallel_num=model_parallel_num,
  98. batch_size=batch_size,
  99. seq_length=args_opt.seq_length,
  100. vocab_size=args_opt.vocab_size,
  101. embedding_size=args_opt.embedding_size,
  102. num_layers=args_opt.num_layers,
  103. num_heads=args_opt.num_heads,
  104. expand_ratio=4,
  105. post_layernorm_residual=False,
  106. dropout_rate=0.1,
  107. compute_dtype=mstype.float16,
  108. use_past=False,
  109. self_layernorm=True,
  110. forward_reduce_scatter=True,
  111. stage_num=args_opt.stage_num,
  112. micro_size=args_opt.micro_size,
  113. word_emb_dp=False)
  114. print("===config is: ", config, flush=True)
  115. pangu_alpha = PANGUALPHAPipeline(config)
  116. loss = CrossEntropyLoss(config)
  117. pangu_alpha_with_loss = PANGUALPHAWithLossPipeline(config, pangu_alpha, loss)
  118. pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)
  119. print("=====args_opt is: ", args_opt, flush=True)
  120. lr = LearningRate(learning_rate=args_opt.start_lr,
  121. end_learning_rate=args_opt.end_lr,
  122. warmup_steps=args_opt.warmup_step,
  123. decay_steps=args_opt.decay_steps)
  124. per_stage_layers = config.num_layers // config.stage_num
  125. per_stage_devices = device_num // config.stage_num
  126. self_stage = rank_id // per_stage_devices
  127. range_min = self_stage * per_stage_layers
  128. range_max = range_min + per_stage_layers
  129. if self_stage == 0:
  130. params = [pangu_alpha.embedding_table]
  131. params.extend(pangu_alpha.backbone.pangu_alpha_embedding.position_embedding.trainable_params())
  132. elif self_stage == config.stage_num - 1:
  133. params = [pangu_alpha.embedding_table]
  134. params.extend(pangu_alpha.backbone.layernorm.trainable_params())
  135. params.extend(pangu_alpha.backbone.top_query_embedding.trainable_params())
  136. else:
  137. params = []
  138. for i in range(range_min, range_max):
  139. params.extend(pangu_alpha.backbone.blocks[i].trainable_params())
  140. decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
  141. decay_params = list(filter(decay_filter, params))
  142. other_params = list(filter(lambda x: not decay_filter(x), params))
  143. group_params = [{
  144. 'params': decay_params,
  145. 'weight_decay': args_opt.weight_decay
  146. }, {
  147. 'params': other_params,
  148. 'weight_decay': 0.0
  149. }, {
  150. 'order_params': params
  151. }]
  152. if args_opt.optimizer == "lamb":
  153. optimizer = nn.Lamb(group_params, learning_rate=lr)
  154. else:
  155. optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
  156. save_steps = args_opt.save_steps
  157. ckpt_dir = os.path.join(args_opt.ckpt_save_sir, f"rank_{str(local_rank)}")
  158. if not os.path.exists(ckpt_dir):
  159. Path(ckpt_dir).mkdir(parents=True, exist_ok=True)
  160. ds = create_dataset(config.batch_size, data_path=args_opt.data_url, data_start_index=0)
  161. epoch_num = args_opt.epoch_size
  162. step_per_epoch = ds.get_dataset_size()
  163. callback_size = args_opt.sink_size
  164. actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
  165. callback = [
  166. TimeMonitor(callback_size),
  167. LossCallBack(callback_size, local_rank, config.stage_num)
  168. ]
  169. config_ck = CheckpointConfig(save_checkpoint_steps=save_steps,
  170. keep_checkpoint_max=1,
  171. integrated_save=False,
  172. filter_prefix="accu_grads")
  173. ckpoint_cb = ModelCheckpoint(prefix="PanguAlpha",
  174. directory=ckpt_dir,
  175. config=config_ck)
  176. callback.append(ckpoint_cb)
  177. loss_scale_value = math.pow(2, 32)
  178. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value,
  179. scale_factor=2,
  180. scale_window=1000)
  181. pangu_alpha_with_grads = PANGUALPHATrainPipelineWithLossScaleCell(
  182. pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell)
  183. model = Model(pangu_alpha_with_grads)
  184. de.config.set_sending_batches(2*args_opt.sink_size)
  185. model.train(actual_epoch_num,
  186. ds,
  187. callbacks=callback,
  188. sink_size=callback_size,
  189. dataset_sink_mode=True)
  190. def run_train_no_pipeline(args_opt):
  191. device_id = int(os.getenv("DEVICE_ID"))
  192. rank_id = int(os.getenv("RANK_ID"))
  193. local_rank = rank_id
  194. print('local_rank:{}, device id:{} start to run...'.format(
  195. local_rank, device_id),
  196. flush=True)
  197. save_graphs_path = "/var/log/npu/slog/device-" + str(local_rank) + "/"
  198. context.set_context(save_graphs=False,
  199. save_graphs_path=save_graphs_path,
  200. mode=context.GRAPH_MODE,
  201. device_target="Ascend",
  202. device_id=device_id)
  203. context.set_context(variable_memory_max_size="31GB")
  204. strategy_ckpt_save_file = "/cache/" + "strategy" + str(local_rank) + ".ckpt"
  205. if args_opt.distribute == "true":
  206. D.init()
  207. device_num = D.get_group_size()
  208. rank = D.get_rank()
  209. print("device_id is {}, rank_id is {}, device_num is {}".format(
  210. device_id, rank, device_num))
  211. context.reset_auto_parallel_context()
  212. context.set_auto_parallel_context(
  213. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
  214. gradients_mean=False,
  215. device_num=device_num,
  216. full_batch=False,
  217. loss_repeated_mean=True,
  218. enable_parallel_optimizer=bool(args_opt.optimizer_shard),
  219. pipeline_stages=args_opt.stage_num,
  220. strategy_ckpt_save_file=strategy_ckpt_save_file)
  221. set_algo_parameters(elementwise_op_strategy_follow=True)
  222. _set_multi_subgraphs()
  223. else:
  224. rank = 0
  225. device_num = 1
  226. model_parallel_num = args_opt.tensor_model_parallel_num
  227. data_parallel_num = int(device_num / model_parallel_num)
  228. per_batch_size = args_opt.per_batch_size
  229. batch_size = per_batch_size * device_num
  230. config = PANGUALPHAConfig(
  231. data_parallel_num=data_parallel_num,
  232. model_parallel_num=model_parallel_num,
  233. batch_size=batch_size,
  234. seq_length=args_opt.seq_length,
  235. vocab_size=args_opt.vocab_size,
  236. embedding_size=args_opt.embedding_size,
  237. num_layers=args_opt.num_layers,
  238. num_heads=args_opt.num_heads,
  239. expand_ratio=4,
  240. post_layernorm_residual=False,
  241. dropout_rate=0.1,
  242. compute_dtype=mstype.float16,
  243. use_past=False,
  244. self_layernorm=True,
  245. forward_reduce_scatter=True,
  246. stage_num=args_opt.stage_num,
  247. micro_size=args_opt.micro_size,
  248. word_emb_dp=True)
  249. print("===config is: ", config, flush=True)
  250. pangu_alpha = PANGUALPHA(config)
  251. loss = CrossEntropyLoss(config)
  252. pangu_alpha_with_loss = PANGUALPHAWithLoss(config, pangu_alpha, loss)
  253. pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)
  254. print("=====args_opt is: ", args_opt, flush=True)
  255. lr = LearningRate(learning_rate=args_opt.start_lr,
  256. end_learning_rate=args_opt.end_lr,
  257. warmup_steps=args_opt.warmup_step,
  258. decay_steps=args_opt.decay_steps)
  259. decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
  260. params = pangu_alpha.trainable_params()
  261. decay_params = list(filter(decay_filter, params))
  262. other_params = list(filter(lambda x: not decay_filter(x), params))
  263. group_params = [{
  264. 'params': decay_params,
  265. 'weight_decay': args_opt.weight_decay
  266. }, {
  267. 'params': other_params,
  268. 'weight_decay': 0.0
  269. }, {
  270. 'order_params': params
  271. }]
  272. if args_opt.optimizer == "lamb":
  273. optimizer = nn.Lamb(group_params, learning_rate=lr)
  274. else:
  275. optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
  276. save_steps = args_opt.save_steps
  277. ckpt_dir = os.path.join(args_opt.ckpt_save_sir, f"rank_{str(local_rank)}")
  278. if not os.path.exists(ckpt_dir):
  279. Path(ckpt_dir).mkdir(parents=True, exist_ok=True)
  280. ds = create_dataset_dp(config.batch_size, data_path=args_opt.data_url, data_start_index=0, device_num=device_num, rank=rank)
  281. epoch_num = args_opt.epoch_size
  282. step_per_epoch = ds.get_dataset_size()
  283. callback_size = args_opt.sink_size
  284. actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
  285. callback = [
  286. TimeMonitor(callback_size),
  287. LossCallBack(callback_size, local_rank)
  288. ]
  289. config_ck = CheckpointConfig(save_checkpoint_steps=save_steps,
  290. keep_checkpoint_max=1,
  291. integrated_save=False)
  292. ckpoint_cb = ModelCheckpoint(prefix="PanguAlpha",
  293. directory=ckpt_dir,
  294. config=config_ck)
  295. callback.append(ckpoint_cb)
  296. loss_scale_value = math.pow(2, 32)
  297. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value,
  298. scale_factor=2,
  299. scale_window=1000)
  300. pangu_alpha_with_grads = PANGUALPHATrainOneStepWithLossScaleCell(
  301. pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell)
  302. model = Model(pangu_alpha_with_grads)
  303. de.config.set_sending_batches(2*args_opt.sink_size)
  304. model.train(actual_epoch_num,
  305. ds,
  306. callbacks=callback,
  307. sink_size=callback_size,
  308. dataset_sink_mode=True)
  309. def run_train(args_opt):
  310. if args_opt.stage_num > 1:
  311. run_train_pipeline(args_opt)
  312. else:
  313. run_train_no_pipeline(args_opt)