You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

331 lines
16 kB

  1. import time
  2. import tensorflow as tf
  3. import numpy as np
  4. import compressed_sensing as cs
  5. import matplotlib.pyplot as plt
  6. from skimage import exposure
  7. import mymath
  8. import os
  9. from os.path import join
  10. import scipy.io as scio
  11. from utils.metric import complex_psnr
  12. from utils.metric import mse
  13. from skimage.measure import compare_ssim
  14. import inference
  15. import train
  16. from datasets import Test_data
  17. batch_size = 1
  18. input_shape = [batch_size, 6, 117, 120]
  19. model_name = "model.ckpt"
  20. def real2complex(x):
  21. '''
  22. Converts from array of the form ([n, ]nt, nx, ny 2) to ([n, ] nt, nx, ny)
  23. '''
  24. x = np.asarray(x)
  25. if x.shape[0] == 2 and x.shape[1] != 2: # Hacky check
  26. return x[0] + x[1] * 1j
  27. elif x.shape[4] == 2:
  28. y = x[:, :, :, :, 0] + x[:, :, :, :, 1] * 1j
  29. return y
  30. else:
  31. raise ValueError('Invalid dimension')
  32. def performance(xs, y, ys):
  33. base_mse = mse(ys, xs)
  34. test_mse = mse(ys, y)
  35. base_psnr = complex_psnr(ys, xs, peak='max')
  36. test_psnr = complex_psnr(ys, y, peak='max')
  37. batch, nt, nx, ny = y.shape
  38. base_ssim = 0
  39. test_ssim = 0
  40. for i in range(nt):
  41. base_ssim += compare_ssim(np.abs(ys[0][i]).astype('float64'),
  42. np.abs(xs[0][i]).astype('float64'))
  43. test_ssim += compare_ssim(np.abs(ys[0][i]).astype('float64'),
  44. np.abs(y[0][i]).astype('float64'))
  45. base_ssim /= nt
  46. test_ssim /= nt
  47. return base_mse, test_mse, base_psnr, test_psnr, base_ssim, test_ssim
  48. def evaluate(test_data, mask, model_save_path, model_file):
  49. with tf.Graph() .as_default() as g:
  50. #x = tf.placeholder(tf.float32, shape=[None, 6, 117, 120, 2], name='x-input')
  51. y_ = tf.placeholder(tf.float32, shape=[None, 6, 117, 120, 2], name='y-label')
  52. mask_p = tf.placeholder(tf.complex64, shape=[None, 6, 117, 120], name='mask')
  53. kspace_p = tf.placeholder(tf.complex64, shape=[None, 6, 117, 120], name='kspace')
  54. kspace_full = tf.placeholder(tf.complex64, shape=[None, 6, 117, 120], name='kspace_full')
  55. y, block_k_1, block_k_2, block_k_3, block_k_4 = inference.inference(mask_p, kspace_p, None)
  56. loss = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(y, y_), [1, 2, 3, 4]))
  57. with tf.Session() as sess:
  58. ckpt = tf.train.get_checkpoint_state(model_save_path)
  59. saver = tf.train.Saver()
  60. test_case = 'show image'
  61. if ckpt and ckpt.model_checkpoint_path:
  62. saver.restore(sess, ckpt.model_checkpoint_path)
  63. if __name__ == '__main__':
  64. if test_case == 'check_loss':
  65. count = 0
  66. for ys in train.iterate_minibatch(test_data, batch_size, shuffle=True):
  67. xs_l, kspace_l, mask_l, ys_l, k_full_l = train.prep_input(test_data, mask)
  68. loss_value, y_pred = sess.run([loss, y],
  69. feed_dict={y_: ys_l, mask_p: mask_l, kspace_p: kspace_l, kspace_full: k_full_l})
  70. print("The loss of No.{} test data = {}".format(count + 1, loss_value))
  71. y_c = real2complex(y_pred)
  72. xs_c = real2complex(xs_l)
  73. base_mse, test_mse, base_psnr, \
  74. test_psnr, base_ssim, test_ssim = performance(xs_c, y_c, ys)
  75. print("test loss:\t\t{:.6f}".format(loss_value))
  76. print("test psnr:\t\t{:.6f}".format(test_psnr))
  77. print("base psnr:\t\t{:.6f}".format(base_psnr))
  78. print("base mse:\t\t{:.6f}".format(base_mse))
  79. print("test mse:\t\t{:.6f}".format(test_mse))
  80. print("base ssim:\t\t{:.6f}".format(base_ssim))
  81. print("test ssim:\t\t{:.6f}".format(test_ssim))
  82. count += 1
  83. elif test_case == 'show image':
  84. project_root = '.'
  85. figure_save_path = join(project_root, 'result/images/%s' % model_file)
  86. if not os.path.isdir(figure_save_path):
  87. os.makedirs(figure_save_path)
  88. mat_save_path = join(project_root, 'result/mat/%s' % model_file)
  89. if not os.path.isdir(mat_save_path):
  90. os.makedirs(mat_save_path)
  91. quantization_save_path = join(project_root, 'result/quantization/%s' % model_file)
  92. if not os.path.isdir(quantization_save_path):
  93. os.makedirs(quantization_save_path)
  94. Test_MSE = []
  95. Test_PSNR = []
  96. Test_SSIM = []
  97. Base_MSE = []
  98. Base_PSNR = []
  99. Base_SSIM = []
  100. for order in range(0, 100):
  101. ys = test_data[order]
  102. ys = ys[np.newaxis, :]
  103. xs_l, kspace_l, mask_l, ys_l, k_full_l = train.prep_input(ys, mask)
  104. time_start = time.time()
  105. loss_value, y_pred = sess.run([loss, y],
  106. feed_dict={y_: ys_l, mask_p: mask_l,
  107. kspace_p: kspace_l, kspace_full: k_full_l})
  108. time_end = time.time()
  109. y_pred_new = real2complex(y_pred)
  110. xs = real2complex(xs_l)
  111. if order == 0:
  112. order_x = 100
  113. elif order == 1:
  114. order_x = 60
  115. elif order == 2:
  116. order_x = 85
  117. elif order == 6:
  118. order_x = 40
  119. else:
  120. order_x = 55
  121. # order_x = 55 # (order, order_x): (0, 100), (1, 60), (6, 40), (7, 55)
  122. ys_t = ys[:, :, order_x, :]
  123. y_pred_t = y_pred_new[:, :, order_x, :]
  124. xs_t = xs[:, :, order_x, :]
  125. xs_t_error = ys_t - xs_t
  126. y_pred_error = ys_t - y_pred_t
  127. base_mse, test_mse, base_psnr, \
  128. test_psnr, base_ssim, test_ssim = performance(xs, y_pred_new, ys)
  129. print("test time:\t\t{:.6f}".format(time_end - time_start))
  130. print("test loss:\t\t{:.6f}".format(loss_value))
  131. print("test psnr:\t\t{:.6f}".format(test_psnr))
  132. print("base psnr:\t\t{:.6f}".format(base_psnr))
  133. print("base mse:\t\t{:.6f}".format(base_mse))
  134. print("test mse:\t\t{:.6f}".format(test_mse))
  135. print("base ssim:\t\t{:.6f}".format(base_ssim))
  136. print("test ssim:\t\t{:.6f}".format(test_ssim))
  137. base_mse = ("%.6f" % base_mse)
  138. test_mse = ("%.6f" % test_mse)
  139. Test_MSE.append(test_mse)
  140. Test_PSNR.append(test_psnr)
  141. Test_SSIM.append(test_ssim)
  142. Base_MSE.append(base_mse)
  143. Base_PSNR.append(base_psnr)
  144. Base_SSIM.append(base_ssim)
  145. mask_shift = mymath.fftshift(mask, axes=(-1, -2))
  146. gamma = 1
  147. plt.figure(1)
  148. plt.subplot(221)
  149. plt.imshow(exposure.adjust_gamma(np.abs(ys[0][0]), gamma), plt.cm.gray)
  150. plt.xticks([])
  151. plt.yticks([])
  152. plt.title('ground truth')
  153. plt.subplot(222)
  154. plt.imshow(exposure.adjust_gamma(abs(mask_shift[0][0]), gamma), plt.cm.gray)
  155. plt.xticks([])
  156. plt.yticks([])
  157. plt.title('mask')
  158. plt.subplot(223)
  159. plt.imshow(exposure.adjust_gamma(abs(xs[0][0]), gamma), plt.cm.gray)
  160. plt.xticks([])
  161. plt.yticks([])
  162. plt.title("undersampling")
  163. plt.subplot(224)
  164. plt.imshow(exposure.adjust_gamma(abs(y_pred_new[0][0]), gamma), plt.cm.gray)
  165. plt.xticks([])
  166. plt.yticks([])
  167. plt.title("reconstruction")
  168. plt.savefig(join(figure_save_path, 'test%s.tif' % order), dpi=300)
  169. plt.figure(2)
  170. plt.imshow(exposure.adjust_gamma(np.abs(ys[0][0]), gamma), plt.cm.gray)
  171. plt.xticks([])
  172. plt.yticks([])
  173. plt.title('ground truth')
  174. scio.savemat(join(mat_save_path, 'gr%s' % order), {'gr': abs(ys[0][0])})
  175. plt.savefig(join(figure_save_path, 'gr%s.tif' % order), dpi=300)
  176. plt.figure(3)
  177. plt.imshow(exposure.adjust_gamma(abs(xs[0][0]), gamma), plt.cm.gray)
  178. plt.xticks([])
  179. plt.yticks([])
  180. plt.title("undersampling: " + base_mse + ' ' + str(round(base_psnr, 5)) + ' ' + str(
  181. round(base_ssim, 4)))
  182. scio.savemat(join(mat_save_path, 'under%s' % order), {'under': abs(xs[0][0])})
  183. plt.savefig(join(figure_save_path, 'under%s.tif' % order), dpi=300)
  184. plt.figure(4)
  185. plt.imshow(exposure.adjust_gamma(abs(y_pred_new[0][0]), gamma), plt.cm.gray)
  186. plt.xticks([])
  187. plt.yticks([])
  188. plt.title("reconstruction: " + test_mse + ' ' + str(round(test_psnr, 5)) + ' ' + str(
  189. round(test_ssim, 4)))
  190. scio.savemat(join(mat_save_path, 'recon%s' % order), {'recon': abs(y_pred_new[0][0])})
  191. plt.savefig(join(figure_save_path, 'recon%s.tif' % order), dpi=300)
  192. plt.figure(5)
  193. plt.imshow(exposure.adjust_gamma(abs(abs(ys[0][0]) - abs(y_pred_new[0][0])), gamma), vmin=0,
  194. vmax=0.07)
  195. plt.xticks([])
  196. plt.yticks([])
  197. plt.title("error: " + test_mse + ' ' + str(round(test_psnr, 5)) + ' ' + str(
  198. round(test_ssim, 4)))
  199. scio.savemat(join(mat_save_path, 'error%s' % order),
  200. {'error': abs(abs(ys[0][0]) - abs(y_pred_new[0][0]))})
  201. plt.savefig(join(figure_save_path, 'error%s.tif' % order), dpi=300)
  202. plt.figure(6)
  203. plt.subplot(511)
  204. plt.imshow(np.abs(ys_t[0]), plt.cm.gray)
  205. plt.xticks([])
  206. plt.yticks([])
  207. plt.title("gnd_t_y")
  208. plt.subplot(512)
  209. plt.imshow(np.abs(xs_t[0]), plt.cm.gray)
  210. plt.xticks([])
  211. plt.yticks([])
  212. plt.title("under_t_y")
  213. plt.subplot(513)
  214. plt.imshow(np.abs(xs_t_error[0]))
  215. plt.xticks([])
  216. plt.yticks([])
  217. plt.title("under_t_y_error")
  218. plt.subplot(514)
  219. plt.imshow(np.abs(y_pred_t[0]), plt.cm.gray)
  220. plt.xticks([])
  221. plt.yticks([])
  222. plt.title("recon_t_y")
  223. plt.subplot(515)
  224. plt.imshow(np.abs(y_pred_error[0]))
  225. plt.xticks([])
  226. plt.yticks([])
  227. plt.title("recon_t_y_error")
  228. plt.savefig(join(figure_save_path, 't_y%s.tif' % order))
  229. train_plot = np.load(join(project_root, 'models/%s' % model_file, 'train_plot.npy'))
  230. validate_plot = np.load(join(project_root, 'models/%s' % model_file, 'validate_plot.npy'))
  231. [num_train_plot, ] = train_plot.shape
  232. [num_validate_plot, ] = validate_plot.shape
  233. x1 = np.arange(1, num_train_plot + 1)
  234. x2 = np.arange(1, num_validate_plot + 1)
  235. plt.figure(7)
  236. l1, = plt.plot(x1, train_plot)
  237. l2, = plt.plot(x2, validate_plot)
  238. plt.legend(handles=[l1, l2, ], labels=['train loss', 'validation loss'], loc=1)
  239. plt.xlabel('epoch')
  240. plt.ylabel('loss')
  241. plt.title('loss')
  242. if not os.path.exists(join(figure_save_path, 'loss.tif')):
  243. plt.savefig(join(figure_save_path, 'loss.tif'), dpi=300)
  244. #plt.show()
  245. scio.savemat(join(quantization_save_path, 'Test_MSE'), {'test_mse': Test_MSE})
  246. scio.savemat(join(quantization_save_path, 'Test_PSNR'),
  247. {'test_psnr': Test_PSNR})
  248. scio.savemat(join(quantization_save_path, 'Test_SSIM'),
  249. {'test_ssim': Test_SSIM})
  250. scio.savemat(join(quantization_save_path, 'Base_MSE'), {'base_mse': Base_MSE})
  251. scio.savemat(join(quantization_save_path, 'Base_PSNR'),
  252. {'base_psnr': Base_PSNR})
  253. scio.savemat(join(quantization_save_path, 'Base_SSIM'),
  254. {'base_ssim': Base_SSIM})
  255. #elif test_case == "Save image":
  256. else:
  257. print("No checkpoint file found")
  258. def main(argv=None):
  259. test_data = Test_data()
  260. mask = np.load("test_mask.npy")
  261. acc = mask.size / np.sum(mask)
  262. print('Acceleration Rate:{:.2f}'.format(acc))
  263. project_root = '.'
  264. model_file = "model_K5C3_D5C2_TV_iso_e-8_ComplexConv_kLoss1-3_e-4_e-4_e-4_block1-2_e+3"
  265. model = join(project_root, 'models/%s' % model_file)
  266. model_name = "model.ckpt"
  267. evaluate(test_data, mask, model_save_path=model, model_file=model_file)
  268. if __name__ == '__main__':
  269. tf.app.run()

简介

No Description

Python

贡献者 (1)