You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

47 lines
2.2 KiB

  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """GRU preprocess script."""
  16. import os
  17. import argparse
  18. from src.dataset import create_gru_dataset
  19. from src.config import config
  20. parser = argparse.ArgumentParser(description='GRU preprocess')
  21. parser.add_argument("--dataset_path", type=str, default="",
  22. help="Dataset path, default: f`sns.")
  23. parser.add_argument('--device_num', type=int, default=1, help='Use device nums, default is 1')
  24. parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path')
  25. args = parser.parse_args()
  26. if __name__ == "__main__":
  27. mindrecord_file = args.dataset_path
  28. if not os.path.exists(mindrecord_file):
  29. print("dataset file {} not exists, please check!".format(mindrecord_file))
  30. raise ValueError(mindrecord_file)
  31. dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
  32. dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
  33. source_ids_path = os.path.join(args.result_path, "00_data")
  34. target_ids_path = os.path.join(args.result_path, "01_data")
  35. os.makedirs(source_ids_path)
  36. os.makedirs(target_ids_path)
  37. for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True, num_epochs=1)):
  38. file_name = "gru_bs" + str(config.eval_batch_size) + "_" + str(i) + ".bin"
  39. data["source_ids"].tofile(os.path.join(source_ids_path, file_name))
  40. data["target_ids"].tofile(os.path.join(target_ids_path, file_name))
  41. print("="*20, "export bin files finished", "="*20)