You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

225 lines
5.5 kB

  1. import numpy as np
  2. import mymath
  3. from numpy.lib.stride_tricks import as_strided
  4. def soft_thresh(u, lmda):
  5. """Soft-threshing operator for complex valued input"""
  6. Su = (abs(u) - lmda) / abs(u) * u
  7. Su[abs(u) < lmda] = 0
  8. return Su
  9. def normal_pdf(length, sensitivity):
  10. return np.exp(-sensitivity * (np.arange(length) - length / 2)**2)
  11. def var_dens_mask(shape, ivar, sample_high_freq=True):
  12. """Variable Density Mask (2D undersampling)"""
  13. if len(shape) == 4:
  14. Num, Nt, Nx, Ny = shape
  15. else:
  16. Nx, Ny = shape
  17. Nt = 1
  18. pdf_x = normal_pdf(Nx, ivar)
  19. pdf_y = normal_pdf(Ny, ivar)
  20. pdf = np.outer(pdf_x, pdf_y)
  21. size = pdf.itemsize
  22. strided_pdf = as_strided(pdf, (Nt, Nx, Ny), (0, Ny * size, size))
  23. # this must be false if undersampling rate is very low (around 90%~ish)
  24. if sample_high_freq:
  25. strided_pdf = strided_pdf / 1.25 + 0.02
  26. mask = np.random.binomial(1, strided_pdf)
  27. xc = Nx / 2
  28. yc = Ny / 2
  29. mask[:, xc - 4:xc + 5, yc - 4:yc + 5] = True
  30. if Nt == 1:
  31. return mask.reshape((Nx, Ny))
  32. mask_4D = mask[np.newaxis, :, :, :]
  33. mask_4D = np.tile(mask_4D, (Num, 1, 1, 1))
  34. return mask_4D
  35. def cartesian_mask(shape, ivar, centred=False,
  36. sample_high_freq=True, sample_centre=True, sample_n=4):
  37. """
  38. undersamples along Nx
  39. :param shape: tuple - [nt, nx, ny]
  40. :param ivar: sensitivity parameter for Gaussian distribution
  41. """
  42. if len(shape) == 4:
  43. Num, Nt, Nx, Ny = shape
  44. else:
  45. Nx, Ny = shape
  46. Nt = 1
  47. pdf_x = normal_pdf(Nx, ivar)
  48. if sample_high_freq:
  49. pdf_x = pdf_x / 1.25 + 0.02
  50. size = pdf_x.itemsize
  51. stride_pdf = as_strided(pdf_x, (Nt, Nx, 1), (0, size, 0))
  52. mask = np.random.binomial(1, stride_pdf)
  53. size = mask.itemsize
  54. mask = as_strided(mask, (Nt, Nx, Ny), (size * Nx, size, 0))
  55. if sample_centre:
  56. s = sample_n / 2
  57. xc = Nx / 2
  58. mask[:, xc - s: xc + s, :] = True
  59. if not centred:
  60. mask = mymath.ifftshift(mask, axes=(-1, -2))
  61. mask_4D = mask[np.newaxis, :, :, :]
  62. mask_4D = np.tile(mask_4D, (Num, 1, 1, 1))
  63. return mask_4D
  64. def shear_grid_mask(shape, acceleration_rate, sample_low_freq=True,
  65. centred=False, sample_n=10):
  66. '''
  67. Creates undersampling mask which samples in sheer grid
  68. Parameters
  69. ----------
  70. shape: (nt, nx, ny)
  71. acceleration_rate: int
  72. Returns
  73. -------
  74. array
  75. '''
  76. Nt, Nx, Ny = shape
  77. start = np.random.randint(0, acceleration_rate)
  78. mask = np.zeros((Nt, Nx))
  79. for t in xrange(Nt):
  80. mask[t, (start+t)%acceleration_rate::acceleration_rate] = 1
  81. xc = Nx / 2
  82. xl = sample_n / 2
  83. if sample_low_freq and centred:
  84. xh = xl
  85. if sample_n % 2 == 0:
  86. xh += 1
  87. mask[:, xc - xl:xc + xh+1] = 1
  88. elif sample_low_freq:
  89. xh = xl
  90. if sample_n % 2 == 1:
  91. xh -= 1
  92. if xl > 0:
  93. mask[:, :xl] = 1
  94. if xh > 0:
  95. mask[:, -xh:] = 1
  96. mask_rep = np.repeat(mask[..., np.newaxis], Ny, axis=-1)
  97. return mask_rep
  98. def perturbed_shear_grid_mask(shape, acceleration_rate, sample_low_freq=True,
  99. centred=False,
  100. sample_n=10):
  101. Nt, Nx, Ny = shape
  102. start = np.random.randint(0, acceleration_rate)
  103. mask = np.zeros((Nt, Nx))
  104. for t in xrange(Nt):
  105. mask[t, (start+t)%acceleration_rate::acceleration_rate] = 1
  106. # brute force
  107. rand_code = np.random.randint(0, 3, size=Nt*Nx)
  108. shift = np.array([-1, 0, 1])[rand_code]
  109. new_mask = np.zeros_like(mask)
  110. for t in xrange(Nt):
  111. for x in xrange(Nx):
  112. if mask[t, x]:
  113. new_mask[t, (x + shift[t*x])%Nx] = 1
  114. xc = Nx / 2
  115. xl = sample_n / 2
  116. if sample_low_freq and centred:
  117. xh = xl
  118. if sample_n % 2 == 0:
  119. xh += 1
  120. new_mask[:, xc - xl:xc + xh+1] = 1
  121. elif sample_low_freq:
  122. xh = xl
  123. if sample_n % 2 == 1:
  124. xh -= 1
  125. new_mask[:, :xl] = 1
  126. new_mask[:, -xh:] = 1
  127. mask_rep = np.repeat(new_mask[..., np.newaxis], Ny, axis=-1)
  128. return mask_rep
  129. def undersample(x, mask, centred=False, norm='ortho'):
  130. '''
  131. Undersample x. FFT2 will be applied to the last 2 axis
  132. Parameters
  133. ----------
  134. x: array_like
  135. data
  136. mask: array_like
  137. undersampling mask in fourier domain
  138. Returns
  139. -------
  140. xu: array_like
  141. undersampled image in image domain. Note that it is complex valued
  142. x_fu: array_like
  143. undersampled data in kspace
  144. '''
  145. assert x.shape == mask.shape
  146. if centred:
  147. x_f = mymath.fft2c(x, norm=norm)
  148. x_fu = x_f * mask
  149. x_u = mymath.ifft2c(x_fu, norm=norm)
  150. return x_u, x_fu
  151. else:
  152. x_f = mymath.fft2(x, norm=norm)
  153. x_fu = x_f * mask
  154. x_u = mymath.ifft2(x_fu, norm=norm)
  155. return x_u, x_fu, x_f
  156. def data_consistency(x, y, mask, centered=False, norm='ortho'):
  157. '''
  158. x is in image space,
  159. y is in k-space
  160. '''
  161. if centered:
  162. xf = mymath.fft2c(x, norm=norm)
  163. xm = (1 - mask) * xf + y
  164. xd = mymath.ifft2c(xm, norm=norm)
  165. else:
  166. xf = mymath.fft2(x, norm=norm)
  167. xm = (1 - mask) * xf + y
  168. xd = mymath.ifft2(xm, norm=norm)
  169. return xd
  170. def get_phase(x):
  171. xr = np.real(x)
  172. xi = np.imag(x)
  173. phase = np.arctan(xi / (xr + 1e-12))
  174. return phase

简介

No Description

Python

贡献者 (1)