2000亿开源中文预训练语言模型「鹏程·盘古α」 https://pangu-alpha.openi.org.cn
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1104 lines
48 KiB

  1. """PANGUALPHA model"""
  2. import math
  3. import numpy as np
  4. import os
  5. import mindspore.nn as nn
  6. from mindspore.common.tensor import Tensor
  7. from mindspore.common.parameter import Parameter
  8. import mindspore.common.dtype as mstype
  9. from mindspore.common.initializer import initializer, Normal, TruncatedNormal
  10. from mindspore.ops import operations as P
  11. from mindspore.ops import functional as F
  12. class LayerNorm(nn.Cell):
  13. def __init__(self, normalized_shape, dp=4, eps=1e-5):
  14. super(LayerNorm, self).__init__()
  15. self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma", parallel_optimizer=False)
  16. self.beta = Parameter(initializer('zeros', normalized_shape), name="beta", parallel_optimizer=False)
  17. self.mean = P.ReduceMean(keep_dims=True).shard(((dp, 1, 1),))
  18. self.square = P.Square().shard(((dp, 1, 1),))
  19. self.sqrt = P.Sqrt().shard(((dp, 1, 1),))
  20. self.sub1 = P.Sub().shard(((dp, 1, 1), (dp, 1, 1)))
  21. self.sub2 = P.Sub().shard(((dp, 1, 1), (dp, 1, 1)))
  22. self.add = P.TensorAdd().shard(((dp, 1, 1), ()))
  23. self.eps = eps
  24. self.mul = P.Mul().shard(((dp, 1, 1), (1,)))
  25. self.add2 = P.TensorAdd().shard(((dp, 1, 1), (1,)))
  26. self.real_div = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1)))
  27. def construct(self, x):
  28. mean = self.mean(x, -1)
  29. diff = self.sub1(x, mean)
  30. variance = self.mean(self.square(diff), -1)
  31. variance_eps = self.sqrt(self.add(variance, self.eps))
  32. output = self.real_div(diff, variance_eps)
  33. output = self.add2(self.mul(output, self.gamma), self.beta)
  34. return output
  35. class Mapping(nn.Cell):
  36. """
  37. A mapping function with a 3d input
  38. Args:
  39. input_size: the size of the last dimension of the input tensor
  40. output_size: the desired size of the last dimension of the output tensor
  41. dtype: the compute datatype
  42. scale: the scale factor for initialization
  43. Inputs:
  44. x: the 3d input
  45. Returns:
  46. output: Tensor, a 3d tensor after projection
  47. """
  48. def __init__(self, config, input_size, output_size, scale=1.0):
  49. super(Mapping, self).__init__()
  50. self.output_size = output_size
  51. self.input_size = input_size
  52. self.weight = Parameter(initializer(Normal(sigma=0.02 * scale),
  53. [input_size, output_size]),
  54. name="mapping_weight")
  55. self.bias = Parameter(initializer("zeros", [
  56. output_size,
  57. ]),
  58. name="mapping_bias",
  59. parallel_optimizer=False)
  60. self.dtype = config.compute_dtype
  61. self.cast = P.Cast()
  62. #self.cast.add_prim_attr("_side_effect", True)
  63. self.add = P.TensorAdd().shard(((config.dp, 1), (1,)))
  64. self.matmul = P.MatMul().shard(
  65. ((config.dp, config.mp), (config.mp, 1)))
  66. self.matmul.add_prim_attr("recompute_comm_op", False)
  67. def construct(self, x):
  68. out_shape = P.Shape()(x)[:-1] + (self.output_size,)
  69. x = P.Reshape()(x, (-1, self.input_size))
  70. weight = self.cast(self.weight, self.dtype)
  71. x = self.matmul(x, weight)
  72. x = self.add(x, self.cast(self.bias, self.dtype))
  73. output = P.Reshape()(x, out_shape)
  74. return output
  75. class Mapping_output(nn.Cell):
  76. def __init__(self, config, input_size, output_size, scale=1.0):
  77. super(Mapping_output, self).__init__()
  78. self.output_size = output_size
  79. self.input_size = input_size
  80. self.weight = Parameter(initializer(Normal(sigma=0.02 * scale),
  81. [input_size, output_size]),
  82. name="mapping_weight")
  83. self.bias = Parameter(initializer("zeros", [
  84. output_size,
  85. ]),
  86. name="mapping_bias")
  87. self.dtype = config.compute_dtype
  88. self.cast = P.Cast()
  89. #self.cast.add_prim_attr("_side_effect", True)
  90. self.add = P.TensorAdd().shard(((config.dp, config.mp), (config.mp,)))
  91. self.matmul = P.MatMul().shard(((config.dp, 1), (1, config.mp)))
  92. def construct(self, x):
  93. out_shape = P.Shape()(x)[:-1] + (self.output_size,)
  94. x = P.Reshape()(x, (-1, self.input_size))
  95. weight = self.cast(self.weight, self.dtype)
  96. x = self.matmul(x, weight)
  97. x = self.add(x, self.cast(self.bias, self.dtype))
  98. output = P.Reshape()(x, out_shape)
  99. return output
  100. class Output(nn.Cell):
  101. """
  102. The output mapping module for each layer
  103. Args:
  104. config(PANGUALPHAConfig): the config of network
  105. scale: scale factor for initialization
  106. Inputs:
  107. x: output of the self-attention module
  108. Returns:
  109. output: Tensor, the output of this layer after mapping
  110. """
  111. def __init__(self, config, scale=1.0):
  112. super(Output, self).__init__()
  113. input_size = config.embedding_size
  114. output_size = config.embedding_size * config.expand_ratio
  115. self.mapping = Mapping_output(config, input_size, output_size)
  116. self.projection = Mapping(config, output_size, input_size, scale)
  117. self.activation = nn.GELU()
  118. self.activation.gelu.shard(((config.dp, 1, config.mp),))
  119. self.dropout = nn.Dropout(1 - config.dropout_rate)
  120. self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
  121. self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
  122. def construct(self, x):
  123. hidden = self.activation(self.mapping(x))
  124. output = self.projection(hidden)
  125. output = self.dropout(output)
  126. return output
  127. class AttentionMask(nn.Cell):
  128. """
  129. Get the attention matrix for self-attention module
  130. Args:
  131. config(PANGUALPHAConfig): the config of network
  132. Inputs:
  133. input_mask: the mask indicating whether each position is a valid input
  134. Returns:
  135. attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
  136. """
  137. def __init__(self, config):
  138. super(AttentionMask, self).__init__()
  139. self.reshape = P.Reshape()
  140. self.mul = P.BatchMatMul().shard(
  141. ((config.dp, 1, 1), (config.dp, 1, 1))) # yzz: use 64, 1, 1?
  142. self.expand_dim = P.ExpandDims().shard(((1, 1),))
  143. ones = np.ones(shape=(config.seq_length, config.seq_length))
  144. self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
  145. self.multiply = P.Mul().shard(((config.dp, 1, 1), (1, 1, 1)))
  146. def construct(self, input_mask):
  147. input_shape = P.Shape()(input_mask)
  148. shape_right = (input_shape[0], 1, input_shape[1])
  149. shape_left = input_shape + (1,)
  150. mask_left = self.reshape(input_mask, shape_left)
  151. mask_right = self.reshape(input_mask, shape_right)
  152. attention_mask = self.mul(mask_left, mask_right)
  153. lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0)
  154. attention_mask = self.multiply(
  155. attention_mask, lower_traiangle) #bs seq_length seq_length
  156. return attention_mask
  157. class EmbeddingLookupPipeline(nn.Cell):
  158. """
  159. The embedding lookup table for vocabulary
  160. Args:
  161. config(PANGUALPHAConfig): the config of network
  162. Inputs:
  163. input_ids: the tokenized inputs with datatype int32
  164. Returns:
  165. output: Tensor, the embedding vector for the input with shape (batch_size, seq_length, embedding_size)
  166. self.embedding_table: Tensor, the embedding table for the vocabulary
  167. """
  168. def __init__(self, config):
  169. super(EmbeddingLookupPipeline, self).__init__()
  170. self.vocab_size = config.vocab_size
  171. self.embedding_size = config.embedding_size
  172. if config.word_emb_dp:
  173. self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1)))
  174. else:
  175. self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1)))
  176. self.gather.add_prim_attr("repeated_calc_num_direction", "left")
  177. if config.forward_reduce_scatter:
  178. self.gather.add_prim_attr("forward_type", "ReduceScatter")
  179. self.gather.add_prim_attr("begin", 0)
  180. self.shape = (-1, config.seq_length, config.embedding_size)
  181. def construct(self, input_ids, table):
  182. output = self.gather(table, input_ids, 0)
  183. return output
  184. class EmbeddingLookup(nn.Cell):
  185. def __init__(self, config):
  186. """
  187. The embedding lookup table for vocabulary
  188. Args:
  189. config(PANGUALPHAConfig): the config of network
  190. Inputs:
  191. input_ids: the tokenized inputs with datatype int32
  192. Returns:
  193. output: Tensor, the embedding vector for the input with shape (batch_size, seq_length, embedding_size)
  194. self.embedding_table: Tensor, the embedding table for the vocabulary
  195. """
  196. super(EmbeddingLookup, self).__init__()
  197. self.vocab_size = config.vocab_size
  198. self.embedding_size = config.embedding_size
  199. if config.load_ckpt_path:
  200. # Loading the embedding table from the ckpt path:
  201. embedding_path = os.path.join(config.load_ckpt_path, 'word_embedding.npy')
  202. if os.path.exists(embedding_path):
  203. e_table = np.load(embedding_path)
  204. e_table = Tensor(e_table, mstype.float32)
  205. self.embedding_table = Parameter(e_table, name="embedding_table")
  206. else:
  207. raise ValueError(f"{embedding_path} file not exits, please check whether word_embedding file exist.")
  208. else:
  209. self.embedding_table = Parameter(initializer(
  210. Normal(0.02), [self.vocab_size, self.embedding_size]),
  211. name="embedding_table")
  212. if config.word_emb_dp:
  213. self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1)))
  214. else:
  215. self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1)))
  216. self.gather.add_prim_attr("repeated_calc_num_direction", "left")
  217. if config.forward_reduce_scatter:
  218. self.gather.add_prim_attr("forward_type", "ReduceScatter")
  219. self.shape = (-1, config.seq_length, config.embedding_size)
  220. def construct(self, input_ids):
  221. output = self.gather(self.embedding_table, input_ids, 0)
  222. return output, self.embedding_table
  223. class Attention(nn.Cell):
  224. """
  225. Self-Attention module for each layer
  226. Args:
  227. config(PANGUALPHAConfig): the config of network
  228. scale: scale factor for initialization
  229. layer_idx: current layer index
  230. """
  231. def __init__(self, config, scale=1.0, layer_idx=None):
  232. super(Attention, self).__init__()
  233. self.get_attention_mask = AttentionMask(config)
  234. self.projection = Mapping(config, config.embedding_size,
  235. config.embedding_size, scale)
  236. self.transpose = P.Transpose().shard(((config.dp, 1, config.mp, 1),))
  237. self.merger_head_transpose = P.Transpose().shard(
  238. ((config.dp, config.mp, 1, 1),))
  239. self.reshape = P.Reshape()
  240. self.n_head = config.num_heads
  241. self.size_per_head = config.embedding_size // self.n_head
  242. self.concat_k = P.Concat(axis=3)
  243. self.concat_v = P.Concat(axis=2)
  244. self.multiply_data = Tensor([
  245. -10000.0,
  246. ], dtype=mstype.float32)
  247. self.batch_matmul = P.BatchMatMul().shard(
  248. ((config.dp, config.mp, 1, 1), (config.dp, config.mp, 1, 1)))
  249. self.scale = scale
  250. self.real_div = P.RealDiv().shard(((config.dp, config.mp, 1, 1), ()))
  251. self.sub = P.Sub().shard(((1,), (config.dp, 1, 1, 1))).add_prim_attr("_side_effect", True)
  252. self.mul = P.Mul().shard(((config.dp, 1, 1, 1), (1,))).add_prim_attr("_side_effect", True)
  253. self.add = P.TensorAdd().shard(
  254. ((config.dp, 1, 1, 1), (config.dp, config.mp, 1, 1)))
  255. if self.scale:
  256. self.scale_factor = Tensor(math.sqrt(self.size_per_head))
  257. if layer_idx is not None:
  258. self.coeff = math.sqrt(layer_idx * math.sqrt(self.size_per_head))
  259. self.coeff = Tensor(self.coeff)
  260. self.use_past = config.use_past
  261. self.dropout = nn.Dropout(1 - config.dropout_rate)
  262. self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
  263. self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
  264. self.prob_dropout = nn.Dropout(1 - config.dropout_rate)
  265. self.prob_dropout.dropout_gen_mask.shard(
  266. ((config.dp, config.mp, 1, 1),))
  267. self.prob_dropout.dropout_do_mask.shard(
  268. ((config.dp, config.mp, 1, 1),))
  269. self.softmax = nn.Softmax()
  270. self.softmax.softmax.shard(((config.dp, config.mp, 1),))
  271. self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
  272. self.dense1 = nn.Dense(config.embedding_size,
  273. config.embedding_size).to_float(
  274. config.compute_dtype)
  275. self.dense1.matmul.shard(((config.dp, 1), (config.mp, 1)))
  276. self.dense1.bias_add.shard(((config.dp, config.mp), (config.mp,)))
  277. self.dense2 = nn.Dense(config.embedding_size,
  278. config.embedding_size).to_float(
  279. config.compute_dtype)
  280. self.dense2.matmul.shard(((config.dp, 1), (config.mp, 1)))
  281. self.dense2.bias_add.shard(((config.dp, config.mp), (config.mp,)))
  282. self.dense3 = nn.Dense(config.embedding_size,
  283. config.embedding_size).to_float(
  284. config.compute_dtype)
  285. self.dense3.matmul.shard(((config.dp, 1), (config.mp, 1)))
  286. self.dense3.bias_add.shard(((config.dp, config.mp), (config.mp,)))
  287. def construct(self, x, attention_mask, layer_past=None):
  288. """
  289. self-attention
  290. Inputs:
  291. x: output of previous layer
  292. attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
  293. layer_past: the previous feature map
  294. Returns:
  295. output: Tensor, the output logit of this layer
  296. layer_present: Tensor, the feature map of current layer
  297. """
  298. original_shape = F.shape(x)
  299. x = F.reshape(x, (-1, original_shape[-1]))
  300. query = self.dense1(x)
  301. key = self.dense2(x)
  302. value = self.dense3(x)
  303. query = self.transpose(
  304. F.reshape(
  305. query,
  306. (-1, original_shape[1], self.n_head, self.size_per_head)),
  307. (0, 2, 1, 3))
  308. key = self.transpose(
  309. F.reshape(
  310. key, (-1, original_shape[1], self.n_head, self.size_per_head)),
  311. (0, 2, 3, 1))
  312. value = self.transpose(
  313. F.reshape(
  314. value,
  315. (-1, original_shape[1], self.n_head, self.size_per_head)),
  316. (0, 2, 1, 3))
  317. if self.use_past:
  318. past_value = layer_past[1]
  319. past_key = self.transpose(layer_past[0], (0, 1, 3, 2))
  320. key = self.concat_k((past_key, key))
  321. value = self.concat_v(past_value, value)
  322. layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
  323. attention = self._attn(query, key, value, attention_mask)
  324. attention_merge = self.merge_heads(attention)
  325. output = self.projection(attention_merge)
  326. output = self.dropout(output)
  327. return output, layer_present
  328. def split_heads(self, x, transpose):
  329. """
  330. split 3d tensor to 4d and switch certain axes
  331. Inputs:
  332. x: input tensor
  333. transpose: tuple, the transpose sequence
  334. Returns:
  335. x_transpose: the 4d output
  336. """
  337. x_size = P.Shape()(x)
  338. new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head)
  339. x = self.reshape(x, new_x_shape)
  340. x_transpose = self.transpose(x, transpose)
  341. return x_transpose
  342. def merge_heads(self, x):
  343. """
  344. convert a 4d input to a 3d output
  345. Inputs:
  346. x: input tensor
  347. Returns:
  348. x_merge: the 3d output
  349. """
  350. x = self.merger_head_transpose(
  351. x, (0, 2, 1, 3)) #bs, seq_length, head, size_per_head
  352. x_shape = P.Shape()(x)
  353. new_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],)
  354. x_merge = self.reshape(x, new_shape)
  355. return x_merge
  356. def _attn(self, query, key, value, attention_mask):
  357. """
  358. Get the weighted score along the seq_length
  359. Inputs:
  360. query: the query matrix
  361. key: the key matrix
  362. value: the value matrix
  363. attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
  364. Returns:
  365. weighted_values: Tensor, the weighted sum scores
  366. """
  367. if not self.scale:
  368. query = query / F.cast(self.coeff, F.dtype(query))
  369. key = key / F.cast(self.coeff, F.dtype(key))
  370. score = self.batch_matmul(query, key)
  371. if self.scale:
  372. score = self.real_div(
  373. score,
  374. P.Cast()(self.scale_factor, P.DType()(score)))
  375. ori_dtype = P.DType()(score)
  376. score = P.Cast()(score, mstype.float32)
  377. multiplu_out = self.sub(
  378. P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
  379. P.Cast()(attention_mask, P.DType()(score)))
  380. adder = self.mul(multiplu_out, self.multiply_data)
  381. attention_scores = self.add(adder, score)
  382. shape = F.shape(attention_scores)
  383. attention_probs = self.softmax(
  384. F.reshape(attention_scores,
  385. (shape[0], -1, shape[-1])))
  386. attention_probs = P.Cast()(attention_probs, ori_dtype)
  387. attention_probs = F.reshape(attention_probs, shape)
  388. attention_probs = self.prob_dropout(attention_probs)
  389. weighted_values = self.batch_matmul(attention_probs, value)
  390. return weighted_values
  391. class Block(nn.Cell):
  392. """
  393. The basic block of PANGUALPHA network
  394. Args:
  395. config(PANGUALPHAConfig): the config of network
  396. layer_idx: current layer index
  397. Inputs:
  398. x: the output of previous layer(input_ids for the first layer)
  399. attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
  400. layer_past: the previous feature map
  401. Returns:
  402. output: Tensor, the output logit of this layer
  403. layer_present: Tensor, the feature map of current layer
  404. """
  405. def __init__(self, config, layer_idx):
  406. super(Block, self).__init__()
  407. scale = 1 / math.sqrt(2.0 * config.num_layers)
  408. if config.self_layernorm:
  409. self.layernorm1 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
  410. self.layernorm2 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
  411. else:
  412. self.layernorm1 = nn.LayerNorm(
  413. (config.embedding_size,)).to_float(mstype.float32)
  414. self.layernorm1.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
  415. self.layernorm2 = nn.LayerNorm(
  416. (config.embedding_size,)).to_float(mstype.float32)
  417. self.layernorm2.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
  418. self.attention = Attention(config, scale, layer_idx)
  419. self.output = Output(config, scale)
  420. self.post_layernorm_residual = config.post_layernorm_residual
  421. self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
  422. self.add_last = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
  423. self.add_last.recompute(False)
  424. self.dtype = config.compute_dtype
  425. def construct(self, x, input_mask, layer_past=None):
  426. input_x = self.layernorm1(x)
  427. input_x = F.cast(input_x, self.dtype)
  428. attention, layer_present = self.attention(input_x, input_mask,
  429. layer_past)
  430. if self.post_layernorm_residual:
  431. x = self.add(input_x, attention)
  432. else:
  433. x = self.add(x, attention)
  434. output_x = self.layernorm2(x)
  435. output_x = F.cast(output_x, self.dtype)
  436. mlp_logit = self.output(output_x)
  437. if self.post_layernorm_residual:
  438. output = self.add_last(output_x, mlp_logit)
  439. else:
  440. output = self.add_last(x, mlp_logit)
  441. return output, layer_present
  442. class PANGUALPHA_EmbeddingPipeLine(nn.Cell):
  443. def __init__(self, config):
  444. super(PANGUALPHA_EmbeddingPipeLine, self).__init__()
  445. self.word_embedding = EmbeddingLookupPipeline(config)
  446. self.position_embedding = nn.Embedding(config.seq_length,
  447. config.embedding_size,
  448. embedding_table=Normal(0.02))
  449. self.position_embedding.gather.shard(((1, 1), (config.dp ,)))
  450. self.position_embedding.expand.shard(((config.dp, 1),))
  451. self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
  452. self.dropout = nn.Dropout(1 - config.dropout_rate)
  453. self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
  454. self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
  455. def construct(self, input_ids, table, input_position):
  456. input_embedding = self.word_embedding(input_ids, table)
  457. position_embedding = self.position_embedding(input_position)
  458. hidden_states = self.add(input_embedding, position_embedding)
  459. hidden_states = self.dropout(hidden_states)
  460. hidden_states = P.Cast()(hidden_states, mstype.float16)
  461. return hidden_states
  462. class PANGUALPHA_Mask(nn.Cell):
  463. def __init__(self, config):
  464. super(PANGUALPHA_Mask, self).__init__()
  465. self.get_attention_mask = AttentionMask(config)
  466. self.dtype = config.compute_dtype
  467. self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
  468. def construct(self, input_mask, attention_mask):
  469. attention_mask = self.expand_dims(attention_mask, 1)
  470. return attention_mask
  471. class QueryLayerAttention(Attention):
  472. def construct(self, x, query_hidden_state, attention_mask, layer_past=None):
  473. original_shape = F.shape(x)
  474. x = F.reshape(x, (-1, original_shape[-1]))
  475. query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1]))
  476. query = self.dense1(query_hidden_state)
  477. key = self.dense2(x)
  478. value = self.dense3(x)
  479. query = self.transpose(
  480. F.reshape(
  481. query,
  482. (-1, original_shape[1], self.n_head, self.size_per_head)),
  483. (0, 2, 1, 3))
  484. key = self.transpose(
  485. F.reshape(
  486. key, (-1, original_shape[1], self.n_head, self.size_per_head)),
  487. (0, 2, 3, 1))
  488. value = self.transpose(
  489. F.reshape(
  490. value,
  491. (-1, original_shape[1], self.n_head, self.size_per_head)),
  492. (0, 2, 1, 3))
  493. if self.use_past:
  494. past_value = layer_past[1]
  495. past_key = self.transpose(layer_past[0], (0, 1, 3, 2))
  496. key = self.concat_k((past_key, key))
  497. value = self.concat_v(past_value, value)
  498. layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value])
  499. attention = self._attn(query, key, value, attention_mask)
  500. attention_merge = self.merge_heads(attention)
  501. output = self.projection(attention_merge)
  502. output = self.dropout(output)
  503. return output, layer_present
  504. class QueryLayer(nn.Cell):
  505. def __init__(self, config):
  506. super(QueryLayer, self).__init__()
  507. scale = 1 / math.sqrt(2.0 * config.num_layers)
  508. self.layernorm1 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
  509. self.layernorm2 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
  510. self.layernorm1.gamma.parallel_optimizer = False
  511. self.layernorm1.beta.parallel_optimizer = False
  512. self.attention = QueryLayerAttention(config, scale)
  513. self.layernorm2.gamma.parallel_optimizer = False
  514. self.layernorm2.beta.parallel_optimizer = False
  515. self.output = Output(config, scale)
  516. self.post_layernorm_residual = config.post_layernorm_residual
  517. self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
  518. self.last_add = P.TensorAdd().shard(
  519. ((config.dp, 1, 1), (config.dp, 1,
  520. 1))).add_prim_attr("recompute", False)
  521. self.dtype = config.compute_dtype
  522. def construct(self, x, query_hidden_state, input_mask, layer_past=None):
  523. input_x = self.layernorm1(x)
  524. input_x = F.cast(input_x, self.dtype)
  525. attention, layer_present = self.attention(input_x,
  526. query_hidden_state,
  527. input_mask,
  528. layer_past)
  529. if self.post_layernorm_residual:
  530. x = self.add(input_x, attention)
  531. else:
  532. x = self.add(x, attention)
  533. output_x = self.layernorm2(x)
  534. output_x = F.cast(output_x, self.dtype)
  535. mlp_logit = self.output(output_x)
  536. if self.post_layernorm_residual:
  537. output = self.last_add(output_x, mlp_logit)
  538. else:
  539. output = self.last_add(x, mlp_logit)
  540. return output, layer_present
  541. class PANGUALPHA_ModelPipeline(nn.Cell):
  542. """
  543. The backbone of PANGUALPHA network
  544. Args:
  545. config(PANGUALPHAConfig): the config of network
  546. Inputs:
  547. input_ids: the tokenized inputs with datatype int32
  548. input_mask: the mask indicating whether each position is a valid input
  549. layer_past: the previous feature map
  550. Returns:
  551. output_state: Tensor, the output logit of backbone
  552. present_layer: Tensor, the current feature map
  553. embedding_table: Tensor, the embedding table for the vocabulary
  554. """
  555. def __init__(self, config):
  556. super(PANGUALPHA_ModelPipeline, self).__init__()
  557. self.pangu_alpha_embedding = PANGUALPHA_EmbeddingPipeLine(config).set_comm_fusion(1)
  558. self.pangu_alpha_embedding.stage = 0
  559. self.pangu_alpha_mask = PANGUALPHA_Mask(config)
  560. self.blocks = nn.CellList()
  561. dropout_recompute = False
  562. self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size,
  563. embedding_table=TruncatedNormal(0.02))
  564. self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
  565. self.top_query_embedding.expand.shard(((config.dp, 1),))
  566. for i in range(config.num_layers):
  567. if i == config.num_layers - 1:
  568. self.top_query_embedding.set_comm_fusion(2)
  569. self.top_query_embedding.stage = i * config.stage_num // config.num_layers
  570. per_block = QueryLayer(config).set_comm_fusion(2)
  571. else:
  572. per_block = Block(config, i + 1).set_comm_fusion(2)
  573. per_block.stage = i * config.stage_num // config.num_layers
  574. per_block.recompute()
  575. self.blocks.append(per_block)
  576. if not dropout_recompute:
  577. per_block.attention.dropout.dropout_gen_mask.recompute(False).add_prim_attr("_side_effect", True)
  578. per_block.attention.prob_dropout.dropout_gen_mask.recompute(False).add_prim_attr("_side_effect", True)
  579. per_block.output.dropout.dropout_gen_mask.recompute(False).add_prim_attr("_side_effect", True)
  580. if config.self_layernorm:
  581. self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
  582. else:
  583. self.layernorm = nn.LayerNorm(
  584. (config.embedding_size,)).to_float(mstype.float32)
  585. self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
  586. self.layernorm.set_comm_fusion(2)
  587. #self.layernorm.set_comm_fusion(3)
  588. self.layernorm.stage = config.stage_num - 1
  589. self.use_past = config.use_past
  590. self.past = tuple([None] * config.num_layers)
  591. self.dtype = config.compute_dtype
  592. self.num_layers = config.num_layers
  593. def construct(self, input_ids, input_mask, table, input_position, attention_mask, layer_past=None):
  594. """PANGUALPHA model"""
  595. if not self.use_past:
  596. layer_past = self.past
  597. hidden_states = self.pangu_alpha_embedding(input_ids, table, input_position)
  598. attention_mask = self.pangu_alpha_mask(input_mask, attention_mask)
  599. present_layer = ()
  600. for i in range(self.num_layers-1):
  601. hidden_states, present = self.blocks[i](hidden_states,
  602. attention_mask, layer_past)
  603. present_layer = present_layer + (present,)
  604. top_query_hidden_states = self.top_query_embedding(input_position)
  605. hidden_states, present = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
  606. attention_mask, layer_past)
  607. present_layer = present_layer + (present,)
  608. output_state = self.layernorm(hidden_states)
  609. output_state = F.cast(output_state, self.dtype)
  610. return output_state, present_layer
  611. class PANGUALPHA_Model(nn.Cell):
  612. """
  613. The backbone of PANGUALPHA network
  614. Args:
  615. config(PANGUALPHAConfig): the config of network
  616. Inputs:
  617. input_ids: the tokenized inputs with datatype int32
  618. input_mask: the mask indicating whether each position is a valid input
  619. layer_past: the previous feature map
  620. Returns:
  621. output_state: Tensor, the output logit of backbone
  622. present_layer: Tensor, the current feature map
  623. embedding_table: Tensor, the embedding table for the vocabulary
  624. """
  625. def __init__(self, config):
  626. super(PANGUALPHA_Model, self).__init__()
  627. self.get_attention_mask = AttentionMask(config)
  628. self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1)
  629. self.eod_reset = config.eod_reset
  630. if config.load_ckpt_path:
  631. # Loading the embedding table from the ckpt path:
  632. embedding_path = os.path.join(config.load_ckpt_path, 'position_embedding.npy')
  633. if os.path.exists(embedding_path):
  634. p_table = np.load(embedding_path)
  635. position_table_param = Tensor(p_table, mstype.float32)
  636. else:
  637. raise ValueError(f"{embedding_path} file not exits, please check whether position_embedding file exit.")
  638. else:
  639. position_table_param = TruncatedNormal(0.02)
  640. self.position_embedding = nn.Embedding(
  641. config.seq_length,
  642. config.embedding_size,
  643. embedding_table=position_table_param).set_comm_fusion(1)
  644. self.word_embedding.embedding_table.parallel_optimizer = False
  645. self.position_embedding.embedding_table.parallel_optimizer = False
  646. self.position_embedding.gather.shard(((1, 1), (config.dp,)))
  647. self.position_embedding.expand.shard(((config.dp, 1),))
  648. self.blocks = nn.CellList()
  649. fusion_group_num = 4
  650. fusion_group_size = config.num_layers // fusion_group_num
  651. fusion_group_size = max(fusion_group_size, 1)
  652. num_layers = config.num_layers - 1
  653. self.num_layers = num_layers
  654. for i in range(num_layers):
  655. per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
  656. per_block.recompute()
  657. per_block.attention.dropout.dropout_gen_mask.recompute(False)
  658. per_block.attention.prob_dropout.dropout_gen_mask.recompute(False)
  659. per_block.output.dropout.dropout_gen_mask.recompute(False)
  660. per_block.attention.dropout.dropout_gen_mask.add_prim_attr("_side_effect", True)
  661. per_block.attention.prob_dropout.dropout_gen_mask.add_prim_attr("_side_effect", True)
  662. per_block.output.dropout.dropout_gen_mask.add_prim_attr("_side_effect", True)
  663. self.blocks.append(per_block)
  664. if config.self_layernorm:
  665. self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(
  666. mstype.float32).set_comm_fusion(
  667. int((num_layers - 1) / fusion_group_size) + 2)
  668. else:
  669. self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(
  670. mstype.float32).set_comm_fusion(
  671. int((num_layers - 1) / fusion_group_size) + 2)
  672. self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
  673. self.layernorm.gamma.parallel_optimizer = False
  674. self.layernorm.beta.parallel_optimizer = False
  675. self.use_past = config.use_past
  676. self.past = tuple([None] * config.num_layers)
  677. self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
  678. self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
  679. self.dtype = config.compute_dtype
  680. self.dropout = nn.Dropout(1 - config.dropout_rate)
  681. self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
  682. self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
  683. if config.load_ckpt_path:
  684. # Loading the embedding table from the ckpt path:
  685. embedding_path = os.path.join(config.load_ckpt_path, 'top_query_embedding.npy')
  686. if os.path.exists(embedding_path):
  687. top_query_table = np.load(embedding_path)
  688. top_query_table_param = Tensor(top_query_table, mstype.float32)
  689. else:
  690. raise ValueError(f"{embedding_path} file not exits, please check whether top_query_embedding file exist.")
  691. else:
  692. top_query_table_param = TruncatedNormal(0.02)
  693. self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size, \
  694. embedding_table=top_query_table_param).set_comm_fusion(
  695. int((config.num_layers - 1) / fusion_group_num) + 2)
  696. self.top_query_embedding.embedding_table.parallel_optimizer = False
  697. self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
  698. self.top_query_embedding.expand.shard(((config.dp, 1),))
  699. self.top_query_layer = QueryLayer(config)
  700. self.top_query_layer.recompute()
  701. self.top_query_layer.output.dropout.dropout_gen_mask.recompute(False)
  702. self.top_query_layer.attention.dropout.dropout_gen_mask.recompute(False)
  703. self.top_query_layer.attention.prob_dropout.dropout_gen_mask.recompute(False)
  704. self.top_query_layer.output.dropout.dropout_gen_mask.add_prim_attr("_side_effect", True)
  705. self.top_query_layer.attention.dropout.dropout_gen_mask.add_prim_attr("_side_effect", True)
  706. self.top_query_layer.attention.prob_dropout.dropout_gen_mask.add_prim_attr("_side_effect", True)
  707. self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_num) + 2)
  708. def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, layer_past=None):
  709. """PanGu Alpha model"""
  710. if not self.use_past:
  711. layer_past = self.past
  712. input_embedding, embedding_table = self.word_embedding(input_ids)
  713. if not self.eod_reset:
  714. batch_size, seq_length = F.shape(input_ids)
  715. input_position = F.tuple_to_array(F.make_range(seq_length))
  716. input_position = P.Tile()(input_position, (batch_size, 1))
  717. attention_mask = self.get_attention_mask(input_mask)
  718. position_embedding = self.position_embedding(input_position)
  719. hidden_states = self.add(input_embedding, position_embedding)
  720. hidden_states = self.dropout(hidden_states)
  721. hidden_states = P.Cast()(hidden_states, mstype.float16)
  722. attention_mask = self.expand_dims(attention_mask, 1)
  723. present_layer = ()
  724. for i in range(self.num_layers):
  725. hidden_states, present = self.blocks[i](hidden_states,
  726. attention_mask, layer_past)
  727. present_layer = present_layer + (present,)
  728. output_state = self.layernorm(hidden_states)
  729. output_state = F.cast(output_state, self.dtype)
  730. top_query_hidden_states = self.top_query_embedding(input_position)
  731. output_state, present = self.top_query_layer(output_state, top_query_hidden_states,
  732. attention_mask, layer_past)
  733. present_layer = present_layer + (present,)
  734. return output_state, present_layer, embedding_table
  735. class PANGUALPHA_Head(nn.Cell):
  736. """
  737. Head for PANGUALPHA to get the logits of each token in the vocab
  738. Args:
  739. config(PANGUALPHAConfig): the config of network
  740. Inputs:
  741. state: the output of the backbone
  742. embedding_table: the embedding table of the vocabulary
  743. Returns:
  744. logits: Tensor, the logits of the corresponding inputs
  745. """
  746. def __init__(self, config):
  747. super(PANGUALPHA_Head, self).__init__()
  748. if config.word_emb_dp:
  749. self.matmul = P.MatMul(transpose_b=True).shard(((config.dp, 1), (1, 1)))
  750. else:
  751. self.matmul = P.MatMul(transpose_b=True).shard(((config.dp, 1), (config.mp, 1)))
  752. self.embedding_size = config.embedding_size
  753. self.log_softmax = P.LogSoftmax(axis=-1)
  754. self.dtype = config.compute_dtype
  755. self.cast = P.Cast()
  756. def construct(self, state, embedding_table):
  757. state = P.Reshape()(state, (-1, self.embedding_size))
  758. logits = self.matmul(state, self.cast(embedding_table, self.dtype))
  759. return logits
  760. class PANGUALPHAPipeline(nn.Cell):
  761. """
  762. The PANGUALPHA network consisting of two parts the backbone and the head
  763. Args:
  764. config(PANGUALPHAConfig): the config of network
  765. Inputs:
  766. input_ids: the tokenized inputs
  767. input_mask: the mask indicating whether each position is a valid input
  768. past: the previous feature map
  769. Returns:
  770. logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
  771. """
  772. def __init__(self, config):
  773. super(PANGUALPHAPipeline, self).__init__()
  774. self.backbone = PANGUALPHA_ModelPipeline(config)
  775. self.head = PANGUALPHA_Head(config)
  776. self.head.stage = config.stage_num - 1
  777. self.vocab_size = config.vocab_size
  778. self.embedding_size = config.embedding_size
  779. self.embedding_table = Parameter(initializer(
  780. Normal(0.02), [self.vocab_size, self.embedding_size]),
  781. name="embedding_table")
  782. def construct(self, input_ids, input_mask, input_position, attention_mask, past=None):
  783. output_states, _ = self.backbone(input_ids, input_mask, self.embedding_table, input_position, attention_mask, past)
  784. logits = self.head(output_states, self.embedding_table)
  785. return logits
  786. class PANGUALPHA(nn.Cell):
  787. """
  788. The PANGUALPHA network consisting of two parts the backbone and the head
  789. Args:
  790. config(PANGUALPHAConfig): the config of network
  791. Inputs:
  792. input_ids: the tokenized inputs
  793. input_mask: the mask indicating whether each position is a valid input
  794. past: the previous feature map
  795. Returns:
  796. logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
  797. """
  798. def __init__(self, config):
  799. super(PANGUALPHA, self).__init__()
  800. self.backbone = PANGUALPHA_Model(config)
  801. self.head = PANGUALPHA_Head(config)
  802. def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None):
  803. output_states, _, embedding_table = self.backbone(
  804. input_ids, input_mask, input_position, attention_mask, past)
  805. logits = self.head(output_states, embedding_table)
  806. return logits
  807. class CrossEntropyLoss(nn.Cell):
  808. """
  809. Calculate the cross entropy loss
  810. Args:
  811. config(PANGUALPHAConfig): the config of the network
  812. Inputs:
  813. logits: the output logits of the backbone
  814. label: the ground truth label of the sample
  815. input_mask: the mask indicating whether each position is a valid input
  816. Returns:
  817. loss: Tensor, the corrsponding cross entropy loss
  818. """
  819. def __init__(self, config):
  820. super(CrossEntropyLoss, self).__init__()
  821. self.mean = P.ReduceMean()
  822. self.sum = P.ReduceSum().shard(((config.dp, config.mp),))
  823. self.onehot = P.OneHot().shard(((config.dp, config.mp), (), ()))
  824. self.on_value = Tensor(1.0, mstype.float32)
  825. self.off_value = Tensor(0.0, mstype.float32)
  826. self.vocab_size = config.vocab_size
  827. self.max = P.ArgMaxWithValue(axis=-1, keep_dims=True).shard(
  828. ((config.dp, config.mp),))
  829. self.eps_const = Tensor(1e-24, mstype.float32)
  830. self.sub = P.Sub().shard(((config.dp, config.mp), (config.dp, 1)))
  831. self.exp = P.Exp().shard(((config.dp, config.mp),))
  832. self.div = P.RealDiv().shard(((config.dp, config.mp), (config.dp, 1)))
  833. self.log = P.Log().shard(((config.dp, config.mp),))
  834. self.add = P.TensorAdd().shard(((config.dp, config.mp), ()))
  835. self.mul = P.Mul().shard(
  836. ((config.dp, config.mp), (config.dp, config.mp)))
  837. self.neg = P.Neg().shard(((config.dp, config.mp),))
  838. self.sum2 = P.ReduceSum().shard(((1,),))
  839. self.mul2 = P.Mul().shard(((1,), (1,)))
  840. self.add2 = P.TensorAdd()
  841. self.div2 = P.RealDiv()
  842. if config.stage_num > 1:
  843. self.div2.add_prim_attr("end", 0)
  844. def construct(self, logits, label, input_mask):
  845. logits = F.cast(logits, mstype.float32)
  846. _, logit_max = self.max(logits)
  847. logit_sub = self.sub(logits, logit_max)
  848. logit_exp = self.exp(logit_sub)
  849. exp_sum = self.sum(logit_exp, -1)
  850. exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
  851. softmax_result = self.div(logit_exp, exp_sum)
  852. log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
  853. label = P.Reshape()(label, (-1,))
  854. one_hot_label = self.onehot(label, self.vocab_size, self.on_value,
  855. self.off_value)
  856. loss = self.mul(log_softmax_result, one_hot_label)
  857. loss_unsum = self.neg(loss)
  858. loss_reduce = self.sum(loss_unsum, -1)
  859. input_mask = P.Reshape()(input_mask, (-1,))
  860. numerator = self.sum2(self.mul2(loss_reduce, input_mask))
  861. denominator = self.add2(
  862. self.sum2(input_mask),
  863. P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32))
  864. loss = self.div2(numerator, denominator)
  865. return loss
  866. class MicroBatch(nn.Cell):
  867. def __init__(self, config):
  868. super().__init__()
  869. self.micro_slice = P.StridedSlice().shard(((1, 1),))
  870. self.micro_attention_slice = P.StridedSlice().shard(((1, 1),))
  871. self.shape = P.Shape()
  872. self.stage_num = config.micro_size
  873. self.seq_len = config.seq_length
  874. self.slice_mask = P.StridedSlice().shard(((1, 1, 1),))
  875. def construct(self, x, i, input_position, attention_mask):
  876. input_shape = self.shape(x)
  877. micro_batch_begin = (i * input_shape[0] // self.stage_num, 0)
  878. micro_batch_end = ((i + 1) * input_shape[0] // self.stage_num, input_shape[1])
  879. micro_batch_stride = (1, 1)
  880. micro_input = self.micro_slice(x, micro_batch_begin, micro_batch_end, micro_batch_stride)
  881. micro_input_position_begin = (i * input_shape[0] // self.stage_num, 0)
  882. micro_input_position_end = ((i + 1) * input_shape[0] // self.stage_num, self.seq_len)
  883. micro_input_position = self.micro_attention_slice(input_position, micro_input_position_begin, micro_input_position_end, micro_batch_stride)
  884. micro_attention_mask_begin = (i * input_shape[0] // self.stage_num, 0, 0)
  885. micro_attention_mask_end = ((i + 1) * input_shape[0] // self.stage_num, self.seq_len, self.seq_len)
  886. micro_attention_mask_stride = (1, 1 ,1)
  887. micro_attention_mask = self.slice_mask(attention_mask, micro_attention_mask_begin, micro_attention_mask_end, micro_attention_mask_stride)
  888. return micro_input, micro_input_position, micro_attention_mask
  889. class PANGUALPHAWithLossPipeline(nn.Cell):
  890. """
  891. PANGUALPHA training loss
  892. Args:
  893. network: backbone network of PANGUALPHA
  894. loss: loss function, e.g., crossentropy
  895. eos_token: the end_of_sentence token
  896. Inputs:
  897. input_ids: the tokenized inputs
  898. past: the previous feature map
  899. Returns:
  900. output: Tensor, the loss of the network
  901. """
  902. def __init__(self, config, network, loss, eos_token=6):
  903. super(PANGUALPHAWithLossPipeline, self).__init__(auto_prefix=False)
  904. self.network = network
  905. self.loss = loss
  906. self.eos_token = eos_token
  907. self.slice = P.StridedSlice().shard(((config.dp, 1),))
  908. self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
  909. self.batch_size = config.batch_size
  910. self.len = config.seq_length
  911. self.micro_batch_step = config.micro_size
  912. self.micro_input = nn.CellList()
  913. self.slice_mask = P.StridedSlice().shard(((config.dp, 1, 1),))
  914. for i in range(self.micro_batch_step):
  915. micro = MicroBatch(config)
  916. micro.micro_slice.add_prim_attr("micro", i)
  917. micro.micro_slice.add_prim_attr("start", i)
  918. self.micro_input.append(micro)
  919. def construct(self, input_ids, input_position, attention_mask):
  920. #tokens = input_ids[:, :-1]
  921. ret = None
  922. for i in range(self.micro_batch_step):
  923. micro_input, micro_input_position, micro_attention_mask = self.micro_input[i](input_ids, i, input_position, attention_mask)
  924. tokens = self.slice(micro_input, (0, 0), (self.batch_size // self.micro_batch_step, -1), (1, 1))
  925. input_mask = F.cast(self.not_equal(tokens, self.eos_token), mstype.float32)
  926. logits = self.network(tokens, input_mask, micro_input_position, micro_attention_mask)
  927. labels = self.slice(micro_input, (0, 1), (self.batch_size // self.micro_batch_step,
  928. self.len + 1), (1, 1))
  929. output = self.loss(logits, labels, input_mask)
  930. if ret is not None:
  931. ret = ret + output
  932. else:
  933. ret = output
  934. return ret
  935. class PANGUALPHAWithLoss(nn.Cell):
  936. """
  937. PANGUALPHA training loss
  938. Args:
  939. network: backbone network of PANGUALPHA
  940. loss: loss function, e.g., crossentropy
  941. eos_token: the end_of_sentence token
  942. Inputs:
  943. input_ids: the tokenized inputs
  944. past: the previous feature map
  945. Returns:
  946. output: Tensor, the loss of the network
  947. """
  948. def __init__(self, config, network, loss, eos_token=6):
  949. super(PANGUALPHAWithLoss, self).__init__(auto_prefix=False)
  950. self.network = network
  951. self.loss = loss
  952. self.eos_token = eos_token
  953. self.slice = P.StridedSlice().shard(((config.dp, 1),))
  954. self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
  955. self.batch_size = config.batch_size
  956. self.len = config.seq_length
  957. self.slice_mask = P.StridedSlice().shard(((config.dp, 1, 1),))
  958. def construct(self, input_ids, input_position=None, attention_mask=None):
  959. tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
  960. input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1))
  961. attention_mask = self.slice_mask(attention_mask, (0, 0, 0),
  962. (self.batch_size, self.len, self.len),
  963. (1, 1, 1))
  964. input_mask = F.cast(self.not_equal(tokens, self.eos_token),
  965. mstype.float32)
  966. logits = self.network(tokens, input_mask, input_position, attention_mask)
  967. labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1),
  968. (1, 1))
  969. output = self.loss(logits, labels, input_mask)
  970. return output
  971. class EvalNet(nn.Cell):
  972. """
  973. PANGUALPHA evaluation net
  974. Args:
  975. backbone: backbone network of PANGUALPHA
  976. generate: enable generate mode
  977. Inputs:
  978. input_ids: the tokenized inpus
  979. Returns:
  980. outputs: Tensor, corresponding output for different tasks
  981. """
  982. def __init__(self, backbone, generate=False):
  983. super(EvalNet, self).__init__(auto_prefix=False)
  984. self.backbone = backbone
  985. self.argmax = P.ArgMaxWithValue()
  986. self.generate = generate
  987. self.topk = P.TopK(sorted=True).shard(((1, 1),))
  988. self.log_softmax = P.Softmax(axis=-1)
  989. def construct(self, input_ids):
  990. """evaluation net"""
  991. input_mask = F.cast(F.not_equal(input_ids, 6), mstype.float32)
  992. logits = self.backbone(input_ids, input_mask)
  993. value, index = self.topk(logits, 5)
  994. probs = self.log_softmax(value)
  995. return probs, index