您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 

93 行
4.2 KiB

  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Transformer evaluation script."""
  16. import os
  17. import argparse
  18. import mindspore.common.dtype as mstype
  19. from mindspore import log as logger
  20. from mindspore.common.tensor import Tensor
  21. from mindspore.train.model import Model
  22. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  23. from mindspore import context
  24. from src.dataset import create_gru_dataset
  25. from src.seq2seq import Seq2Seq
  26. from src.gru_for_infer import GRUInferCell
  27. from src.config import config
  28. def run_gru_eval():
  29. """
  30. Transformer evaluation.
  31. """
  32. parser = argparse.ArgumentParser(description='GRU eval')
  33. parser.add_argument("--device_target", type=str, default="Ascend",
  34. help="device where the code will be implemented, default is Ascend")
  35. parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0')
  36. parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1')
  37. parser.add_argument('--ckpt_file', type=str, default="", help='ckpt file path')
  38. parser.add_argument("--dataset_path", type=str, default="",
  39. help="Dataset path, default: f`sns.")
  40. args = parser.parse_args()
  41. context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
  42. device_id=args.device_id, save_graphs=False)
  43. if args.device_target == "GPU":
  44. if config.compute_type != mstype.float32:
  45. logger.warning('GPU only support fp32 temporarily, run with fp32.')
  46. config.compute_type = mstype.float32
  47. mindrecord_file = args.dataset_path
  48. if not os.path.exists(mindrecord_file):
  49. print("dataset file {} not exists, please check!".format(mindrecord_file))
  50. raise ValueError(mindrecord_file)
  51. dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
  52. dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
  53. dataset_size = dataset.get_dataset_size()
  54. print("dataset size is {}".format(dataset_size))
  55. network = Seq2Seq(config, is_training=False)
  56. network = GRUInferCell(network)
  57. network.set_train(False)
  58. if args.ckpt_file != "":
  59. parameter_dict = load_checkpoint(args.ckpt_file)
  60. load_param_into_net(network, parameter_dict)
  61. model = Model(network)
  62. predictions = []
  63. source_sents = []
  64. target_sents = []
  65. eval_text_len = 0
  66. for batch in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
  67. source_sents.append(batch["source_ids"])
  68. target_sents.append(batch["target_ids"])
  69. source_ids = Tensor(batch["source_ids"], mstype.int32)
  70. target_ids = Tensor(batch["target_ids"], mstype.int32)
  71. predicted_ids = model.predict(source_ids, target_ids)
  72. print("predicts is ", predicted_ids.asnumpy())
  73. print("target_ids is ", target_ids)
  74. predictions.append(predicted_ids.asnumpy())
  75. eval_text_len = eval_text_len + 1
  76. f_output = open(config.output_file, 'w')
  77. f_target = open(config.target_file, "w")
  78. for batch_out, true_sentence in zip(predictions, target_sents):
  79. for i in range(config.eval_batch_size):
  80. target_ids = [str(x) for x in true_sentence[i].tolist()]
  81. f_target.write(" ".join(target_ids) + "\n")
  82. token_ids = [str(x) for x in batch_out[i].tolist()]
  83. f_output.write(" ".join(token_ids) + "\n")
  84. f_output.close()
  85. f_target.close()
  86. if __name__ == "__main__":
  87. run_gru_eval()