2000亿开源中文预训练语言模型「鹏程·盘古α」
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

290 lines
12 KiB

  1. """PANGUALPHA training wrapper"""
  2. import mindspore.nn as nn
  3. from mindspore.ops import operations as P
  4. from mindspore.ops import composite as C
  5. from mindspore.ops import functional as F
  6. from mindspore import context
  7. from mindspore.context import ParallelMode
  8. from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
  9. from mindspore.communication.management import get_group_size, create_group
  10. from mindspore.common.tensor import Tensor
  11. import mindspore.common.dtype as mstype
  12. from mindspore.common.parameter import Parameter
  13. from mindspore.ops.operations.comm_ops import _VirtualDataset
  14. from utils import ClipByGlobalNorm
  15. GRADIENT_CLIP_TYPE = 1
  16. GRADIENT_CLIP_VALUE = 1.0
  17. clip_grad = C.MultitypeFuncGraph("clip_grad")
  18. @clip_grad.register("Number", "Number", "Tensor")
  19. def _clip_grad(clip_type, clip_value, grad):
  20. """
  21. Clip gradients.
  22. Inputs:
  23. clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
  24. clip_value (float): Specifies how much to clip.
  25. grad (tuple[Tensor]): Gradients.
  26. Outputs:
  27. tuple[Tensor], clipped gradients.
  28. """
  29. if clip_type not in [0, 1]:
  30. return grad
  31. dt = F.dtype(grad)
  32. if clip_type == 0:
  33. new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
  34. F.cast(F.tuple_to_array((clip_value,)), dt))
  35. else:
  36. new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
  37. return new_grad
  38. grad_scale = C.MultitypeFuncGraph("grad_scale")
  39. reciprocal = P.Reciprocal()
  40. update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
  41. @update_accu_grads.register("Tensor", "Tensor")
  42. def _update_accu_grads(accu_grad, grad):
  43. return F.depend(accu_grad, grad)
  44. @grad_scale.register("Tensor", "Tensor", "Tensor")
  45. def tensor_grad_scale(scale, grad, accu_grad):
  46. #mul = P.Mul()
  47. new_grad = accu_grad * reciprocal(scale)
  48. zeros = F.tensor_mul(accu_grad, 0.0)
  49. clear = F.assign(accu_grad, zeros)
  50. F.control_depend(new_grad, clear)
  51. F.control_depend(grad, new_grad)
  52. return new_grad
  53. @grad_scale.register("Tensor", "Tensor")
  54. def tensor_grad_scale(scale, grad):
  55. return grad * reciprocal(scale)
  56. class VirtualDatasetOneInputCell(nn.Cell):
  57. def __init__(self, backbone):
  58. super(VirtualDatasetOneInputCell, self).__init__(auto_prefix=False)
  59. self._backbone = backbone
  60. self._virtual_dataset = _VirtualDataset()
  61. def construct(self, *data):
  62. data_ = self._virtual_dataset(*data)
  63. return self._backbone(*data_)
  64. class PANGUALPHATrainPipelineWithLossScaleCell(nn.Cell):
  65. """
  66. Encapsulation class of PANGUALPHA network training.
  67. Append an optimizer to the training network after that the construct
  68. function can be called to create the backward graph.
  69. Args:
  70. network (Cell): The training network. Note that loss function should have been added.
  71. optimizer (Optimizer): Optimizer for updating the weights.
  72. scale_update_cell (Cell): Cell to do the loss scale. Default: None.
  73. """
  74. def __init__(self, network, optimizer, config, scale_update_cell=None, enable_global_norm=True):
  75. super(PANGUALPHATrainPipelineWithLossScaleCell, self).__init__(auto_prefix=False)
  76. self.config = config
  77. self.network = network
  78. self.network.add_flags(defer_inline=True)
  79. self.weights = optimizer.parameters
  80. self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
  81. self.optimizer = optimizer
  82. self.enable_global_norm = enable_global_norm
  83. self.grad = C.GradOperation(get_by_list=True,
  84. sens_param=True)
  85. self.reducer_flag = False
  86. self.allreduce = P.AllReduce()
  87. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  88. if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
  89. self.reducer_flag = True
  90. self.grad_reducer = F.identity
  91. self.degree = 1
  92. if self.reducer_flag:
  93. self.degree = get_group_size()
  94. self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
  95. self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
  96. self.cast = P.Cast()
  97. self.alloc_status = P.NPUAllocFloatStatus(False)
  98. self.get_status = P.NPUGetFloatStatus(False)
  99. self.clear_before_grad = P.NPUClearFloatStatus(False)
  100. self.reduce_sum = P.ReduceSum(keep_dims=False)
  101. self.depend_parameter_use = P.ControlDepend(depend_mode=1)
  102. self.base = Tensor(1, mstype.float32)
  103. self.less_equal = P.LessEqual()
  104. self.hyper_map = C.HyperMap()
  105. self.loss_scale = None
  106. self.reshape = P.Reshape()
  107. self.control = P.ControlDepend(1)
  108. self.clip_norm = Tensor(1000.0, mstype.float32)
  109. self.loss_scaling_manager = scale_update_cell
  110. if scale_update_cell:
  111. self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
  112. name="loss_scale")
  113. self.clip = ClipByGlobalNorm(self.weights, self.config)
  114. self.micro_size = config.micro_size
  115. @C.add_flags(has_effect=True)
  116. def construct(self,
  117. input_ids,
  118. input_position,
  119. attention_mask,
  120. past=None,
  121. sens=None):
  122. """Defines the computation performed."""
  123. weights = self.weights
  124. loss = self.network(input_ids, input_position, attention_mask)
  125. if sens is None:
  126. scaling_sens = self.loss_scale
  127. scaling_sens = self.reshape(scaling_sens, (1,))
  128. else:
  129. scaling_sens = sens
  130. # alloc status and clear should be right before gradoperation
  131. init = self.alloc_status()
  132. status_clear = self.clear_before_grad(init)
  133. #clear_depend = self.control(status_clear, self.weights)
  134. grads = self.grad(self.network, weights)(input_ids,
  135. input_position,
  136. attention_mask,
  137. self.cast(scaling_sens / self.micro_size,
  138. mstype.float32))
  139. get_status = self.get_status(init)
  140. get_status_depend = F.control_depend(grads, get_status)
  141. flag_sum = self.reduce_sum(init, (0,))
  142. flag_sum_depend = F.control_depend(get_status, flag_sum)
  143. loss = F.depend(loss, status_clear)
  144. loss = F.depend(loss, get_status_depend)
  145. loss = F.depend(loss, flag_sum_depend)
  146. # apply grad reducer on grads
  147. accu_grads = self.grad_reducer(self.accu_grads)
  148. grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)
  149. grads, global_norms = self.clip(grads)
  150. global_norm = P.Reshape()(global_norms, (()))
  151. if self.is_distributed:
  152. # sum overflow flag over devices
  153. flag_reduce = self.allreduce(flag_sum)
  154. cond = self.less_equal(self.base, flag_reduce)
  155. else:
  156. cond = self.less_equal(self.base, flag_sum)
  157. overflow = cond
  158. if sens is None:
  159. overflow = self.loss_scaling_manager(self.loss_scale, cond)
  160. if overflow:
  161. succ = False
  162. else:
  163. succ = self.optimizer(grads)
  164. ret = (loss, overflow, scaling_sens, global_norm)
  165. return F.depend(ret, succ)
  166. class PANGUALPHATrainOneStepWithLossScaleCell(nn.Cell):
  167. """
  168. Encapsulation class of PANGUALPHA network training.
  169. Append an optimizer to the training network after that the construct
  170. function can be called to create the backward graph.
  171. Args:
  172. network (Cell): The training network. Note that loss function should have been added.
  173. optimizer (Optimizer): Optimizer for updating the weights.
  174. scale_update_cell (Cell): Cell to do the loss scale. Default: None.
  175. """
  176. def __init__(self,
  177. network,
  178. optimizer,
  179. scale_update_cell=None,
  180. enable_global_norm=True,
  181. config=None):
  182. super(PANGUALPHATrainOneStepWithLossScaleCell,
  183. self).__init__(auto_prefix=False)
  184. self.network = network
  185. self.config = config
  186. self.network.add_flags(defer_inline=True)
  187. self.weights = optimizer.parameters
  188. self.optimizer = optimizer
  189. self.enable_global_norm = enable_global_norm
  190. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  191. self.reducer_flag = False
  192. self.allreduce = P.AllReduce()
  193. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  194. if self.parallel_mode in [
  195. ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL
  196. ]:
  197. self.reducer_flag = True
  198. self.grad_reducer = F.identity
  199. self.degree = 1
  200. if self.reducer_flag:
  201. self.degree = get_group_size()
  202. self.grad_reducer = DistributedGradReducer(optimizer.parameters,
  203. False, self.degree)
  204. self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
  205. self.cast = P.Cast()
  206. self.alloc_status = P.NPUAllocFloatStatus()
  207. self.get_status = P.NPUGetFloatStatus()
  208. self.clear_before_grad = P.NPUClearFloatStatus()
  209. self.reduce_sum = P.ReduceSum(keep_dims=False)
  210. self.depend_parameter_use = P.ControlDepend(depend_mode=1)
  211. self.base = Tensor(1, mstype.float32)
  212. self.less_equal = P.LessEqual()
  213. self.hyper_map = C.HyperMap()
  214. self.loss_scale = None
  215. self.loss_scaling_manager = scale_update_cell
  216. if scale_update_cell:
  217. self.loss_scale = Parameter(Tensor(
  218. scale_update_cell.get_loss_scale(), dtype=mstype.float32),
  219. name="loss_scale")
  220. self.clip = ClipByGlobalNorm(self.weights, self.config, pipeline=False)
  221. @C.add_flags(has_effect=True)
  222. def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None):
  223. """Defines the computation performed."""
  224. weights = self.weights
  225. loss = self.network(input_ids, input_position, attention_mask)
  226. if sens is None:
  227. scaling_sens = self.loss_scale
  228. else:
  229. scaling_sens = sens
  230. # alloc status and clear should be right before gradoperation
  231. init = self.alloc_status()
  232. self.clear_before_grad(init)
  233. grads = self.grad(self.network,
  234. weights)(input_ids,
  235. input_position, attention_mask,
  236. self.cast(scaling_sens, mstype.float32))
  237. # apply grad reducer on grads
  238. grads = self.grad_reducer(grads)
  239. grads = self.hyper_map(
  240. F.partial(grad_scale, scaling_sens * self.degree), grads)
  241. grads, global_norms = self.clip(grads)
  242. global_norm = P.Reshape()(global_norms, (()))
  243. self.get_status(init)
  244. flag_sum = self.reduce_sum(init, (0,))
  245. if self.is_distributed:
  246. # sum overflow flag over devices
  247. flag_reduce = self.allreduce(flag_sum)
  248. cond = self.less_equal(self.base, flag_reduce)
  249. else:
  250. cond = self.less_equal(self.base, flag_sum)
  251. overflow = cond
  252. if sens is None:
  253. overflow = self.loss_scaling_manager(self.loss_scale, cond)
  254. if overflow:
  255. succ = False
  256. else:
  257. succ = self.optimizer(grads)
  258. ret = (loss, cond, scaling_sens, global_norm)
  259. return F.depend(ret, succ)