Browse Source

first commit

master
keziwen 2 years ago
parent
commit
680838fab2
13 changed files with 508 additions and 0 deletions
  1. BIN
      checkpoints/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX.ckpt.data-00000-of-00001
  2. BIN
      checkpoints/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX.ckpt.index
  3. BIN
      checkpoints/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX.ckpt.meta
  4. +2
    -0
      checkpoints/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX/checkpoint
  5. +69
    -0
      dataset.py
  6. +100
    -0
      evaluate_v2.py
  7. BIN
      mask/UIH_TIS_mask_t_192_192_X400_ACS_16_R_3.56.mat
  8. BIN
      mask/UIH_TIS_mask_t_192_192_X800_ACS_16_R_5.42.mat
  9. BIN
      mask/UIH_TIS_random_by_hand_mask_t_192_192_X400_ACS_16_R_3.56.mat
  10. BIN
      mask/random_gauss_mask_t_192_192_16_ACS_16_R_3.64.mat
  11. BIN
      mask/random_gauss_mask_t_192_192_16_ACS_16_R_5.42.mat
  12. +185
    -0
      model.py
  13. +152
    -0
      train_v2.py

BIN
checkpoints/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX.ckpt.data-00000-of-00001 View File


BIN
checkpoints/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX.ckpt.index View File


BIN
checkpoints/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX.ckpt.meta View File


+ 2
- 0
checkpoints/Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX/checkpoint View File

@@ -0,0 +1,2 @@
model_checkpoint_path: "Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX.ckpt"
all_model_checkpoint_paths: "Unsupervised learning via TIS_mask_5t_multi-coil_2149_AMAX.ckpt"

+ 69
- 0
dataset.py View File

@@ -0,0 +1,69 @@
import h5py
import numpy as np
import os


def get_train_data(data_dir):
with h5py.File(os.path.join(data_dir, './train_real.h5')) as f:
label_real = f['train_real'][:]
num, coil, ny, nx = label_real.shape
#data_real = np.transpose(data_real, (0, 2, 1))
with h5py.File(os.path.join(data_dir, './train_imag.h5')) as f:
label_imag = f['train_imag'][:]
#data_imag = np.transpose(kspace_imag, (0, 2, 1))
label = label_real + 1j * label_imag
label = np.transpose(label, (0, 3, 2, 1))

num_train = 1900
num_validate = 249
train_label = label[0:num_train]
validate_label = label[num_train:num_train + num_validate]
return train_label, validate_label

def get_test_data(data_dir):
with h5py.File(os.path.join(data_dir, './test_real_45.h5')) as f:
test_real = f['test_real'][:]
num, nc, nt, ny, nx = test_real.shape
#data_real = np.transpose(data_real, (0, 2, 1))
with h5py.File(os.path.join(data_dir, './test_imag_45.h5')) as f:
test_imag = f['test_imag'][:]
#data_imag = np.transpose(kspace_imag, (0, 2, 1))
test_label = test_real + 1j * test_imag
test_label = np.transpose(test_label, (0, 4, 3, 2, 1))

return test_label

def get_fine_tuning_data(data_dir):
with h5py.File(os.path.join(data_dir, './fine_tuning_real.h5')) as f:
fine_tuning_real = f['fine_tuning_real'][:]
num, nc, nt, ny, nx = fine_tuning_real.shape
#data_real = np.transpose(data_real, (0, 2, 1))
with h5py.File(os.path.join(data_dir, './fine_tuning_imag.h5')) as f:
fine_tuning_imag = f['fine_tuning_imag'][:]
#data_imag = np.transpose(kspace_imag, (0, 2, 1))
fine_tuning_label = fine_tuning_real + 1j * fine_tuning_imag
fine_tuning_label = np.transpose(fine_tuning_label, (0, 4, 3, 2, 1))

return fine_tuning_label

def get_train_data_UIH(data_dir):
with h5py.File(os.path.join(data_dir, './UIH_Data.h5')) as f:
UIH_real = f['trnData_real'][:]
UIH_img = f['trnData_img'][:]
mask_t = f['Mask_1D_x4'][:]
UIH_data = UIH_real + 1j * UIH_img
UIH_data = np.transpose(UIH_data, (0, 2, 3, 1))

