You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
5.4 kB

  1. import tensorflow as tf
  2. import os
  3. from model import LplusS_Net, S_Net, SLR_Net
  4. import argparse
  5. import scipy.io as scio
  6. import mat73
  7. import numpy as np
  8. from datetime import datetime
  9. import time
  10. from tools.tools import video_summary, mse, tempfft
  11. if __name__ == "__main__":
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument('--mode', metavar='str', nargs=1, default=['test'], help='training or test')
  14. parser.add_argument('--batch_size', metavar='int', nargs=1, default=['1'], help='batch size')
  15. parser.add_argument('--niter', metavar='int', nargs=1, default=['10'], help='number of network iterations')
  16. parser.add_argument('--acc', metavar='int', nargs=1, default=['9'], help='accelerate rate')
  17. parser.add_argument('--net', metavar='str', nargs=1, default=['SLRNET'], help='SLR Net or S Net')
  18. parser.add_argument('--weight', metavar='str', nargs=1, default=['models/stable/2020-11-05T19-31-19SLRNET_OCMR8_epoch_50_lr_0.0001_ocmr_fine_tuning/epoch-50/ckpt'], help='modeldir in ./models')
  19. parser.add_argument('--gpu', metavar='int', nargs=1, default=['0'], help='GPU No.')
  20. parser.add_argument('--data', metavar='str', nargs=1, default=['OCMR'], help='dataset name')
  21. parser.add_argument('--learnedSVT', metavar='bool', nargs=1, default=['True'], help='Learned SVT threshold or not')
  22. args = parser.parse_args()
  23. # GPU setup
  24. os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu[0]
  25. GPUs = tf.config.experimental.list_physical_devices('GPU')
  26. tf.config.experimental.set_memory_growth(GPUs[0], True)
  27. dataset_name = args.data[0].upper()
  28. mode = args.mode[0]
  29. batch_size = int(args.batch_size[0])
  30. niter = int(args.niter[0])
  31. acc = int(args.acc[0])
  32. net_name = args.net[0].upper()
  33. weight_file = args.weight[0]
  34. learnedSVT = bool(args.learnedSVT[0])
  35. print('network: ', net_name)
  36. print('acc: ', acc)
  37. print('load weight file from: ', weight_file)
  38. result_dir = os.path.join('results/stable/prospective/huang', weight_file.split('/')[2] + net_name + str(acc))
  39. if not os.path.isdir(result_dir):
  40. os.makedirs(result_dir)
  41. #logdir = './logs'
  42. TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
  43. #summary_writer = tf.summary.create_file_writer(os.path.join(logdir, mode, TIMESTAMP + net_name + str(acc) + '/'))
  44. # prepare undersampling mask
  45. #if dataset_name == 'DYNAMIC_V2':
  46. # multi_coil = False
  47. # mask_size = '18_192_192'
  48. #elif dataset_name == 'DYNAMIC_V2_MULTICOIL':
  49. # multi_coil = True
  50. # mask_size = '18_192_192'
  51. #elif dataset_name == 'FLOW':
  52. # multi_coil = False
  53. # mask_size = '20_180_180'
  54. #if acc == 8:
  55. # mask = scio.loadmat('mask_newdata/cartesian_' + mask_size + '_acs4_acc8.mat')['mask']
  56. #elif acc == 10:
  57. # mask = scio.loadmat('mask_newdata/cartesian_' + mask_size + '_acs4_acc10.mat')['mask']
  58. #elif acc == 12:
  59. # mask = scio.loadmat('mask_newdata/cartesian_' + mask_size + '_acs4_acc12.mat')['mask']
  60. #mask = tf.cast(tf.constant(mask), tf.complex64)
  61. # prepare dataset
  62. #dataset = get_dataset(mode, dataset_name, batch_size, shuffle=False)
  63. for i in range(2,8):
  64. #k0 = scio.loadmat('ku'+str(i)+'.mat')['ku'] # nx,ny,nt,nc
  65. k0 = mat73.loadmat('/data1/wenqihuang/LplusSNet/data/prospective/ku'+str(i)+'.mat')['ku'] # nx,ny,nt,nc
  66. #csm = mat73.loadmat('data/prospective/csm_adaptive.mat')['csm']
  67. #csm = scio.loadmat('csm'+str(i)+'.mat')['csm'] # nx, ny, nc
  68. csm = mat73.loadmat('/data1/wenqihuang/LplusSNet/data/prospective/csm'+str(i)+'.mat')['csm'] # nx, ny, nc
  69. #csm = mat73.loadmat('data/prospective/csm1.mat')['csm']
  70. k0 = k0 * 420
  71. k0 = tf.convert_to_tensor(k0, dtype=tf.complex64)
  72. csm = tf.convert_to_tensor(csm, dtype=tf.complex64)
  73. csm = tf.expand_dims(csm, 2)
  74. k0 = tf.expand_dims(k0, 0) #batch
  75. csm = tf.expand_dims(csm, 0) #batch
  76. k0 = tf.transpose(k0, [0,4,3,1,2]) # nb, nx, ny, nt, nc -> nb, nc, nt, nx, ny
  77. csm = tf.transpose(csm, [0,4,3,1,2])
  78. #k0 = k0[:,:,0:18,:,:]
  79. #csm = csm[:,:,0:18,:,:]
  80. mask = tf.cast(tf.abs(k0) > 0, tf.complex64)
  81. # initialize network
  82. net = SLR_Net(mask, niter, learnedSVT)
  83. net.load_weights(weight_file)
  84. # Iterate over epochs.
  85. # forward
  86. #with tf.GradientTape() as tape:
  87. t0 = time.time()
  88. recon, X_SYM = net(k0, csm)
  89. t1 = time.time()
  90. recon_abs = tf.abs(recon)
  91. #loss_total = mse(LSrecon, LplusS_label)
  92. #tf.print(i, 'mse =', loss_total.numpy(), 'time = ', t1-t0)
  93. result_file = os.path.join(result_dir, 'recon_'+str(i)+'.mat')
  94. datadict = {
  95. 'recon': np.squeeze(tf.transpose(recon, [0,2,3,1]).numpy())
  96. }
  97. scio.savemat(result_file, datadict)
  98. # record gif
  99. #with summary_writer.as_default():
  100. # if net_name[0:4] == 'SNET':
  101. # combine_video = tf.concat([LplusS_label_abs[0:1,:,:,:], recon_abs[0:1,:,:,:]], axis=0).numpy()
  102. # else:
  103. # combine_video = tf.concat([LplusS_label_abs[0:1,:,:,:], recon_abs[0:1,:,:,:], L_recon_abs[0:1,:,:,:], S_recon_abs[0:1,:,:,:]], axis=0).numpy()
  104. # combine_video = np.expand_dims(combine_video, -1)
  105. # video_summary('convin-'+str(i+1), combine_video, step=1, fps=10)

简介

No Description

Python

贡献者 (1)