You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
1.8 kB

  1. import numpy as np
  2. import tensorflow as tf
  3. def complex2real(x):
  4. '''
  5. Parameter
  6. ---------
  7. x: ndarray
  8. assumes at least 2d. Last 2D axes are split in terms of real and imag
  9. 2d/3d/4d complex valued tensor (n, nx, ny) or (n, nx, ny, nt)
  10. Returns
  11. -------
  12. y: 4d tensor (n, 2, nx, ny)
  13. '''
  14. x_real = np.real(x)
  15. x_imag = np.imag(x)
  16. y = np.array([x_real, x_imag]).astype(np.float32)
  17. # re-order in convenient order
  18. """
  19. if x.ndim >= 3:
  20. y = y.swapaxes(0, 1)
  21. """
  22. y = np.transpose(y, [1, 2, 3, 4, 0])
  23. return y
  24. def real2complex(x):
  25. '''
  26. Converts from array of the form ([n, ]2, nx, ny[, nt]) to ([n, ]nx, ny[, nt])
  27. '''
  28. #x = np.asarray(x)
  29. if x.shape[0] == 2 and x.shape[1] != 2: # Hacky check
  30. return x[0] + x[1] * 1j
  31. elif x.shape[1] == 2:
  32. y = x[:, 0] + x[:, 1] * 1j
  33. return y
  34. else:
  35. raise ValueError('Invalid dimension')
  36. def mask_c2r(m):
  37. return complex2real(m * (1+1j))
  38. def mask_r2c(m):
  39. return m[0] if m.ndim == 3 else m[:, 0]
  40. def to_lasagne_format(x, mask=False):
  41. """
  42. Assumes data is of shape (n[, nt], nx, ny).
  43. Reshapes to (n, n_channels, nx, ny[, nt])
  44. Note: Depth must be the last axis, the dimensions will be reordered
  45. """
  46. """
  47. if x.ndim == 4: # n 3D inputs. reorder axes
  48. x = np.transpose(x, (0, 2, 3, 1))
  49. """
  50. if mask: # Hacky solution
  51. x = x*(1+1j)
  52. x = complex2real(x)
  53. return x
  54. def from_lasagne_format(x, mask=False):
  55. """
  56. Assumes data is of shape (n, 2, nx, ny[, nt]).
  57. Reshapes to (n, [nt, ]nx, ny)
  58. """
  59. if x.ndim == 5: # n 3D inputs. reorder axes
  60. x = np.transpose(x, (0, 1, 4, 2, 3))
  61. if mask:
  62. x = mask_r2c(x)
  63. else:
  64. x = real2complex(x)
  65. return x

简介

No Description

Python

贡献者 (1)