2000亿开源中文预训练语言模型「鹏程·盘古α」 https://pangu-alpha.openi.org.cn
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

240 lines
9.1 KiB

  1. """
  2. network config setting, gradient clip function and dynamic learning rate function
  3. """
  4. import numpy as np
  5. from multiprocessing import Process
  6. import mindspore.nn as nn
  7. from mindspore.ops import operations as P
  8. from mindspore.ops import composite as C
  9. from mindspore.ops import functional as F
  10. import mindspore.common.dtype as mstype
  11. from mindspore.common.tensor import Tensor
  12. from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR
  13. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  14. from mindspore.communication.management import get_rank, get_group_size, create_group
  15. class PANGUALPHAConfig:
  16. """
  17. PANGUALPHA config class which defines the model size
  18. """
  19. def __init__(self,
  20. data_parallel_num,
  21. model_parallel_num,
  22. batch_size=32,
  23. seq_length=1024,
  24. vocab_size=50257,
  25. embedding_size=768,
  26. num_layers=12,
  27. num_heads=12,
  28. expand_ratio=4,
  29. post_layernorm_residual=False,
  30. dropout_rate=0.1,
  31. compute_dtype=mstype.float16,
  32. use_past=False,
  33. self_layernorm=True,
  34. forward_reduce_scatter=True,
  35. word_emb_dp=True,
  36. stage_num=16,
  37. micro_size=32):
  38. self.batch_size = batch_size
  39. self.seq_length = seq_length
  40. self.vocab_size = vocab_size
  41. self.embedding_size = embedding_size
  42. self.num_layers = num_layers
  43. self.num_heads = num_heads
  44. self.expand_ratio = expand_ratio
  45. self.post_layernorm_residual = post_layernorm_residual
  46. self.dropout_rate = dropout_rate
  47. self.compute_dtype = compute_dtype
  48. self.use_past = use_past
  49. self.dp = data_parallel_num
  50. self.mp = model_parallel_num
  51. self.self_layernorm = self_layernorm
  52. self.forward_reduce_scatter = forward_reduce_scatter
  53. self.stage_num = stage_num
  54. self.micro_size = micro_size
  55. self.word_emb_dp = word_emb_dp
  56. def __str__(self):
  57. info = "[PANGUALPHAConfig]" + '===' * 10 + '\n'
  58. for k, v in self.__dict__.items():
  59. var_info = "{}:{}\n".format(k, v)
  60. info += var_info
  61. info += '=' * 10
  62. return info
  63. get_square_sum = C.MultitypeFuncGraph("get_square_sum")
  64. @get_square_sum.register("Tensor", "Number")
  65. def _get_square_sum(grad, value):
  66. norm = P.ReduceSum(False)(F.square(grad), ()) / value
  67. norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
  68. return norm
  69. apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
  70. @apply_global_norm.register("Tensor", "Tensor", "Tensor")
  71. def _apply_global_norm(clip_norm, global_norm, grad):
  72. grad = grad * clip_norm / global_norm
  73. return grad
  74. class GlobalNormPipline(nn.Cell):
  75. """
  76. Calculate the global norm value of given tensors
  77. """
  78. def __init__(self, params, config):
  79. super(GlobalNormPipline, self).__init__()
  80. self.norm = nn.Norm()
  81. self.hyper_map = C.HyperMap()
  82. self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name and "position_embedding.embedding_table" not in x.name for x in params)
  83. self.allreduce_group_size = ()
  84. for item in self.allreduce_filter:
  85. if item:
  86. self.allreduce_group_size = self.allreduce_group_size + (1.0, )
  87. else:
  88. self.allreduce_group_size = self.allreduce_group_size + (config.mp * 1.0, )
  89. self.length = len(params)
  90. group_list ,group_name = _get_model_parallel_group(config.mp)
  91. print("rank_list", group_name)
  92. print("group_size_list", self.allreduce_group_size)
  93. create_group(group_name, group_list)
  94. self.allreduce = P.AllReduce(group=group_name)
  95. pipeline_group_list, pipeline_group_name = _get_pipeline_group()
  96. print("pipeline_group_name", pipeline_group_name)
  97. create_group(pipeline_group_name, pipeline_group_list)
  98. self.allreduce2 = P.AllReduce(group=pipeline_group_name)
  99. def construct(self, grads):
  100. square_sum = self.hyper_map(get_square_sum, grads, self.allreduce_group_size)
  101. square_reduce_sum = F.addn(square_sum)
  102. stage_square_reduce_sum = self.allreduce(square_reduce_sum)
  103. global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum)
  104. global_norms = F.sqrt(global_square_reduce_sum)
  105. return global_norms
  106. class GlobalNorm(nn.Cell):
  107. """
  108. Calculate the global norm value of given tensors
  109. """
  110. def __init__(self, params, config):
  111. super(GlobalNorm, self).__init__()
  112. self.norm = nn.Norm()
  113. self.hyper_map = C.HyperMap()
  114. self.config = config
  115. self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name and "embedding_table"
  116. not in x.name for x in params)
  117. self.length = len(params)
  118. self.values = []
  119. self.group_size = get_group_size()
  120. for item in self.allreduce_filter:
  121. if item:
  122. self.values.append(1.0)
  123. else:
  124. self.values.append(self.group_size*1.0)
  125. self.values = tuple(self.values)
  126. def construct(self, grads):
  127. square_sum_dp = self.hyper_map(get_square_sum, grads, self.values)
  128. global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp)))
  129. return global_norms
  130. class ClipByGlobalNorm(nn.Cell):
  131. """
  132. Clip grads by global norm
  133. """
  134. def __init__(self, params, config, clip_norm=1.0, pipeline=True):
  135. super(ClipByGlobalNorm, self).__init__()
  136. if pipeline:
  137. self.global_norm = GlobalNormPipline(params, config)
  138. else:
  139. self.global_norm = GlobalNorm(params, config)
  140. self.clip_norm = Tensor([clip_norm], mstype.float32)
  141. self.hyper_map = C.HyperMap()
  142. def construct(self, grads):
  143. global_norm_origin = self.global_norm(grads)
  144. cond = P.GreaterEqual()(global_norm_origin, self.clip_norm)
  145. global_norm = F.select(cond, global_norm_origin, self.clip_norm)
  146. grads = self.hyper_map(
  147. F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
  148. return grads, global_norm_origin
  149. def _get_model_parallel_group(mp):
  150. rank = get_rank()
  151. stage_nums = auto_parallel_context().get_pipeline_stages()
  152. device_nums = get_group_size()
  153. per_stage_device_nums = device_nums // stage_nums
  154. stage_id = rank // per_stage_device_nums
  155. local_stage_rank_id = rank % per_stage_device_nums
  156. index = local_stage_rank_id // mp
  157. group = range(0, mp)
  158. rank_str_list = [str(x + index * mp + stage_id * per_stage_device_nums) for x in group]
  159. rank_list_str = "-".join(rank_str_list)
  160. rank_list = [x + index * mp + stage_id * per_stage_device_nums for x in group]
  161. return rank_list, rank_list_str
  162. def _get_pipeline_group():
  163. rank = get_rank()
  164. stage_nums = auto_parallel_context().get_pipeline_stages()
  165. device_nums = get_group_size()
  166. per_stage_device_nums = device_nums // stage_nums
  167. stage_id = rank // per_stage_device_nums
  168. local_stage_rank_id = rank % per_stage_device_nums
  169. group = range(0, stage_nums)
  170. rank_list = [local_stage_rank_id + x * per_stage_device_nums for x in group]
  171. rank_str_list = [str(local_stage_rank_id + x * per_stage_device_nums) for x in group]
  172. rank_list_str = "-".join(rank_str_list)
  173. return rank_list, rank_list_str
  174. class LearningRate(LearningRateSchedule):
  175. """
  176. Warmup-decay learning rate for PANGUALPHA network.
  177. """
  178. def __init__(self,
  179. learning_rate,
  180. end_learning_rate,
  181. warmup_steps,
  182. decay_steps,
  183. power=1.0,
  184. use_cosine=True):
  185. super(LearningRate, self).__init__()
  186. self.warmup_flag = False
  187. if warmup_steps > 0:
  188. self.warmup_flag = True
  189. self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
  190. self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate,
  191. decay_steps, power)
  192. self.cosine_decay_lr = CosineDecayLR(end_learning_rate, learning_rate,
  193. decay_steps)
  194. self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
  195. self.greater = P.Greater()
  196. self.one = Tensor(np.array([1.0]).astype(np.float32))
  197. self.cast = P.Cast()
  198. self.use_cosine = use_cosine
  199. def construct(self, global_step):
  200. """dynamic learning rate"""
  201. if not self.use_cosine:
  202. decay_lr = self.decay_lr(global_step)
  203. else:
  204. decay_lr = self.cosine_decay_lr(global_step)
  205. if self.warmup_flag:
  206. is_warmup = self.cast(self.greater(self.warmup_steps, global_step),
  207. mstype.float32)
  208. warmup_lr = self.warmup_lr(global_step)
  209. lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
  210. else:
  211. lr = decay_lr
  212. return lr