You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

274 lines
9.5 kB

  1. import tensorflow as tf
  2. from tensorflow.keras import layers
  3. import os
  4. import numpy as np
  5. import time
  6. from tools.tools import tempfft, fft2c_mri, ifft2c_mri, Emat_xyt
  7. class CNNLayer(tf.keras.layers.Layer):
  8. def __init__(self, n_f=32):
  9. super(CNNLayer, self).__init__()
  10. self.mylayers = []
  11. self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
  12. self.mylayers.append(tf.keras.layers.LeakyReLU())
  13. self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
  14. self.mylayers.append(tf.keras.layers.LeakyReLU())
  15. self.mylayers.append(tf.keras.layers.Conv3D(2, 3, strides=1, padding='same', use_bias=False))
  16. self.seq = tf.keras.Sequential(self.mylayers)
  17. def call(self, input):
  18. if len(input.shape) == 4:
  19. input2c = tf.stack([tf.math.real(input), tf.math.imag(input)], axis=-1)
  20. else:
  21. input2c = tf.concat([tf.math.real(input), tf.math.imag(input)], axis=-1)
  22. res = self.seq(input2c)
  23. res = tf.complex(res[:,:,:,:,0], res[:,:,:,:,1])
  24. return res
  25. class CONV_OP(tf.keras.layers.Layer):
  26. def __init__(self, n_f=32, ifactivate=False):
  27. super(CONV_OP, self).__init__()
  28. self.mylayers = []
  29. self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
  30. if ifactivate == True:
  31. self.mylayers.append(tf.keras.layers.ReLU())
  32. self.seq = tf.keras.Sequential(self.mylayers)
  33. def call(self, input):
  34. res = self.seq(input)
  35. return res
  36. class SLR_Net(tf.keras.Model):
  37. def __init__(self, mask, niter, learned_topk=False):
  38. super(SLR_Net, self).__init__(name='SLR_Net')
  39. self.niter = niter
  40. self.E = Emat_xyt(mask)
  41. self.learned_topk = learned_topk
  42. self.celllist = []
  43. def build(self, input_shape):
  44. for i in range(self.niter-1):
  45. self.celllist.append(SLRCell(input_shape, self.E, learned_topk=self.learned_topk))
  46. self.celllist.append(SLRCell(input_shape, self.E, learned_topk=self.learned_topk, is_last=True))
  47. def call(self, d, csm):
  48. """
  49. d: undersampled k-space
  50. csm: coil sensitivity map
  51. """
  52. if csm == None:
  53. nb, nt, nx, ny = d.shape
  54. else:
  55. nb, nc, nt, nx, ny = d.shape
  56. X_SYM = []
  57. x_rec = self.E.mtimes(d, inv=True, csm=csm)
  58. t = tf.zeros_like(x_rec)
  59. beta = tf.zeros_like(x_rec)
  60. x_sym = tf.zeros_like(x_rec)
  61. data = [x_rec, x_sym, beta, t, d, csm]
  62. for i in range(self.niter):
  63. data = self.celllist[i](data, d.shape)
  64. x_sym = data[1]
  65. X_SYM.append(x_sym)
  66. x_rec = data[0]
  67. return x_rec, X_SYM
  68. class SLRCell(layers.Layer):
  69. def __init__(self, input_shape, E, learned_topk=False, is_last=False):
  70. super(SLRCell, self).__init__()
  71. if len(input_shape) == 4:
  72. self.nb, self.nt, self.nx, self.ny = input_shape
  73. else:
  74. self.nb, nc, self.nt, self.nx, self.ny = input_shape
  75. self.E = E
  76. self.learned_topk = learned_topk
  77. if self.learned_topk:
  78. if is_last:
  79. self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=False, name='thres_coef')
  80. self.eta = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=False, name='eta')
  81. else:
  82. self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=True, name='thres_coef')
  83. self.eta = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=True, name='eta')
  84. self.conv_1 = CONV_OP(n_f=16, ifactivate=True)
  85. self.conv_2 = CONV_OP(n_f=16, ifactivate=True)
  86. self.conv_3 = CONV_OP(n_f=16, ifactivate=False)
  87. self.conv_4 = CONV_OP(n_f=16, ifactivate=True)
  88. self.conv_5 = CONV_OP(n_f=16, ifactivate=True)
  89. self.conv_6 = CONV_OP(n_f=2, ifactivate=False)
  90. #self.conv_7 = CONV_OP(n_f=16, ifactivate=True)
  91. #self.conv_8 = CONV_OP(n_f=16, ifactivate=True)
  92. #self.conv_9 = CONV_OP(n_f=16, ifactivate=True)
  93. #self.conv_10 = CONV_OP(n_f=16, ifactivate=True)
  94. self.lambda_step = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='lambda_1')
  95. self.lambda_step_2 = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='lambda_2')
  96. self.soft_thr = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='soft_thr')
  97. def call(self, data, input_shape):
  98. if len(input_shape) == 4:
  99. self.nb, self.nt, self.nx, self.ny = input_shape
  100. else:
  101. self.nb, nc, self.nt, self.nx, self.ny = input_shape
  102. x_rec, x_sym, beta, t, d, csm = data
  103. x_rec, x_sym = self.sparse(x_rec, d, t, beta, csm)
  104. t = self.lowrank(x_rec)
  105. beta = self.beta_mid(beta, x_rec, t)
  106. data[0] = x_rec
  107. data[1] = x_sym
  108. data[2] = beta
  109. data[3] = t
  110. return data
  111. def sparse(self, x_rec, d, t, beta, csm):
  112. lambda_step = tf.cast(tf.nn.relu(self.lambda_step), tf.complex64)
  113. lambda_step_2 = tf.cast(tf.nn.relu(self.lambda_step_2), tf.complex64)
  114. ATAX_cplx = self.E.mtimes(self.E.mtimes(x_rec, inv=False, csm=csm) - d, inv=True, csm=csm)
  115. r_n = x_rec - tf.math.scalar_mul(lambda_step, ATAX_cplx) +\
  116. tf.math.scalar_mul(lambda_step_2, x_rec + beta - t)
  117. # D_T(soft(D_r_n))
  118. if len(r_n.shape) == 4:
  119. r_n = tf.stack([tf.math.real(r_n), tf.math.imag(r_n)], axis=-1)
  120. x_1 = self.conv_1(r_n)
  121. x_2 = self.conv_2(x_1)
  122. x_3 = self.conv_3(x_2)
  123. x_soft = tf.math.multiply(tf.math.sign(x_3), tf.nn.relu(tf.abs(x_3) - self.soft_thr))
  124. x_4 = self.conv_4(x_soft)
  125. x_5 = self.conv_5(x_4)
  126. x_6 = self.conv_6(x_5)
  127. x_rec = x_6 + r_n
  128. x_1_sym = self.conv_4(x_3)
  129. x_1_sym = self.conv_5(x_1_sym)
  130. x_1_sym = self.conv_6(x_1_sym)
  131. #x_sym_1 = self.conv_10(x_1_sym)
  132. x_sym = x_1_sym - r_n
  133. x_rec = tf.complex(x_rec[:, :, :, :, 0], x_rec[:, :, :, :, 1])
  134. return x_rec, x_sym
  135. def lowrank(self, x_rec):
  136. [batch, Nt, Nx, Ny] = x_rec.get_shape()
  137. M = tf.reshape(x_rec, [batch, Nt, Nx*Ny])
  138. St, Ut, Vt = tf.linalg.svd(M)
  139. if self.learned_topk:
  140. #tf.print(tf.sigmoid(self.thres_coef))
  141. thres = tf.sigmoid(self.thres_coef) * St[:, 0]
  142. thres = tf.expand_dims(thres, -1)
  143. St = tf.nn.relu(St - thres)
  144. else:
  145. top1_mask = np.concatenate(
  146. [np.ones([self.nb, 1], dtype=np.float32), np.zeros([self.nb, self.nt - 1], dtype=np.float32)], 1)
  147. top1_mask = tf.constant(top1_mask)
  148. St = St * top1_mask
  149. St = tf.linalg.diag(St)
  150. St = tf.dtypes.cast(St, tf.complex64)
  151. Vt_conj = tf.transpose(Vt, perm=[0, 2, 1])
  152. Vt_conj = tf.math.conj(Vt_conj)
  153. US = tf.linalg.matmul(Ut, St)
  154. M = tf.linalg.matmul(US, Vt_conj)
  155. x_rec = tf.reshape(M, [batch, Nt, Nx, Ny])
  156. return x_rec
  157. def beta_mid(self, beta, x_rec, t):
  158. eta = tf.cast(tf.nn.relu(self.eta), tf.complex64)
  159. return beta + tf.multiply(eta, x_rec - t)
  160. class S_Net(tf.keras.Model):
  161. def __init__(self, mask, niter):
  162. super(S_Net, self).__init__(name='S_Net')
  163. self.niter = niter
  164. self.E = Emat_xyt(mask)
  165. self.celllist = []
  166. def build(self, input_shape):
  167. for i in range(self.niter-1):
  168. self.celllist.append(SCell_learned_step(input_shape, self.E, is_last=False))
  169. self.celllist.append(SCell_learned_step(input_shape, self.E, is_last=True))
  170. def call(self, d):
  171. nb, nt, nx, ny = d.shape
  172. Spre = tf.reshape(self.E.mtimes(d, inv=True), [nb, nt, nx*ny])
  173. Mpre = Spre
  174. data = [Spre, Mpre, d]
  175. for i in range(self.niter):
  176. data = self.celllist[i](data)
  177. S, M, _ = data
  178. #M = tf.reshape(M, [nb, nt, nx, ny])
  179. S = tf.reshape(S, [nb, nt, nx, ny])
  180. return S
  181. class SCell_learned_step(layers.Layer):
  182. def __init__(self, input_shape, E, is_last):
  183. super(SCell_learned_step, self).__init__()
  184. self.nb, self.nt, self.nx, self.ny = input_shape
  185. self.E = E
  186. self.sconv = CNNLayer(n_f=32)
  187. self.is_last = is_last
  188. if not is_last:
  189. self.gamma = tf.Variable(tf.constant(1, dtype=tf.float32), trainable=True)
  190. def call(self, data):
  191. Spre, Mpre, d = data
  192. S = self.sparse(Mpre)
  193. dc = self.dataconsis(S, d)
  194. if not self.is_last:
  195. gamma = tf.cast(tf.nn.relu(self.gamma), tf.complex64)
  196. else:
  197. gamma = tf.cast(1.0, tf.complex64)
  198. M = S - gamma * dc
  199. data[0] = S
  200. data[1] = M
  201. return data
  202. def sparse(self, S):
  203. S = tf.reshape(S, [self.nb, self.nt, self.nx, self.ny])
  204. S = self.sconv(S)
  205. S = tf.reshape(S, [self.nb, self.nt, self.nx*self.ny])
  206. return S
  207. def dataconsis(self, LS, d):
  208. resk = self.E.mtimes(tf.reshape(LS, [self.nb, self.nt, self.nx, self.ny]), inv=False) - d
  209. dc = tf.reshape(self.E.mtimes(resk, inv=True), [self.nb, self.nt, self.nx*self.ny])
  210. return dc

简介

No Description

Python

贡献者 (1)