with h5py.File(os.path.join(data_dir, './trnData_CUBE.hdf5')) as f:
GE_real = f['trnData_real'][:]
GE_img = f['trnData_img'][:]
GE_data = GE_real + 1j * GE_img
GE_data = np.transpose(GE_data, (0, 2, 3, 1))

kspace = np.concatenate((UIH_data, GE_data))

# num_train = 500
# num_validate = 110
# train_kspace = kspace[0:num_train]
# validate_kspace = kspace[num_train:num_train + num_validate]
return kspace, mask_t

+ 100
- 0
evaluate_v2.py View File

@@ -0,0 +1,100 @@
"""Learned primal-dual method."""
import os
import h5py
import time
from os.path import join, exists
from skimage import io
import tensorflow as tf
import numpy as np
import scipy.io as sio
from numpy.fft import fft2, ifft2, ifftshift, fftshift
import matplotlib.pyplot as plt
from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_psnr as psnr

from train_v2 import generate_data
from dataset import get_test_data
from model import getMultiCoilImage, getCoilCombineImage

def get_data_sos(label, mask_t, bacth_size_mask=4):
batch, nx, ny, nt, coil = label.shape
nx, ny, nt = mask_t.shape
mask_t = np.transpose(mask_t, (2, 0, 1))
#mask = mask_t[0:bacth_size_mask, ...]
mask = np.tile(mask_t[:, :, :, np.newaxis], (1, 1, 1, coil)) #nt, nx, ny, coil

label = np.squeeze(label)
label = np.transpose(label, (2, 3, 0, 1)) # nt, coil, nx, ny
k_full_shift = fft2(label, axes=(-2, -1)) # batch, coil, nx, ny
#k_full_shift = np.tile(k_full_shift, (bacth_size_mask, 1, 1, 1))
k_full_shift = np.transpose(k_full_shift, (0, 2, 3, 1)) # nt, nx, ny, coil
k_und_shift = k_full_shift * mask
label_sos = np.sum(abs(label**2), axis=1)**(1/2)
#label_sos = np.tile(label_sos, [bacth_size_mask, 1, 1])
mask = mask[:, :, :, 0]
return k_und_shift, label_sos, mask


def evaluate(test_data, mask_t, model_save_path, model_file):
result_dir = os.path.join('results', model_file+'_test_on_uniform_mask' )
if not os.path.exists(result_dir):
os.makedirs(result_dir)

with tf.Graph() .as_default() as g:
y_m = tf.placeholder(tf.complex64, (None, 192, 192, 20), "y_m")
mask = tf.placeholder(tf.complex64, (None, 192, 192), "mask")
x_true = tf.placeholder(tf.float32, (None, 192, 192), "x_true")

x_pred = getCoilCombineImage(y_m, mask, n_iter=8)

residual = x_pred - x_true
#residual = tf.stack([tf.real(residual), tf.imag(residual)], axis=4)
loss = tf.reduce_mean(residual ** 2)

with tf.Session() as sess:

#ckpt = tf.train.get_checkpoint_state(model_save_path)
saver = tf.train.Saver()

#if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, model_save_path)

count = 0
recon_total = np.zeros((test_data.shape[0], test_data.shape[3], test_data.shape[1], test_data.shape[2]))

for ys in generate_data(test_data, BATCH_SIZE=1, shuffle=False):
und_kspace, label, mask_d = get_data_sos(ys, mask_t)
im_start = time.time()
loss_value, pred = sess.run([loss, x_pred],
feed_dict={y_m: und_kspace,
mask: mask_d,
x_true: label})
recon_total[count, ...] = pred
count += 1
sio.savemat(join(result_dir, 'recon_%d.mat' % count), {'im_recon': pred})

print("The loss of No.{} test data = {}".format(count, loss_value))
#sio.savemat(join(result_dir, 'recon_total.mat'), {'recon': recon_total})


