You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

579 lines
23 kB

  1. import tensorflow as tf
  2. from tensorflow.keras import layers
  3. import os
  4. import numpy as np
  5. import time
  6. from tools.tools import tempfft, fft2c_mri, ifft2c_mri, Emat_xyt
  7. class CNNLayer(tf.keras.layers.Layer):
  8. def __init__(self, n_f=32, n_out=2):
  9. super(CNNLayer, self).__init__()
  10. self.mylayers = []
  11. self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
  12. self.mylayers.append(tf.keras.layers.LeakyReLU())
  13. self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
  14. self.mylayers.append(tf.keras.layers.LeakyReLU())
  15. self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
  16. self.mylayers.append(tf.keras.layers.LeakyReLU())
  17. self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
  18. self.mylayers.append(tf.keras.layers.LeakyReLU())
  19. self.mylayers.append(tf.keras.layers.Conv3D(n_out, 3, strides=1, padding='same', use_bias=False))
  20. self.seq = tf.keras.Sequential(self.mylayers)
  21. def call(self, input):
  22. if len(input.shape) == 4:
  23. input2c = tf.stack([tf.math.real(input), tf.math.imag(input)], axis=-1)
  24. else:
  25. input2c = tf.concat([tf.math.real(input), tf.math.imag(input)], axis=-1)
  26. res = self.seq(input2c)
  27. res = tf.complex(res[:,:,:,:,0], res[:,:,:,:,1])
  28. return res
  29. class CONV_OP(tf.keras.layers.Layer):
  30. def __init__(self, n_f=32, ifactivate=False):
  31. super(CONV_OP, self).__init__()
  32. self.mylayers = []
  33. self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
  34. if ifactivate == True:
  35. self.mylayers.append(tf.keras.layers.ReLU())
  36. self.seq = tf.keras.Sequential(self.mylayers)
  37. def call(self, input):
  38. res = self.seq(input)
  39. return res
  40. class SLR_Net(tf.keras.Model):
  41. def __init__(self, mask, niter, learned_topk=False):
  42. super(SLR_Net, self).__init__(name='SLR_Net')
  43. self.niter = niter
  44. self.E = Emat_xyt(mask)
  45. self.learned_topk = learned_topk
  46. self.celllist = []
  47. def build(self, input_shape):
  48. for i in range(self.niter-1):
  49. self.celllist.append(SLRCell(input_shape, self.E, learned_topk=self.learned_topk))
  50. self.celllist.append(SLRCell(input_shape, self.E, learned_topk=self.learned_topk, is_last=True))
  51. def call(self, d, csm):
  52. """
  53. d: undersampled k-space
  54. csm: coil sensitivity map
  55. """
  56. if csm == None:
  57. nb, nt, nx, ny = d.shape
  58. else:
  59. nb, nc, nt, nx, ny = d.shape
  60. X_SYM = []
  61. x_rec = self.E.mtimes(d, inv=True, csm=csm)
  62. t = tf.zeros_like(x_rec)
  63. beta = tf.zeros_like(x_rec)
  64. x_sym = tf.zeros_like(x_rec)
  65. data = [x_rec, x_sym, beta, t, d, csm]
  66. for i in range(self.niter):
  67. data = self.celllist[i](data, d.shape)
  68. x_sym = data[1]
  69. X_SYM.append(x_sym)
  70. x_rec = data[0]
  71. return x_rec, X_SYM
  72. class SLRCell(layers.Layer):
  73. def __init__(self, input_shape, E, learned_topk=False, is_last=False):
  74. super(SLRCell, self).__init__()
  75. if len(input_shape) == 4:
  76. self.nb, self.nt, self.nx, self.ny = input_shape
  77. else:
  78. self.nb, nc, self.nt, self.nx, self.ny = input_shape
  79. self.E = E
  80. self.learned_topk = learned_topk
  81. if self.learned_topk:
  82. if is_last:
  83. self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=False, name='thres_coef')
  84. self.eta = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=False, name='eta')
  85. else:
  86. self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=True, name='thres_coef')
  87. self.eta = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=True, name='eta')
  88. self.conv_1 = CONV_OP(n_f=16, ifactivate=True)
  89. self.conv_2 = CONV_OP(n_f=16, ifactivate=True)
  90. self.conv_3 = CONV_OP(n_f=16, ifactivate=False)
  91. self.conv_4 = CONV_OP(n_f=16, ifactivate=True)
  92. self.conv_5 = CONV_OP(n_f=16, ifactivate=True)
  93. self.conv_6 = CONV_OP(n_f=2, ifactivate=False)
  94. #self.conv_7 = CONV_OP(n_f=16, ifactivate=True)
  95. #self.conv_8 = CONV_OP(n_f=16, ifactivate=True)
  96. #self.conv_9 = CONV_OP(n_f=16, ifactivate=True)
  97. #self.conv_10 = CONV_OP(n_f=16, ifactivate=True)
  98. self.lambda_step = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='lambda_1')
  99. self.lambda_step_2 = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='lambda_2')
  100. self.soft_thr = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='soft_thr')
  101. def call(self, data, input_shape):
  102. if len(input_shape) == 4:
  103. self.nb, self.nt, self.nx, self.ny = input_shape
  104. else:
  105. self.nb, nc, self.nt, self.nx, self.ny = input_shape
  106. x_rec, x_sym, beta, t, d, csm = data
  107. x_rec, x_sym = self.sparse(x_rec, d, t, beta, csm)
  108. t = self.lowrank(x_rec)
  109. beta = self.beta_mid(beta, x_rec, t)
  110. data[0] = x_rec
  111. data[1] = x_sym
  112. data[2] = beta
  113. data[3] = t
  114. return data
  115. def sparse(self, x_rec, d, t, beta, csm):
  116. lambda_step = tf.cast(tf.nn.relu(self.lambda_step), tf.complex64)
  117. lambda_step_2 = tf.cast(tf.nn.relu(self.lambda_step_2), tf.complex64)
  118. ATAX_cplx = self.E.mtimes(self.E.mtimes(x_rec, inv=False, csm=csm) - d, inv=True, csm=csm)
  119. r_n = x_rec - tf.math.scalar_mul(lambda_step, ATAX_cplx) +\
  120. tf.math.scalar_mul(lambda_step_2, x_rec + beta - t)
  121. # D_T(soft(D_r_n))
  122. if len(r_n.shape) == 4:
  123. r_n = tf.stack([tf.math.real(r_n), tf.math.imag(r_n)], axis=-1)
  124. x_1 = self.conv_1(r_n)
  125. x_2 = self.conv_2(x_1)
  126. x_3 = self.conv_3(x_2)
  127. x_soft = tf.math.multiply(tf.math.sign(x_3), tf.nn.relu(tf.abs(x_3) - self.soft_thr))
  128. x_4 = self.conv_4(x_soft)
  129. x_5 = self.conv_5(x_4)
  130. x_6 = self.conv_6(x_5)
  131. x_rec = x_6 + r_n
  132. x_1_sym = self.conv_4(x_3)
  133. x_1_sym = self.conv_5(x_1_sym)
  134. x_1_sym = self.conv_6(x_1_sym)
  135. #x_sym_1 = self.conv_10(x_1_sym)
  136. x_sym = x_1_sym - r_n
  137. x_rec = tf.complex(x_rec[:, :, :, :, 0], x_rec[:, :, :, :, 1])
  138. return x_rec, x_sym
  139. def lowrank(self, x_rec):
  140. [batch, Nt, Nx, Ny] = x_rec.get_shape()
  141. M = tf.reshape(x_rec, [batch, Nt, Nx*Ny])
  142. St, Ut, Vt = tf.linalg.svd(M)
  143. if self.learned_topk:
  144. #tf.print(tf.sigmoid(self.thres_coef))
  145. thres = tf.sigmoid(self.thres_coef) * St[:, 0]
  146. thres = tf.expand_dims(thres, -1)
  147. St = tf.nn.relu(St - thres)
  148. else:
  149. top1_mask = np.concatenate(
  150. [np.ones([self.nb, 1], dtype=np.float32), np.zeros([self.nb, self.nt - 1], dtype=np.float32)], 1)
  151. top1_mask = tf.constant(top1_mask)
  152. St = St * top1_mask
  153. St = tf.linalg.diag(St)
  154. St = tf.dtypes.cast(St, tf.complex64)
  155. Vt_conj = tf.transpose(Vt, perm=[0, 2, 1])
  156. Vt_conj = tf.math.conj(Vt_conj)
  157. US = tf.linalg.matmul(Ut, St)
  158. M = tf.linalg.matmul(US, Vt_conj)
  159. x_rec = tf.reshape(M, [batch, Nt, Nx, Ny])
  160. return x_rec
  161. def beta_mid(self, beta, x_rec, t):
  162. eta = tf.cast(tf.nn.relu(self.eta), tf.complex64)
  163. return beta + tf.multiply(eta, x_rec - t)
  164. ###### DC-CNN ######
  165. class DC_CNN_LR(tf.keras.Model):
  166. def __init__(self, mask, niter, learned_topk=False):
  167. super(DC_CNN_LR, self).__init__(name='DC_CNN_LR')
  168. self.niter = niter
  169. self.E = Emat_xyt(mask)
  170. self.mask = mask
  171. self.learned_topk = learned_topk
  172. self.celllist = []
  173. def build(self, input_shape):
  174. for i in range(self.niter-1):
  175. self.celllist.append(DNCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk))
  176. self.celllist.append(DNCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk, is_last=True))
  177. def call(self, d, csm):
  178. """
  179. d: undersampled k-space
  180. csm: coil sensitivity map
  181. """
  182. if csm == None:
  183. nb, nt, nx, ny = d.shape
  184. else:
  185. nb, nc, nt, nx, ny = d.shape
  186. x_rec = self.E.mtimes(d, inv=True, csm=csm)
  187. for i in range(self.niter):
  188. x_rec = self.celllist[i](x_rec, d, d.shape)
  189. return x_rec
  190. class DNCell(layers.Layer):
  191. def __init__(self, input_shape, E, mask, learned_topk=False, is_last=False):
  192. super(DNCell, self).__init__()
  193. if len(input_shape) == 4:
  194. self.nb, self.nt, self.nx, self.ny = input_shape
  195. else:
  196. self.nb, nc, self.nt, self.nx, self.ny = input_shape
  197. self.E = E
  198. self.mask = mask
  199. self.learned_topk = learned_topk
  200. if self.learned_topk:
  201. if is_last:
  202. self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=False, name='thres_coef')
  203. else:
  204. self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=True, name='thres_coef')
  205. self.conv_1 = CONV_OP(n_f=16, ifactivate=True)
  206. self.conv_2 = CONV_OP(n_f=16, ifactivate=True)
  207. self.conv_3 = CONV_OP(n_f=16, ifactivate=True)
  208. self.conv_4 = CONV_OP(n_f=16, ifactivate=True)
  209. self.conv_5 = CONV_OP(n_f=2, ifactivate=False)
  210. def call(self, x_rec, d, input_shape):
  211. if len(input_shape) == 4:
  212. self.nb, self.nt, self.nx, self.ny = input_shape
  213. else:
  214. self.nb, nc, self.nt, self.nx, self.ny = input_shape
  215. x_rec = self.sparse(x_rec, d)
  216. return x_rec
  217. def sparse(self, x_rec, d):
  218. r_n = tf.stack([tf.math.real(x_rec), tf.math.imag(x_rec)], axis=-1)
  219. x_1 = self.conv_1(r_n)
  220. x_2 = self.conv_2(x_1)
  221. x_3 = self.conv_3(x_2)
  222. x_4 = self.conv_4(x_3)
  223. x_5 = self.conv_5(x_4)
  224. x_rec = x_5 + r_n
  225. x_rec = tf.complex(x_rec[:, :, :, :, 0], x_rec[:, :, :, :, 1])
  226. if self.learned_topk:
  227. x_rec = self.lowrank(x_rec)
  228. x_rec = self.dc_layer(x_rec, d)
  229. return x_rec
  230. def lowrank(self, x_rec):
  231. [batch, Nt, Nx, Ny] = x_rec.get_shape()
  232. M = tf.reshape(x_rec, [batch, Nt, Nx*Ny])
  233. St, Ut, Vt = tf.linalg.svd(M)
  234. if self.learned_topk:
  235. #tf.print(tf.sigmoid(self.thres_coef))
  236. thres = tf.sigmoid(self.thres_coef) * St[:, 0]
  237. thres = tf.expand_dims(thres, -1)
  238. St = tf.nn.relu(St - thres)
  239. else:
  240. top1_mask = np.concatenate(
  241. [np.ones([self.nb, 1], dtype=np.float32), np.zeros([self.nb, self.nt - 1], dtype=np.float32)], 1)
  242. top1_mask = tf.constant(top1_mask)
  243. St = St * top1_mask
  244. St = tf.linalg.diag(St)
  245. St = tf.dtypes.cast(St, tf.complex64)
  246. Vt_conj = tf.transpose(Vt, perm=[0, 2, 1])
  247. Vt_conj = tf.math.conj(Vt_conj)
  248. US = tf.linalg.matmul(Ut, St)
  249. M = tf.linalg.matmul(US, Vt_conj)
  250. x_rec = tf.reshape(M, [batch, Nt, Nx, Ny])
  251. return x_rec
  252. def dc_layer(self, x_rec, d):
  253. k_rec = fft2c_mri(x_rec)
  254. k_rec = (1 - self.mask) * k_rec + self.mask * d
  255. x_rec = ifft2c_mri(k_rec)
  256. return x_rec
  257. ###### Manifold_Net ######
  258. class Manifold_Net(tf.keras.Model):
  259. def __init__(self, mask, niter, learned_topk=False, N_factor=1):
  260. super(Manifold_Net, self).__init__(name='Manifold_Net')
  261. self.niter = niter
  262. self.E = Emat_xyt(mask)
  263. self.mask = mask
  264. self.learned_topk = learned_topk
  265. self.N_factor = N_factor
  266. self.celllist = []
  267. def build(self, input_shape):
  268. for i in range(self.niter-1):
  269. self.celllist.append(ManifoldCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk, N_factor=self.N_factor))
  270. self.celllist.append(ManifoldCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk, N_factor=self.N_factor, is_last=True))
  271. def call(self, d, csm):
  272. """
  273. d: undersampled k-space
  274. csm: coil sensitivity map
  275. """
  276. if csm == None:
  277. nb, nt, nx, ny = d.shape
  278. else:
  279. nb, nc, nt, nx, ny = d.shape
  280. x_rec = self.E.mtimes(d, inv=True, csm=csm)
  281. for i in range(self.niter):
  282. x_rec = self.celllist[i](x_rec, d, d.shape)
  283. return x_rec
  284. class ManifoldCell(layers.Layer):
  285. def __init__(self, input_shape, E, mask, learned_topk=False, N_factor=1, is_last=False):
  286. super(ManifoldCell, self).__init__()
  287. if len(input_shape) == 4:
  288. self.nb, self.nt, self.nx, self.ny = input_shape
  289. else:
  290. self.nb, nc, self.nt, self.nx, self.ny = input_shape
  291. self.E = E
  292. self.mask = mask
  293. self.Nx_factor = N_factor
  294. self.Ny_factor = N_factor
  295. self.Nt_factor = N_factor
  296. self.learned_topk = learned_topk
  297. if self.learned_topk:
  298. self.eta = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=True, name='eta')
  299. #self.lambda_sparse = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=True, name='lambda')
  300. self.conv_1 = CNNLayer(n_f=16, n_out=2)
  301. #self.conv_2 = CNNLayer(n_f=16, n_out=2)
  302. #self.conv_3 = CNNLayer(n_f=16, n_out=2)
  303. #self.conv_D = CNNLayer(n_f=16, n_out=2)
  304. #self.conv_transD = CNNLayer(n_f=16, n_out=2)
  305. def call(self, x_rec, d, input_shape):
  306. if len(input_shape) == 4:
  307. self.nb, self.nt, self.nx, self.ny = input_shape
  308. else:
  309. self.nb, nc, self.nt, self.nx, self.ny = input_shape
  310. x_k = self.conv_1(x_rec)
  311. #grad_sparse = self.conv_transD(self.conv_D(x_k))
  312. #grad_sparse = tf.stack([tf.math.real(grad_sparse), tf.math.imag(grad_sparse)], axis=-1)
  313. #grad_sparse = tf.multiply(self.lambda_sparse, grad_sparse)
  314. #grad_sparse = tf.complex(grad_sparse[..., 0], grad_sparse[..., 1])
  315. #grad_dc = ifft2c_mri((fft2c_mri(x_k) * self.mask - d) * self.mask)
  316. #g_k = grad_dc + grad_sparse
  317. g_k = ifft2c_mri((fft2c_mri(x_k) * self.mask - d) * self.mask)
  318. #g_k = self.E.mtimes(self.E.mtimes(x_k, inv=False, csm=csm) - d, inv=True, csm=csm)
  319. t_k = self.Tangent_Module(g_k, x_k)
  320. x_k = self.Retraction_Module(x_k, t_k)
  321. #x_k = self.conv_3(x_k)
  322. x_k = self.dc_layer(x_k, d)
  323. return x_k
  324. def Tangent_Module(self, g_k, x_k):
  325. batch, Nt, Nx, Ny = x_k.shape
  326. x_k = tf.transpose(x_k, [0, 2, 3, 1]) # batch, Nx, Ny, Nt
  327. g_k = tf.transpose(g_k, [0, 2, 3, 1]) # batch, Nx, Ny, Nt
  328. Ux, Uy, Ut = self.Mode(x_k)
  329. first_term = self.Mode_Multiply(g_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1)
  330. first_term = self.Mode_Multiply(first_term, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2)
  331. first_term = self.Mode_Multiply(first_term, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3)
  332. first_term = self.Mode_Multiply(first_term, Ux, mode_n=1)
  333. first_term = self.Mode_Multiply(first_term, Uy, mode_n=2)
  334. first_term = self.Mode_Multiply(first_term, Ut, mode_n=3)
  335. C_mode_x, C_mode_y, C_mode_t = self.Core_C(x_k, Ux, Uy, Ut)
  336. second_term_1 = self.Mode_Multiply(g_k, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2)
  337. second_term_1 = self.Mode_Multiply(second_term_1, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3)
  338. second_term_1 = tf.reshape(second_term_1, [batch, Nx, Ny * Nt])
  339. second_term_1 = self.Projector(second_term_1, Ux)
  340. second_term_1 = self.Core_Multiply(second_term_1, C_mode_x)
  341. second_term_1 = tf.linalg.matmul(second_term_1, C_mode_x)
  342. second_term_1 = tf.reshape(second_term_1, [batch, Nx, Ny, Nt])
  343. second_term_1 = self.Mode_Multiply(second_term_1, Uy, mode_n=2)
  344. second_term_1 = self.Mode_Multiply(second_term_1, Ut, mode_n=3)
  345. second_term_2 = self.Mode_Multiply(g_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1)
  346. second_term_2 = self.Mode_Multiply(second_term_2, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3)
  347. second_term_2 = tf.reshape(tf.transpose(second_term_2, [0, 2, 1, 3]), [batch, Ny, Nx*Nt])
  348. second_term_2 = self.Projector(second_term_2, Uy)
  349. second_term_2 = self.Core_Multiply(second_term_2, C_mode_y)
  350. second_term_2 = tf.linalg.matmul(second_term_2, C_mode_y)
  351. second_term_2 = tf.transpose(tf.reshape(second_term_2, [batch, Ny, Nx, Nt]), [0, 2, 1, 3])
  352. second_term_2 = self.Mode_Multiply(second_term_2, Ux, mode_n=1)
  353. second_term_2 = self.Mode_Multiply(second_term_2, Ut, mode_n=3)
  354. second_term_3 = self.Mode_Multiply(g_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1)
  355. second_term_3 = self.Mode_Multiply(second_term_3, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2)
  356. second_term_3 = tf.reshape(tf.transpose(second_term_3, [0, 3, 1, 2]), [batch, Nt, Nx * Ny])
  357. second_term_3 = self.Projector(second_term_3, Ut)
  358. second_term_3 = self.Core_Multiply(second_term_3, C_mode_t)
  359. second_term_3 = tf.linalg.matmul(second_term_3, C_mode_t)
  360. second_term_3 = tf.transpose(tf.reshape(second_term_3, [batch, Nt, Nx, Ny]), [0, 2, 3, 1])
  361. second_term_3 = self.Mode_Multiply(second_term_3, Ux, mode_n=1)
  362. second_term_3 = self.Mode_Multiply(second_term_3, Uy, mode_n=2)
  363. t_k = first_term + second_term_1 + second_term_2 + second_term_3
  364. t_k = tf.transpose(t_k, [0, 3, 1, 2])
  365. return t_k
  366. def Retraction_Module(self, x_k, t_k):
  367. x_k = tf.stack([tf.math.real(x_k), tf.math.imag(x_k)], axis=-1)
  368. t_k = tf.stack([tf.math.real(t_k), tf.math.imag(t_k)], axis=-1)
  369. x_k = x_k - tf.multiply(self.eta, t_k)
  370. x_k = tf.complex(x_k[..., 0], x_k[..., 1])
  371. batch, Nt, Nx, Ny = x_k.shape
  372. x_k = tf.transpose(x_k, [0, 2, 3, 1]) # batch, Nx, Ny, Nt
  373. Ux, Uy, Ut = self.Mode(x_k)
  374. Ux = self.SVT_U(Ux, top_kth= int(Nx / self.Nx_factor))
  375. Uy = self.SVT_U(Uy, top_kth= int(Ny / self.Ny_factor))
  376. Ut = self.SVT_U(Ut, top_kth= int(Nt / self.Nt_factor))
  377. """
  378. Ux = self.SVT_U(Ux, top_kth= Nx // self.Nx_factor)
  379. Uy = self.SVT_U(Uy, top_kth= Ny // self.Ny_factor)
  380. Ut = self.SVT_U(Ut, top_kth= Nt // self.Nt_factor)
  381. """
  382. C = self.Mode_Multiply(x_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1)
  383. C = self.Mode_Multiply(C, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2)
  384. C = self.Mode_Multiply(C, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3)
  385. x_k = self.Mode_Multiply(C, Ux, mode_n=1)
  386. x_k = self.Mode_Multiply(x_k, Uy, mode_n=2)
  387. x_k = self.Mode_Multiply(x_k, Ut, mode_n=3)
  388. x_k = tf.transpose(x_k, [0, 3, 1, 2])
  389. return x_k
  390. def Mode(self, x_k):
  391. batch, Nx, Ny, Nt = x_k.shape
  392. mode_x = tf.reshape(x_k, [batch, Nx, Ny*Nt])
  393. mode_y = tf.reshape(tf.transpose(x_k, [0, 2, 1, 3]), [batch, Ny, Nx*Nt])
  394. mode_t = tf.reshape(tf.transpose(x_k, [0, 3, 1, 2]), [batch, Nt, Nx*Ny])
  395. Sx, Ux, Vx = tf.linalg.svd(mode_x) # Ux: batch, 192, 192
  396. Sy, Uy, Vy = tf.linalg.svd(mode_y) # Uy: batch, 192, 192
  397. St, Ut, Vt = tf.linalg.svd(mode_t)
  398. return Ux, Uy, Ut
  399. def Mode_Multiply(self, A, U, mode_n=1):
  400. """
  401. A: batch, Nx, Ny, Nt
  402. U: batch, Nx, Ny
  403. return: batch, Nx, Ny, Nt
  404. """
  405. batch, Nx, Ny, Nt = A.shape
  406. if mode_n == 1:
  407. out = tf.linalg.matmul(U, tf.reshape(A, [batch, Nx, Ny*Nt])) # batch, Nx, Ny*Nt
  408. out= tf.reshape(out, [batch, Nx, Ny, Nt])
  409. elif mode_n == 2:
  410. out = tf.linalg.matmul(U, tf.reshape(tf.transpose(A, [0, 2, 1, 3]), [batch, Ny, Nx * Nt])) # batch, Ny, Nx*Nt
  411. out = tf.transpose(tf.reshape(out, [batch, Ny, Nx, Nt]), [0, 2, 1, 3])
  412. elif mode_n == 3:
  413. out = tf.linalg.matmul(U, tf.reshape(tf.transpose(A, [0, 3, 1, 2]), [batch, Nt, Nx * Ny])) # batch, Nt, Nx*Ny
  414. out = tf.transpose(tf.reshape(out, [batch, Nt, Nx, Ny]), [0, 2, 3, 1])
  415. return out
  416. def Core_C(self, x_k, Ux, Uy, Ut):
  417. batch, Nx, Ny, Nt = x_k.shape
  418. C = self.Mode_Multiply(x_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1)
  419. C = self.Mode_Multiply(C, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2)
  420. C = self.Mode_Multiply(C, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3)
  421. C_mode_x = tf.reshape(C, [batch, Nx, Ny * Nt])
  422. C_mode_y = tf.reshape(tf.transpose(C, [0, 2, 1, 3]), [batch, Ny, Nx * Nt])
  423. C_mode_t = tf.reshape(tf.transpose(C, [0, 3, 1, 2]), [batch, Nt, Nx * Ny])
  424. return C_mode_x, C_mode_y, C_mode_t
  425. def Projector(self, second_term, U):
  426. second_term = second_term - tf.linalg.matmul(
  427. tf.linalg.matmul(U,
  428. tf.transpose(U, [0, 2, 1], conjugate=True)),
  429. second_term)
  430. return second_term
  431. def Core_Multiply(self, second_term, C_mode):
  432. second_term = tf.linalg.matmul(second_term,
  433. tf.linalg.matmul(tf.transpose(C_mode, [0, 2, 1], conjugate=True),
  434. tf.linalg.inv(tf.linalg.matmul(C_mode,
  435. tf.transpose(C_mode, [0, 2, 1], conjugate=True)))))
  436. return second_term
  437. def SVT_U(self, Uk, top_kth):
  438. [batch, Nx, Ny] = Uk.get_shape()
  439. mask_1 = tf.ones([batch, Nx, top_kth])
  440. mask_2 = tf.zeros([batch, Nx, Ny - top_kth])
  441. mask_top_k = tf.concat([mask_1, mask_2], axis=-1)
  442. mask_top_k = tf.cast(mask_top_k, dtype=Uk.dtype)
  443. Uk = tf.multiply(Uk, mask_top_k)
  444. return Uk
  445. def dc_layer(self, x_rec, d):
  446. k_rec = fft2c_mri(x_rec)
  447. k_rec = (1 - self.mask) * k_rec + self.mask * d
  448. x_rec = ifft2c_mri(k_rec)
  449. return x_rec
  450. def dc_layer_v2(self, x_rec, d):
  451. x_rec = x_rec - ifft2c_mri(fft2c_mri(x_rec) * self.mask - d)
  452. return x_rec

简介

No Description

Python

贡献者 (1)