You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
5.7 kB

  1. import tensorflow as tf
  2. import os
  3. from model_net_v3 import Manifold_Net
  4. from dataset_tfrecord import get_dataset
  5. import argparse
  6. import scipy.io as scio
  7. import mat73
  8. import numpy as np
  9. from datetime import datetime
  10. import time
  11. from tools.tools import video_summary, mse, tempfft
  12. if __name__ == "__main__":
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument('--mode', metavar='str', nargs=1, default=['test'], help='training or test')
  15. parser.add_argument('--batch_size', metavar='int', nargs=1, default=['1'], help='batch size')
  16. parser.add_argument('--niter', metavar='int', nargs=1, default=['5'], help='number of network iterations')
  17. parser.add_argument('--acc', metavar='int', nargs=1, default=['8'], help='accelerate rate')
  18. parser.add_argument('--mask_pattern', metavar='str', nargs=1, default=['cartesian'], help='mask pattern: cartesian, radial, spiral, vsita')
  19. parser.add_argument('--net', metavar='str', nargs=1, default=['Manifold_Net'], help='Manifold_Net')
  20. parser.add_argument('--weight', metavar='str', nargs=1, default=['models/stable/2021-02-28T13-44-00_Manifold_Net_v3_correct_dc_v1_d3c5_acc_8_lr_0.001_N_factor_1.05_rank_17_cartesian/epoch-60/ckpt'], help='modeldir in ./models')
  21. parser.add_argument('--gpu', metavar='int', nargs=1, default=['2'], help='GPU No.')
  22. parser.add_argument('--data', metavar='str', nargs=1, default=['DYNAMIC_V2'], help='dataset name')
  23. parser.add_argument('--learnedSVT', metavar='bool', nargs=1, default=['True'], help='Learned SVT threshold or not')
  24. parser.add_argument('--SVT_favtor', metavar='float', nargs=1, default=['1.05'], help='SVT factor')
  25. args = parser.parse_args()
  26. # GPU setup
  27. os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu[0]
  28. GPUs = tf.config.experimental.list_physical_devices('GPU')
  29. tf.config.experimental.set_memory_growth(GPUs[0], True)
  30. dataset_name = args.data[0].upper()
  31. mode = args.mode[0]
  32. batch_size = int(args.batch_size[0])
  33. niter = int(args.niter[0])
  34. acc = int(args.acc[0])
  35. mask_pattern = args.mask_pattern[0]
  36. net_name = args.net[0]
  37. weight_file = args.weight[0]
  38. learnedSVT = bool(args.learnedSVT[0])
  39. N_factor = float(args.SVT_favtor[0])
  40. print('network: ', net_name)
  41. print('acc: ', acc)
  42. print('load weight file from: ', weight_file)
  43. result_dir = os.path.join('results/stable', weight_file.split('/')[2])
  44. if not os.path.isdir(result_dir):
  45. os.makedirs(result_dir)
  46. logdir = './logs'
  47. TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
  48. summary_writer = tf.summary.create_file_writer(os.path.join(logdir, mode, TIMESTAMP + net_name + str(acc) + '/'))
  49. # prepare undersampling mask
  50. if dataset_name == 'DYNAMIC_V2':
  51. multi_coil = False
  52. mask_size = '18_192_192'
  53. elif dataset_name == 'DYNAMIC_V2_MULTICOIL':
  54. multi_coil = True
  55. mask_size = '18_192_192'
  56. elif dataset_name == 'FLOW':
  57. multi_coil = False
  58. mask_size = '20_180_180'
  59. if acc == 8:
  60. mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/'+mask_pattern + '_' + mask_size + '_acc8.mat')['mask']
  61. elif acc == 10:
  62. mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/cartesian_' + mask_size + '_acs4_acc10.mat')['mask']
  63. elif acc == 12:
  64. mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/'+mask_pattern + '_' + mask_size + '_acc12.mat')['mask']
  65. """
  66. if acc == 8:
  67. mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_8.mat')['mask']
  68. elif acc == 10:
  69. mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_10.mat')['mask']
  70. elif acc == 12:
  71. mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_12.mat')['mask']
  72. elif acc == 16:
  73. mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_16.mat')['mask']
  74. elif acc == 20:
  75. mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_20.mat')['mask']
  76. elif acc == 24:
  77. mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_24.mat')['mask']
  78. """
  79. mask = tf.cast(tf.constant(mask), tf.complex64)
  80. # prepare dataset
  81. dataset = get_dataset(mode, dataset_name, batch_size, shuffle=False)
  82. # initialize network
  83. if net_name == 'Manifold_Net':
  84. net = Manifold_Net(mask, niter, learnedSVT, N_factor)
  85. net.load_weights(weight_file)
  86. # Iterate over epochs.
  87. for i, sample in enumerate(dataset):
  88. # forward
  89. k0 = None
  90. csm = None
  91. #with tf.GradientTape() as tape:
  92. if multi_coil:
  93. k0, label, csm = sample
  94. else:
  95. k0, label = sample
  96. label_abs = tf.abs(label)
  97. k0 = k0 * mask
  98. t0 = time.time()
  99. recon = net(k0, csm)
  100. t1 = time.time()
  101. recon_abs = tf.abs(recon)
  102. loss_total = mse(recon, label)
  103. tf.print(i, 'mse =', loss_total.numpy(), 'time = ', t1-t0)
  104. result_file = os.path.join(result_dir, 'recon_'+str(i+1)+'.mat')
  105. datadict = {'recon': np.squeeze(tf.transpose(recon, [0,2,3,1]).numpy())}
  106. scio.savemat(result_file, datadict)
  107. # record gif
  108. with summary_writer.as_default():
  109. combine_video = tf.concat([label_abs[0:1,:,:,:], recon_abs[0:1,:,:,:]], axis=0).numpy()
  110. combine_video = np.expand_dims(combine_video, -1)
  111. video_summary('convin-'+str(i+1), combine_video, step=1, fps=10)

简介

No Description

Python

贡献者 (1)