def main(argv=None):
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
data_dir = '/data0/ziwen/data/h5/multi_coil_strategy_v0'
test_data = get_test_data(data_dir)
#mask_t = sio.loadmat(join('mask', 'UIH_TIS_mask_t_192_192_X400_ACS_16_R_3.56.mat'))['mask_t']
#mask_t = sio.loadmat(join('mask', 'random_gauss_mask_t_192_192_16_ACS_16_R_3.64.mat'))['mask_t']
#mask_t = sio.loadmat(join('mask', 'random_gauss_mask_t_192_192_16_ACS_16_R_5.42.mat'))['mask_t']
mask_t = sio.loadmat(join('mask', 'UIH_TIS_mask_t_192_192_X400_ACS_16_R_3.56.mat'))['mask_t']
mask_t = np.fft.fftshift(mask_t, axes=(0, 1))
acc = mask_t.size / np.sum(mask_t)
print('Acceleration Rate:{:.2f}'.format(acc))

project_root = '.'
model_file = "Unsupervised learning via TIS_mask_5t_multi-coil_2149_train_on_random_mask_AMAX"
model_name = "Unsupervised learning via TIS_mask_5t_multi-coil_2149_train_on_random_mask_AMAX.ckpt"
model = join(project_root, 'checkpoints/%s' % model_file, model_name)
evaluate(test_data, mask_t, model_save_path=model, model_file=model_file)


if __name__ == '__main__':
tf.app.run()

BIN
mask/UIH_TIS_mask_t_192_192_X400_ACS_16_R_3.56.mat View File


BIN
mask/UIH_TIS_mask_t_192_192_X800_ACS_16_R_5.42.mat View File


BIN
mask/UIH_TIS_random_by_hand_mask_t_192_192_X400_ACS_16_R_3.56.mat View File


BIN
mask/random_gauss_mask_t_192_192_16_ACS_16_R_3.64.mat View File


BIN
mask/random_gauss_mask_t_192_192_16_ACS_16_R_5.42.mat View File


+ 185
- 0
model.py View File

@@ -0,0 +1,185 @@
import tensorflow as tf

def apply_conv(x, n_out, name):
n_in = x.get_shape()[-1].value

with tf.name_scope(name) as scope:
kernel = tf.compat.v1.get_variable(scope + "w",
shape=[3, 3, n_in, n_out],
dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer())
conv = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='SAME')
bias_init_var = tf.constant(0.0, dtype=tf.float32, shape=[n_out])
biases = tf.Variable(bias_init_var, trainable=True, name='b')
z = tf.nn.bias_add(conv, biases)
return z


def apply_conv_3D(x, n_out, name):
n_in = x.get_shape()[-1].value

with tf.name_scope(name) as scope:
kernel = tf.get_variable(scope + "w",
shape=[3, 3, 3, n_in, n_out],
dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer())
conv = tf.nn.conv3d(x, kernel, strides=[1, 1, 1, 1, 1], padding='SAME')
bias_init_var = tf.constant(0.0, dtype=tf.float32, shape=[n_out])
biases = tf.Variable(bias_init_var, trainable=True, name='b')
z = tf.nn.bias_add(conv, biases)
return z

def conv_op(input_op, name, kh, kw, n_out, dh, dw, ifactivate):
n_in = input_op.get_shape()[-1].value
with tf.name_scope(name) as scope:
kernel = tf.get_variable(scope + 'w', shape=[kh, kw, n_in, n_out], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer())

conv = tf.nn.conv2d(input_op, kernel, strides=[1, dh, dw, 1], padding='SAME')
bias_init_var = tf.constant(0.0, dtype=tf.float32, shape=[n_out])
biases = tf.Variable(bias_init_var, trainable=True, name='b')
z = tf.nn.bias_add(conv, biases)
if ifactivate is True:
activation = tf.nn.relu(z, name=scope)
else:
activation = z
return activation

def real2complex(input_op, inv=False):
if inv == False:
return tf.complex(input_op[:, :, :, 0], input_op[:, :, :, 1])
else:
input_real = tf.cast(tf.real(input_op), dtype=tf.float32)
input_imag = tf.cast(tf.imag(input_op), dtype=tf.float32)
return tf.stack([input_real, input_imag], axis=3)



def getADMM_2D(y_m, mask, n_iter, n_coil):
kdata = tf.stack([tf.math.real(y_m), tf.math.imag(y_m)], axis=3)
beta = tf.zeros_like(kdata)
z = tf.zeros_like(kdata)
x = tf.zeros_like(kdata)

