You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

62 lines
1.9 kB

  1. from os.path import join
  2. import numpy as np
  3. import time
  4. import h5py
  5. import tensorflow as tf
  6. def Train_data():
  7. """
  8. Creates dataset for training,validating,testing
  9. :return:
  10. train_data : training dataset
  11. validate_data : validation dataset
  12. test_data : test dataset
  13. """
  14. print ("loading train data ...")
  15. time_start = time.time()
  16. data_root = '/media/keziwen/86AA9651AA963E1D'
  17. with h5py.File(join(data_root, './data/train_real2.h5')) as f:
  18. data_real = f['train_real'][:]
  19. num, nt, ny, nx = data_real.shape
  20. data_real = np.transpose(data_real, (0, 1, 3, 2))
  21. with h5py.File(join(data_root, './data/train_imag2.h5')) as f:
  22. data_imag = f['train_imag'][:]
  23. num, nt, ny, nx = data_imag.shape
  24. data_imag = np.transpose(data_imag, (0, 1, 3, 2))
  25. data = data_real+1j*data_imag
  26. num_train = 15000
  27. num_validate = 2000
  28. train_data = data[0:num_train]
  29. validate_data = data[num_train:num_train+num_validate]
  30. train_data = np.random.permutation(train_data)
  31. time_end = time.time()
  32. print ('dataset has been created using {}s'.format(time_end-time_start))
  33. return train_data, validate_data
  34. def Test_data():
  35. """
  36. Creates dataset for training,validating,testing
  37. :return:
  38. test_data : test dataset
  39. """
  40. print ("loading test data ...")
  41. time_start = time.time()
  42. data_root = '/media/keziwen/86AA9651AA963E1D'
  43. with h5py.File(join(data_root, './data/test_real2.h5')) as f:
  44. test_real = f['test_real'][:]
  45. with h5py.File(join(data_root, './data/test_imag2.h5')) as f:
  46. test_imag = f['test_imag'][:]
  47. test_real = np.transpose(test_real, (0, 1, 3, 2))
  48. test_imag = np.transpose(test_imag, (0, 1, 3, 2))
  49. test_data = test_real+1j*test_imag
  50. time_end = time.time()
  51. print ('dataset has been created using {}s'.format(time_end - time_start))
  52. return test_data

简介

No Description

Python

贡献者 (1)