You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

100 lines
4.4 kB

  1. """Learned primal-dual method."""
  2. import os
  3. import h5py
  4. import time
  5. from os.path import join, exists
  6. from skimage import io
  7. import tensorflow as tf
  8. import numpy as np
  9. import scipy.io as sio
  10. from numpy.fft import fft2, ifft2, ifftshift, fftshift
  11. import matplotlib.pyplot as plt
  12. from skimage.measure import compare_ssim as ssim
  13. from skimage.measure import compare_psnr as psnr
  14. from train_v2 import generate_data
  15. from dataset import get_test_data
  16. from model import getMultiCoilImage, getCoilCombineImage
  17. def get_data_sos(label, mask_t, bacth_size_mask=4):
  18. batch, nx, ny, nt, coil = label.shape
  19. nx, ny, nt = mask_t.shape
  20. mask_t = np.transpose(mask_t, (2, 0, 1))
  21. #mask = mask_t[0:bacth_size_mask, ...]
  22. mask = np.tile(mask_t[:, :, :, np.newaxis], (1, 1, 1, coil)) #nt, nx, ny, coil
  23. label = np.squeeze(label)
  24. label = np.transpose(label, (2, 3, 0, 1)) # nt, coil, nx, ny
  25. k_full_shift = fft2(label, axes=(-2, -1)) # batch, coil, nx, ny
  26. #k_full_shift = np.tile(k_full_shift, (bacth_size_mask, 1, 1, 1))
  27. k_full_shift = np.transpose(k_full_shift, (0, 2, 3, 1)) # nt, nx, ny, coil
  28. k_und_shift = k_full_shift * mask
  29. label_sos = np.sum(abs(label**2), axis=1)**(1/2)
  30. #label_sos = np.tile(label_sos, [bacth_size_mask, 1, 1])
  31. mask = mask[:, :, :, 0]
  32. return k_und_shift, label_sos, mask
  33. def evaluate(test_data, mask_t, model_save_path, model_file):
  34. result_dir = os.path.join('results', model_file+'_test_on_uniform_mask' )
  35. if not os.path.exists(result_dir):
  36. os.makedirs(result_dir)
  37. with tf.Graph() .as_default() as g:
  38. y_m = tf.placeholder(tf.complex64, (None, 192, 192, 20), "y_m")
  39. mask = tf.placeholder(tf.complex64, (None, 192, 192), "mask")
  40. x_true = tf.placeholder(tf.float32, (None, 192, 192), "x_true")
  41. x_pred = getCoilCombineImage(y_m, mask, n_iter=8)
  42. residual = x_pred - x_true
  43. #residual = tf.stack([tf.real(residual), tf.imag(residual)], axis=4)
  44. loss = tf.reduce_mean(residual ** 2)
  45. with tf.Session() as sess:
  46. #ckpt = tf.train.get_checkpoint_state(model_save_path)
  47. saver = tf.train.Saver()
  48. #if ckpt and ckpt.model_checkpoint_path:
  49. saver.restore(sess, model_save_path)
  50. count = 0
  51. recon_total = np.zeros((test_data.shape[0], test_data.shape[3], test_data.shape[1], test_data.shape[2]))
  52. for ys in generate_data(test_data, BATCH_SIZE=1, shuffle=False):
  53. und_kspace, label, mask_d = get_data_sos(ys, mask_t)
  54. im_start = time.time()
  55. loss_value, pred = sess.run([loss, x_pred],
  56. feed_dict={y_m: und_kspace,
  57. mask: mask_d,
  58. x_true: label})
  59. recon_total[count, ...] = pred
  60. count += 1
  61. sio.savemat(join(result_dir, 'recon_%d.mat' % count), {'im_recon': pred})
  62. print("The loss of No.{} test data = {}".format(count, loss_value))
  63. #sio.savemat(join(result_dir, 'recon_total.mat'), {'recon': recon_total})
  64. def main(argv=None):
  65. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  66. data_dir = '/data0/ziwen/data/h5/multi_coil_strategy_v0'
  67. test_data = get_test_data(data_dir)
  68. #mask_t = sio.loadmat(join('mask', 'UIH_TIS_mask_t_192_192_X400_ACS_16_R_3.56.mat'))['mask_t']
  69. #mask_t = sio.loadmat(join('mask', 'random_gauss_mask_t_192_192_16_ACS_16_R_3.64.mat'))['mask_t']
  70. #mask_t = sio.loadmat(join('mask', 'random_gauss_mask_t_192_192_16_ACS_16_R_5.42.mat'))['mask_t']
  71. mask_t = sio.loadmat(join('mask', 'UIH_TIS_mask_t_192_192_X400_ACS_16_R_3.56.mat'))['mask_t']
  72. mask_t = np.fft.fftshift(mask_t, axes=(0, 1))
  73. acc = mask_t.size / np.sum(mask_t)
  74. print('Acceleration Rate:{:.2f}'.format(acc))
  75. project_root = '.'
  76. model_file = "Unsupervised learning via TIS_mask_5t_multi-coil_2149_train_on_random_mask_AMAX"
  77. model_name = "Unsupervised learning via TIS_mask_5t_multi-coil_2149_train_on_random_mask_AMAX.ckpt"
  78. model = join(project_root, 'checkpoints/%s' % model_file, model_name)
  79. evaluate(test_data, mask_t, model_save_path=model, model_file=model_file)
  80. if __name__ == '__main__':
  81. tf.app.run()

简介

No Description

Unity3D Asset other

贡献者 (1)