You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

307 lines
13 kB

  1. import tensorflow as tf
  2. import numpy as np
  3. import inference
  4. import compressed_sensing as cs
  5. from helpers import to_lasagne_format
  6. import time
  7. from os.path import join
  8. import os
  9. from datasets import Train_data
  10. batch_size = 10
  11. learning_rate_base = 0.0001
  12. learning_rate_decay = 0.95
  13. regularization_rate = 1e-7
  14. Train_steps = 30
  15. num_train = 15000
  16. num_validate = 2000
  17. input_shape = [batch_size, 6, 117, 120]
  18. project_root = '.'
  19. model_file = "model_K5_D3C4_TV_iso_e-8_ComplexConv_kLoss_e-1_block1-3_e+3_e+3_e+3"
  20. model_save_path = join(project_root, 'models/%s' % model_file)
  21. if not os.path.isdir(model_save_path):
  22. os.makedirs(model_save_path)
  23. model_name = "model.ckpt"
  24. def iterate_minibatch(data, batch_size, shuffle=False):
  25. n = len(data)
  26. if shuffle:
  27. data = np.random.permutation(data)
  28. for i in xrange(0, n, batch_size):
  29. yield data[i:i+batch_size]
  30. def prep_input(ys, mask):
  31. """Undersample the batch, then reformat them into what the network accepts.
  32. Parameters
  33. ----------
  34. gauss_ivar: float - controls the undersampling rate.
  35. higher the value, more undersampling
  36. """
  37. """
  38. mask = cs.cartesian_mask(ys.shape, gauss_ivar,
  39. centred=False,
  40. sample_high_freq=True,
  41. sample_centre=True,
  42. sample_n=6)
  43. """
  44. xs, k_und, k_full = cs.undersample(ys, mask, centred=False, norm=None)
  45. ys_l = to_lasagne_format(ys)
  46. xs_l = to_lasagne_format(xs)
  47. mask = mask.astype(np.complex)
  48. return xs_l, k_und, mask, ys_l, k_full
  49. def TV(f, case=1):
  50. indices_x = np.random.randint(1, 116, [117])
  51. indices_x[0:116] = range(1, 117)
  52. indices_x[116] = 0
  53. indices_y = np.random.randint(1, 119, [120])
  54. indices_y[0:119] = range(1, 120)
  55. indices_y[119] = 0
  56. f_x = tf.gather(f, indices=indices_x, axis=2) - f
  57. f_y = tf.gather(f, indices=indices_y, axis=3) - f
  58. # anisotropy
  59. if case == 1:
  60. TV_f = tf.reduce_mean(tf.reduce_sum(tf.abs(f_x) + tf.abs(f_y), [1, 2, 3, 4]))
  61. print("Using anisotropy TV")
  62. # isotropy
  63. if case == 2:
  64. TV_f = tf.reduce_mean(tf.reduce_sum(tf.sqrt(tf.square(f_x) + tf.square(f_y)), [1, 2, 3, 4]))
  65. print("Using isotropy TV")
  66. return TV_f
  67. def HDTV(f, case=1, degree=2):
  68. indices_x = np.random.randint(1, 116, [117])
  69. indices_x[0:116] = range(1, 117)
  70. indices_x[116] = 0
  71. indices_y = np.random.randint(1, 119, [120])
  72. indices_y[0:119] = range(1, 120)
  73. indices_y[119] = 0
  74. f_x = tf.gather(f, indices=indices_x, axis=2) - f
  75. f_y = tf.gather(f, indices=indices_y, axis=3) - f
  76. if degree == 2:
  77. f_xx_n = tf.gather(f_x, indices=indices_x, axis=2) - f_x
  78. f_yy_n = tf.gather(f_y, indices=indices_y, axis=3) - f_y
  79. f_xy_n = tf.gather(f_x, indices=indices_y, axis=3) - f_x
  80. if degree == 3:
  81. f_xx = tf.gather(f_x, indices=indices_x, axis=2) - f_x
  82. f_yy = tf.gather(f_y, indices=indices_y, axis=3) - f_y
  83. f_xy = tf.gather(f_x, indices=indices_y, axis=3) - f_x
  84. f_xx_n = tf.gather(f_xx, indices=indices_x, axis=2) - f_xx
  85. f_yy_n = tf.gather(f_yy, indices=indices_y, axis=3) - f_yy
  86. f_xxy_n = tf.gather(f_xx, indices=indices_y, axis=3) - f_xx
  87. f_xyy_n = tf.gather(f_xy, indices=indices_y, axis=3) - f_xy
  88. if case == 1:
  89. if degree == 2:
  90. HDTV_f = tf.reduce_mean(tf.reduce_sum(tf.sqrt((3 * tf.square(f_xx_n) + 3 * tf.square(f_yy_n)
  91. + 4 * tf.square(f_xy_n) + tf.multiply(f_xx_n, f_yy_n))
  92. / 8), [1, 2, 3, 4]))
  93. if degree == 3:
  94. HDTV_f = tf.reduce_mean(tf.reduce_sum(tf.sqrt((5 * (tf.square(f_xx_n) + tf.square(f_yy_n)) +
  95. 3 * (tf.multiply(f_xx_n, f_xyy_n) + tf.multiply(f_yy_n, f_xxy_n)) +
  96. 9 * (tf.square(f_xxy_n) + tf.square(f_xyy_n)))
  97. / (4 * np.sqrt(2))),
  98. [1, 2, 3, 4]))
  99. return HDTV_f
  100. def real2complex(input_op, inv=False):
  101. if inv == False:
  102. return tf.complex(input_op[:, :, :, :, 0], input_op[:, :, :, :, 1])
  103. else:
  104. input_real = tf.cast(tf.real(input_op), dtype=tf.float32)
  105. input_imag = tf.cast(tf.imag(input_op), dtype=tf.float32)
  106. return tf.stack([input_real, input_imag], axis=4)
  107. def total_loss(y, y_, block_1, block_2, block_3, block_k_1, k_full):
  108. lambda_TV = 1e-8
  109. lambda_mse_image2 = 1e+3
  110. lambda_kspace1 = 1e-1
  111. loss_mse_image = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(y, y_), [1, 2, 3, 4]))
  112. loss_mse_image1 = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(block_1, y_), [1, 2, 3, 4]))
  113. loss_mse_image2 = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(block_2, y_), [1, 2, 3, 4]))
  114. loss_mse_image3 = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(block_3, y_), [1, 2, 3, 4]))
  115. #loss_mse_image4 = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(block_4, y_), [1, 2, 3, 4]))
  116. kspace_full_real = real2complex(k_full, inv=True)
  117. loss_mse_kspace1 = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(block_k_1, kspace_full_real), [1, 2, 3, 4]))
  118. loss_TV = TV(y, case=2)
  119. loss = loss_mse_image + lambda_TV * loss_TV + lambda_kspace1 * loss_mse_kspace1 + lambda_mse_image2 * loss_mse_image1 +\
  120. lambda_mse_image2 * loss_mse_image2 + lambda_mse_image2 * loss_mse_image3
  121. tf.add_to_collection('losses', loss)
  122. return tf.add_n(tf.get_collection('losses'))
  123. def train(train_data, validate_data, mask):
  124. print ("compling...")
  125. train_plot = []
  126. validate_plot = []
  127. #x = tf.placeholder(tf.float32, shape=[None, 6, 117, 120, 2], name='x-input')
  128. y_ = tf.placeholder(tf.float32, shape=[None, 6, 117, 120, 2], name='y-label')
  129. mask_p = tf.placeholder(tf.complex64, shape=[None, 6, 117, 120], name='mask')
  130. kspace_p = tf.placeholder(tf.complex64, shape=[None, 6, 117, 120], name='kspace')
  131. kspace_full = tf.placeholder(tf.complex64, shape=[None, 6, 117, 120], name='kspace_full')
  132. regularizer = tf.contrib.layers.l2_regularizer(regularization_rate)
  133. y, block_1, block_2, block_3, block_k_1 = inference.inference(mask_p, kspace_p, regularizer)
  134. global_step = tf.Variable(0, trainable=False)
  135. loss = total_loss(y, y_, block_1, block_2, block_3, block_k_1, kspace_full)
  136. learning_rate = tf.train.exponential_decay(learning_rate_base,
  137. global_step=global_step,
  138. decay_steps=num_train / batch_size,
  139. decay_rate=learning_rate_decay)
  140. train_step = tf.train.AdamOptimizer(learning_rate).\
  141. minimize(loss, global_step=global_step)
  142. saver = tf.train.Saver()
  143. with tf.Session() as sess:
  144. tf.global_variables_initializer().run()
  145. train_data_per_num = 5000
  146. # get Initalized value of loss
  147. count_train = 0
  148. loss_sum_train = 0.0
  149. for ys_train in iterate_minibatch(train_data, batch_size, shuffle=True):
  150. _, kspace_l, mask_l, ys_l, k_full_l = prep_input(ys_train, mask)
  151. im_start = time.time()
  152. loss_value_train = sess.run(loss, feed_dict={y_: ys_l,
  153. mask_p: mask_l, kspace_p: kspace_l, kspace_full: k_full_l
  154. })
  155. im_end = time.time()
  156. loss_sum_train += loss_value_train
  157. count_train += 1
  158. print("{}\{} of train loss (just get loss):\t\t{:.6f} \t using :{:.4f}s"
  159. .format(count_train, int(num_train / batch_size),
  160. loss_sum_train / count_train, im_end - im_start))
  161. count_validate = 0
  162. loss_sum_validate = 0.0
  163. for ys_validate in iterate_minibatch(validate_data, batch_size, shuffle=True):
  164. _, kspace_l, mask_l, ys_l, k_full_l = prep_input(ys_validate, mask)
  165. im_start = time.time()
  166. loss_value_validate = sess.run(loss,
  167. feed_dict={y_: ys_l,
  168. mask_p: mask_l, kspace_p: kspace_l, kspace_full: k_full_l})
  169. im_end = time.time()
  170. loss_sum_validate += loss_value_validate
  171. count_validate += 1
  172. print("{}\{} of validation loss:\t\t{:.6f} \t using :{:.4f}s".
  173. format(count_validate, int(num_validate / batch_size),
  174. loss_sum_validate / count_validate, im_end - im_start))
  175. train_plot.append(loss_sum_train / count_train)
  176. validate_plot.append(loss_sum_validate / count_validate)
  177. for i in range(Train_steps):
  178. j = 0
  179. for train_data_per in iterate_minibatch(train_data, batch_size=train_data_per_num, shuffle=True):
  180. count_train = 0
  181. loss_sum_train = 0.0
  182. for ys in iterate_minibatch(train_data_per, batch_size, shuffle=False):
  183. _, kspace_l, mask_l, ys_l, k_full_l = prep_input(ys, mask)
  184. im_start = time.time()
  185. _, loss_value, step = sess.run([train_step, loss, global_step],
  186. feed_dict={y_: ys_l, mask_p: mask_l, kspace_p: kspace_l, kspace_full: k_full_l})
  187. im_end = time.time()
  188. loss_sum_train += loss_value
  189. print("{}\{}\{}\{} of training loss:\t\t{:.6f} \t using :{:.4f}s".
  190. format(i+1, j+1, count_train + 1, int(train_data_per_num / batch_size),
  191. loss_sum_train / (count_train + 1), im_end - im_start))
  192. count_train += 1
  193. # validating and get train loss
  194. count_train_per = 0
  195. loss_sum_train_per = 0.0
  196. for ys_train in iterate_minibatch(train_data_per, batch_size, shuffle=True):
  197. _, kspace_l, mask_l, ys_l, k_full_l = prep_input(ys_train, mask)
  198. im_start = time.time()
  199. loss_value_train_per = sess.run(loss, feed_dict={y_: ys_l,
  200. mask_p: mask_l, kspace_p: kspace_l, kspace_full: k_full_l})
  201. im_end = time.time()
  202. loss_sum_train_per += loss_value_train_per
  203. count_train_per += 1
  204. print("{}\{}\{}\{} of train loss (just get loss):\t\t{:.6f} \t using :{:.4f}s"
  205. .format(i+1, j+1, count_train_per, int(train_data_per_num / batch_size),
  206. loss_sum_train_per / count_train_per, im_end - im_start))
  207. count_validate = 0
  208. loss_sum_validate = 0.0
  209. for ys_validate in iterate_minibatch(validate_data, batch_size, shuffle=True):
  210. _, kspace_l, mask_l, ys_l, k_full_l = prep_input(ys_validate, mask)
  211. im_start = time.time()
  212. loss_value_validate = sess.run(loss,
  213. feed_dict={y_: ys_l,
  214. mask_p: mask_l, kspace_p: kspace_l, kspace_full: k_full_l})
  215. im_end = time.time()
  216. loss_sum_validate += loss_value_validate
  217. count_validate += 1
  218. print("{}\{}\{}\{} of validation loss:\t\t{:.6f} \t using :{:.4f}s".
  219. format(i+1, j+1, count_validate, int(num_validate / batch_size), loss_sum_validate / count_validate, im_end - im_start))
  220. train_plot.append(loss_sum_train_per / count_train_per)
  221. validate_plot.append(loss_sum_validate / count_validate)
  222. j += 1
  223. print ("After %d train epochs, loss on training batch is %g\n model has been saved in %s"
  224. % (i, loss_sum_train / count_train, model_save_path))
  225. saver.save(sess, os.path.join(model_save_path, model_name), global_step=global_step)
  226. train_plot_name = 'train_plot.npy'
  227. np.save(join(model_save_path, train_plot_name), train_plot)
  228. validate_plot_name = 'validate_plot.npy'
  229. np.save(join(model_save_path, validate_plot_name), validate_plot)
  230. def main(argv=None):
  231. # ivar = 0.003, acc = 4
  232. # ivar = 0.008, acc = 6
  233. # ivar = 0.015, acc = 8
  234. # ivar = 0.030, acc = 10
  235. # ivar = 0.070, acc = 12
  236. mask = cs.cartesian_mask(input_shape, ivar=0.003,
  237. centred=False,
  238. sample_high_freq=True,
  239. sample_centre=True,
  240. sample_n=6)
  241. acc = mask.size / np.sum(mask)
  242. print ('Acceleration Rate:{:.2f}'.format(acc))
  243. train_data, validate_data = Train_data()
  244. train(train_data, validate_data, mask)
  245. if __name__ == '__main__':
  246. tf.app.run()

简介

No Description

Python

贡献者 (1)