2000亿开源中文预训练语言模型「鹏程·盘古α」
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

161 lines
6.0 KiB

  1. """
  2. Create dataset for training and evaluting
  3. """
  4. import os
  5. import mindspore.dataset as ds
  6. import mindspore.dataset.transforms.c_transforms as C
  7. import mindspore.common.dtype as mstype
  8. import numpy as np
  9. def get_input_data(input_ids, eod_id):
  10. """
  11. Generate position_id and attention_mask according to input_ids considering eod reset
  12. Inputs:
  13. input_ids: the input token ids
  14. eod_token: the id for <EOD>
  15. returns:
  16. input_ids: the input token ids
  17. position_id: the position ids cosidering eod reset
  18. attention_mask: the attention mask considering eod reset
  19. """
  20. seq_length = input_ids.shape[0] - 1
  21. attention_mask = np.tril(np.ones(shape=(seq_length, seq_length)))
  22. position_id = np.arange(seq_length)
  23. eod_index = position_id[input_ids[:-1] == eod_id]
  24. prev_index = 0
  25. for i in range(eod_index.size):
  26. index = eod_index[i]
  27. attention_mask[(index+1):, :(index+1)] = 0
  28. position_id[(index+1):] -= (index + 1 - prev_index)
  29. prev_index = index + 1
  30. return input_ids, position_id, attention_mask
  31. def get_input_data_from_batch(input_ids, eod_id, rank, dis):
  32. """
  33. Generate position_id and attention_mask according to input_ids considering eod reset
  34. Inputs:
  35. input_ids: the input token ids
  36. eod_token: the id for <EOD>
  37. returns:
  38. input_ids: the input token ids
  39. position_id: the position ids cosidering eod reset
  40. attention_mask: the attention mask considering eod reset
  41. """
  42. rank = int(rank)
  43. input_ids = input_ids[rank * dis: (rank + 1) * dis]
  44. seq_length = 1024 # input_ids.shape[1] - 1
  45. batch_input_ids = input_ids
  46. batch_position_ids = np.ones((dis, seq_length))
  47. batch_attention_mask = np.ones((dis, seq_length, seq_length))
  48. for bs_i in range(0, len(input_ids)):
  49. local_ids = input_ids[bs_i]
  50. batch_attention_mask[bs_i] = np.tril(np.ones(shape=(seq_length, seq_length)))
  51. batch_position_ids[bs_i] = np.arange(seq_length)
  52. eod_index = batch_position_ids[bs_i, local_ids[:-1] == eod_id].astype(np.int32)
  53. prev_index = 0
  54. for i in range(eod_index.size):
  55. index = eod_index[i]
  56. batch_attention_mask[bs_i, (index + 1):, :(index + 1)] = 0
  57. batch_position_ids[bs_i, (index + 1):] -= (index + 1 - prev_index)
  58. prev_index = index + 1
  59. return batch_input_ids, batch_position_ids, batch_attention_mask
  60. def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_start_index=0, eod_reset=True, eod_id=9):
  61. """
  62. Create dataset
  63. Inputs:
  64. batch_size: batch size
  65. data_path: path of your MindRecord files
  66. device_num: total device number
  67. rank: current rank id
  68. drop: whether drop remainder
  69. eod_reset: whether enable position reset and attention mask reset
  70. eod_id: the id for <EOD>
  71. Returns:
  72. dataset: the dataset for training or evaluating
  73. """
  74. ds.config.set_seed(1)
  75. home_path = os.path.join(os.getcwd(), data_path)
  76. files = os.listdir(data_path)
  77. data = [
  78. os.path.join(home_path, name) for name in files
  79. if not name.endswith(".db")
  80. ]
  81. data.sort(key=lambda x: int(x[x.find("mindrecord")+10:]))
  82. print(data)
  83. dataset = ds.MindDataset(data[data_start_index:], columns_list=["input_ids"], shuffle=False)
  84. type_cast_op = C.TypeCast(mstype.int32)
  85. type_cast_op_float = C.TypeCast(mstype.float16)
  86. if eod_reset:
  87. map_func = (lambda input_ids: get_input_data(input_ids, eod_id))
  88. dataset = dataset.map(operations=map_func, input_columns=["input_ids"], output_columns=["input_ids", "position_id", "attention_mask"], column_order=["input_ids", "position_id", "attention_mask"])
  89. dataset = dataset.map(input_columns="position_id", operations=type_cast_op)
  90. dataset = dataset.map(input_columns="attention_mask", operations=type_cast_op_float)
  91. dataset = dataset.map(input_columns="input_ids", operations=type_cast_op)
  92. dataset = dataset.batch(batch_size, drop_remainder=drop)
  93. dataset = dataset.repeat(1)
  94. return dataset
  95. def create_dataset_dp(batch_size, data_path, device_num=1, rank=0, drop=True, data_start_index=0,
  96. eod_id=9):
  97. """
  98. Create dataset using data parallel.
  99. Inputs:
  100. batch_size: batch size
  101. data_path: path of your MindRecord files
  102. device_num: total device number
  103. rank: current rank id
  104. drop: whether drop remainder
  105. eod_id: the id for <EOD>
  106. Returns:
  107. dataset: the dataset for training or evaluating
  108. """
  109. ds.config.set_seed(1)
  110. home_path = os.path.join(os.getcwd(), data_path)
  111. files = os.listdir(data_path)
  112. dis = int(batch_size / device_num)
  113. if dis < 1:
  114. raise ValueError("Batch size / device_num should be positive, but found {}".format(dis))
  115. data = [
  116. os.path.join(home_path, name) for name in files
  117. if not name.endswith(".db")
  118. ]
  119. data.sort(key=lambda x: int(x[x.find("mindrecord")+10:]))
  120. print(data)
  121. if data_start_index >= len(data):
  122. raise ValueError(f"data start index {data_start_index} is larger than dataset length {len(data)}")
  123. dataset = ds.MindDataset(data[data_start_index:], columns_list=["input_ids"], shuffle=False)
  124. type_cast_op = C.TypeCast(mstype.int32)
  125. type_cast_op_float = C.TypeCast(mstype.float16)
  126. map_func = (lambda input_ids: get_input_data_from_batch(input_ids, eod_id, rank, dis))
  127. dataset = dataset.batch(batch_size, drop_remainder=drop)
  128. dataset = dataset.map(operations=map_func, input_columns=["input_ids"],
  129. output_columns=["input_ids", "position_id", "attention_mask"],
  130. column_order=["input_ids", "position_id", "attention_mask"])
  131. dataset = dataset.map(input_columns="position_id", operations=type_cast_op)
  132. dataset = dataset.map(input_columns="attention_mask", operations=type_cast_op_float)
  133. dataset = dataset.map(input_columns="input_ids", operations=type_cast_op)
  134. dataset = dataset.repeat(1)
  135. return dataset