for iter in range(n_iter):
with tf.compat.v1.variable_scope('recon_layer_{}_{}'.format(n_coil, iter)):
# y_cplx = tf.complex(y_m[:, :, :, 0], y_m[:, :, :, 1])
# evalop_cplx = tf.ifft2d(y_cplx)
# evalop = tf.stack([tf.real(evalop_cplx), tf.imag(evalop_cplx)], axis=3)
# update = tf.concat([evalop, z-beta], axis=-1)

x_cplx = tf.complex(x[..., 0], x[..., 1])
Ax = tf.signal.fft2d(x_cplx) * mask
evalop_k = tf.stack([tf.math.real(Ax), tf.math.imag(Ax)], axis=3)
update = tf.concat([evalop_k, kdata], axis=-1)
update = tf.nn.relu(apply_conv(update, n_out=16, name='update1'), name='relu_1')
update = apply_conv(update, n_out=2, name='update2')

update_cplx = tf.complex(update[:, :, :, 0], update[:, :, :, 1])
input1_cplx = tf.signal.ifft2d(update_cplx * mask)
input1 = tf.stack([tf.math.real(input1_cplx), tf.math.imag(input1_cplx)], axis=3)

v = z - beta
update = tf.concat([v, x, input1], axis=-1)

update = tf.nn.relu(apply_conv(update, n_out=16, name='update3'), name='relu_1')
update = tf.nn.relu(apply_conv(update, n_out=16, name='update4'), name='relu_2')
update = apply_conv(update, n_out=2, name='update5')

x = x + update

with tf.compat.v1.variable_scope('denoise_layer_{}'.format(iter)):
update = tf.nn.relu(apply_conv(x + beta, n_out=8, name='update6'), name='relu_1')
update = tf.nn.relu(apply_conv(update, n_out=8, name='update7'), name='relu_2')
update = apply_conv(update, n_out=2, name='update8')
z = x + beta + update

with tf.compat.v1.variable_scope('update_layer_{}'.format(iter)):
eta = tf.Variable(tf.constant(1, dtype=tf.float32), name='eta')
beta = beta + tf.multiply(eta, x - z)
output = tf.complex(x[..., 0], x[..., 1])
return output

def DC_CNN_2D(input_image_Net, mask, kspace):
# D5C5
temp = input_image_Net
for i in range(5):
conv_1 = conv_op(temp, name='conv'+str(i+1)+'_1', kh=3, kw=3, n_out=16, dh=1, dw=1, ifactivate=True)
conv_2 = conv_op(conv_1, name='conv'+str(i+1)+'_2', kh=3, kw=3, n_out=16, dh=1, dw=1, ifactivate=True)
conv_3 = conv_op(conv_2, name='conv'+str(i+1)+'_3', kh=3, kw=3, n_out=16, dh=1, dw=1, ifactivate=True)
conv_4 = conv_op(conv_3, name='conv'+str(i+1)+'_4', kh=3, kw=3, n_out=16, dh=1, dw=1, ifactivate=True)
conv_5 = conv_op(conv_4, name='conv'+str(i+1)+'_5', kh=3, kw=3, n_out=2, dh=1, dw=1, ifactivate=False)
block = temp + conv_5
block_dc = dc_DCCNN(block, ku_complex=kspace, mask=mask)
temp = block_dc
return temp

def getMultiCoilImage(y_m_multicoil, mask, n_iter):
x = []
for c in range(20):
y_m = y_m_multicoil[:, :, :, c]
output_c = getADMM_2D(y_m, mask, n_iter, c)
#output_c = dc(output_c, y_m, mask)
x.append(output_c)
output = tf.stack([x[i] for i in range(20)], axis=-1)
return output

def getCoilCombineImage(y_m_multicoil, mask, n_iter):
x = []
nSlice, nFE, nPE, nCoil = y_m_multicoil.shape
for c in range(nCoil):
y_m = y_m_multicoil[:, :, :, c]
output_c = getADMM_2D(y_m, mask, n_iter, c)
output_c = dc(output_c, y_m, mask)
x.append(output_c)
output = tf.stack([x[i] for i in range(nCoil)], axis=-1)
# output: complex tensor: batch, nx, ny, 20

x = tf.concat([tf.math.real(output), tf.math.imag(output)], axis=-1)
x = tf.nn.relu(apply_conv(x, n_out=32, name='recon_conv1'))
x = tf.nn.relu(apply_conv(x, n_out=32, name='recon_conv2'))
x = apply_conv(x, n_out=2, name='recon_conv3')
x = tf.abs(tf.complex(x[..., 0], x[..., 1]))
return x

def getCoilCombineImage_DCCNN(y_m_multicoil, mask, n_iter):
x = []
nSlice, nFE, nPE, nCoil = y_m_multicoil.shape
for c in range(nCoil):
y_m = y_m_multicoil[:, :, :, c]
x_m = tf.signal.ifft2d(y_m)
x_m = real2complex(x_m, inv=True)
output_c = DC_CNN_2D(x_m, mask, y_m)
output_c = real2complex(output_c)
x.append(output_c)
output = tf.stack([x[i] for i in range(nCoil)], axis=-1)
# output: complex tensor: batch, nx, ny, 20

x = tf.concat([tf.math.real(output), tf.math.imag(output)], axis=-1)
x = tf.nn.relu(apply_conv(x, n_out=32, name='recon_conv1'))
x = tf.nn.relu(apply_conv(x, n_out=32, name='recon_conv2'))
x = apply_conv(x, n_out=2, name='recon_conv3')
x = tf.abs(tf.complex(x[..., 0], x[..., 1]))
return x

def dc(x0_complex, ku_complex, mask):
k0_complex = tf.signal.fft2d(x0_complex, 'fft2')
k0_complex_dc = tf.multiply((1-mask), k0_complex) + ku_complex
x0_dc = tf.signal.ifft2d(k0_complex_dc)
return x0_dc

def dc_DCCNN(input_op, ku_complex, mask):
image = real2complex(input_op)
k0_complex = tf.signal.fft2d(image, 'fft2')
k0_complex_dc = tf.multiply((1-mask), k0_complex) + ku_complex
input_dc = tf.signal.ifft2d(k0_complex_dc)
input_dc = real2complex(input_dc, inv=True)
return input_dc








+ 152
- 0
train_v2.py View File

@@ -0,0 +1,152 @@
import tensorflow as tf
import numpy as np
from numpy.fft import fft2, ifft2, ifftshift, fftshift

import os
from skimage import io
import time

from dataset import get_train_data
from model import getMultiCoilImage, getCoilCombineImage, getCoilCombineImage_DCCNN
import scipy.io as sio
from os.path import join

def generate_data(x, BATCH_SIZE=1, shuffle=False):
"""Generate a set of random data."""
n = len(x)
if shuffle:
x = np.random.permutation(x)

for j in range(0, n, BATCH_SIZE):
yield x[j:j+BATCH_SIZE]


def get_data_sos(label, mask_t, bacth_size_mask=4):
batch, nx, ny, coil = label.shape
nx, ny, nt = mask_t.shape
mask_t = np.transpose(mask_t, (2, 0, 1))
mask = mask_t[0:bacth_size_mask, ...]
mask = np.tile(mask[:, :, :, np.newaxis], (1, 1, 1, coil)) #batch_size_mask, nx, ny, coil

label = np.transpose(label, (0, 3, 1, 2))
k_full_shift = fft2(label, axes=(-2, -1)) # batch, coil, nx, ny
k_full_shift = np.tile(k_full_shift, (bacth_size_mask, 1, 1, 1))
k_full_shift = np.transpose(k_full_shift, (0, 2, 3, 1)) # batch_size_mask, nx, ny, coil
k_und_shift = k_full_shift * mask
label_sos = np.sum(abs(label**2), axis=1)**(1/2)
label_sos = np.tile(label_sos, [bacth_size_mask, 1, 1])
mask = mask[:, :, :, 0]
return k_und_shift, label_sos, mask

if __name__ == "__main__":
lr_base = 1e-03
BATCH_SIZE = 1
lr_decay_rate = 0.98
# EPOCHS = 200
num_epoch = 200
num_train = 1900
num_validate = 249
Nx = 192
Ny = 192
Nc = 20

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

base_dir = '.'
name = 'Unsupervised learning via TIS_mask_5t_multi-coil_2149_train_on_random_4X_DC_CNN_AMAX'
# name = os.path.splitext(os.path.basename(__file__))[0]
# model_save_path = os.path.join(base_dir, name)
# if not os.path.isdir(model_save_path):
# os.makedirs(model_save_path)

checkpoint_dir = os.path.join(base_dir, 'checkpoints/%s' % name)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)

checkpoint_path = os.path.join(checkpoint_dir, '{}.ckpt'.format(name))

# data for train
data_dir = '/data0/ziwen/data/h5/multi_coil_strategy_v0'
train_label, validate_label = get_train_data(data_dir)
#train_label = np.zeros([100, 192, 192, 20])
#validate_label = np.zeros([100, 192, 192, 20])

mask_t = sio.loadmat(join('mask', 'random_gauss_mask_t_192_192_16_ACS_16_R_3.64.mat'))['mask_t']
mask_t = np.fft.fftshift(mask_t, axes=(0, 1))
print('Acceleration factor: {}'.format(mask_t.size/float(mask_t.sum())))
# mk = sio.loadmat('Random1D_256_256_R6.mat')
# mask_t = np.fft.fftshift(mk['mask'], axes=(-1, -2))

y_m = tf.compat.v1.placeholder(tf.complex64, (None, Nx, Ny, Nc), "y_m")
mask = tf.compat.v1.placeholder(tf.complex64, (None, Nx, Ny), "mask")
x_true = tf.compat.v1.placeholder(tf.float32, (None, Nx, Ny), "x_true")

x_pred = getCoilCombineImage_DCCNN(y_m, mask, n_iter=8)

with tf.name_scope("loss"):
residual = x_pred - x_true
#residual = tf.stack([tf.real(residual_cplx), tf.imag(residual_cplx)], axis=4)
Y = tf.reduce_mean(residual ** 2)
loss = Y

global_step = tf.Variable(0., trainable=False)
lr = tf.compat.v1.train.exponential_decay(lr_base,
global_step=global_step,
decay_steps=num_train // BATCH_SIZE,
decay_rate=lr_decay_rate,
staircase=False)
with tf.name_scope("train"):
train_step = tf.compat.v1.train.AdamOptimizer(lr).minimize(loss, global_step=global_step)

saver = tf.compat.v1.train.Saver()
with tf.Session() as sess:


init = tf.compat.v1.global_variables_initializer()
sess.run(init)

# saver = tf.train.Saver()
# if ckpt and ckpt.model_checkpoint_path:
#saver.restore(sess, checkpoint_path)
train_plot = []
validate_plot = []

# train the network
for i in range(num_epoch):
count_train = 0
loss_sum_train = 0.0
for ys in generate_data(train_label, BATCH_SIZE=BATCH_SIZE, shuffle=True):
train, label, mask_d = get_data_sos(ys, mask_t)
im_start = time.time()
_, loss_value, step, pred = sess.run([train_step, loss, global_step, x_pred],
feed_dict={y_m: train,
mask: mask_d,
x_true: label})
im_end = time.time()
loss_sum_train += loss_value
print("{}\{}\{} of training loss:\t\t{:.6f} \t using :{:.4f}s".
format(i + 1, count_train + 1, int(num_train / BATCH_SIZE),
loss_sum_train / (count_train + 1), im_end - im_start))
count_train += 1

count_validate = 0
loss_sum_validate = 0.0
for ys_validate in generate_data(validate_label, shuffle=True):
y_rt_validate, x_true_validate, mask_validate = get_data_sos(ys_validate, mask_t)
im_start = time.time()
loss_value_validate = sess.run(loss, feed_dict={y_m: y_rt_validate,
mask: mask_validate,
x_true: x_true_validate})
im_end = time.time()
loss_sum_validate += loss_value_validate
count_validate += 1
print("{}\{}\{} of validation loss:\t\t{:.6f} \t using :{:.4f}s".
format(i + 1, count_validate, int(num_validate / BATCH_SIZE),
loss_sum_validate / count_validate, im_end - im_start))
# train_plot.append(loss_sum_train / count_train_per)
# validate_plot.append(loss_sum_validate / count_validate)
saver.save(sess, checkpoint_path)
# train_plot_name = 'train_plot.npy'
# np.save(os.path.join(checkpoint_dir, train_plot_name), train_plot)
# validate_plot_name = 'validate_plot.npy'
# np.save(os.path.join(checkpoint_dir, validate_plot_name), validate_plot)

Loading…
Cancel
Save