|
- import argparse
- import os,random
- import numpy as np
- import tensorflow as tf
- from base_dataset import BaseDataset
- import datetime
- # import multi_scale_model
- import base_model
- import time
-
- class CPrintl():
- def __init__(self,logName) -> None:
- self.log_file = logName
- if os.path.dirname(logName)!='' and not os.path.exists(os.path.dirname(logName)):
- os.makedirs(os.path.dirname(logName))
- def __call__(self, *args):
- print(*args)
- print(*args, file=open(self.log_file, 'a'))
-
- def parse_args():
- parser = argparse.ArgumentParser(description='Training')
- parser.add_argument('--checkpoints_init', default='', #/userhome/postprocess_v1_bek/tensorflow/single_GPU_models/
- help='pretrained models dir')
- parser.add_argument('--checkpoints_dir', default='single_GPU_models',
- help='Dir for saving logs and models.')
- parser.add_argument('--seed', type=int, default=0, help='Random seed.')
- args = parser.parse_args()
- return args
-
- def set_random_seed(seed):
- r"""Set random seeds for everything.
-
- Args:
- seed (int): Random seed.
- by_rank (bool):
- """
- random.seed(seed)
- np.random.seed(seed)
- tf.random.set_seed(seed)
-
- def BCELoss(predict, label):
- eps = 1e-7
- predict = tf.keras.backend.clip(predict, min_value=eps, max_value=1-eps)
- # bce = -((1-label)*(1-predict).log()).sum()/(1-label).sum() - (label*predict.log()).sum()/label.sum()
- bce = -tf.reduce_sum((1-label)*tf.math.log(1-predict)) / tf.reduce_sum(1-label) - tf.reduce_sum(label*tf.math.log(predict)) / tf.reduce_sum(label)
- # bce = bce * 10 #没用
- return bce
-
- if __name__ == '__main__':
- # get training options
- cube_size = 64
- batch_size=32
- train_path = '/userhome/postprocess_v1_bek/traindata/train.txt'
- max_epoch = 100
- logging_iter = 30 #150=5分钟保存一次
- snapshot_save_iter = 30*2
- args = parse_args()
- set_random_seed(args.seed)
-
- os.makedirs(args.checkpoints_dir, exist_ok=True)
- printl = CPrintl(args.checkpoints_dir+'/log_TF.txt')
- printl(datetime.datetime.now().strftime('\r\n%Y-%m-%d:%H:%M:%S'))
- # create a model
- net_G = base_model.Generator(base_channel=16,num_layers=4)
- net_G.build(input_shape = (1,1,64,64,64))
- for v in net_G.trainable_variables:
- print(v.name,v.shape,v.dtype)
- opt_G = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.99, epsilon=1e-08, decay=0.0)
- # net_G, net_D, net_G_ema, opt_G, opt_D, sch_G, sch_D \
- # = get_model_optimizer_and_scheduler(opt) #opt means optimizer, sch means lr 没用:sch_G(改变lr用的,但是不用改), sch_D,opt_D,net_D,net_G_ema
-
- checkpoint = tf.train.Checkpoint(model_net=net_G,model_opt=opt_G)
- if args.checkpoints_init != '':
- latest_ckpt = tf.train.latest_checkpoint(args.checkpoints_init)
- checkpoint.restore(latest_ckpt)
- printl('loading checkpoint from '+latest_ckpt)
-
- # Start training.
- best_loss = None
- for epoch in range(max_epoch):
- printl('Epoch {} starts...'.format(epoch))
- start_epoch_time = time.time()
- # create a dataset
- train_dataset = BaseDataset(train_path, cube_size=cube_size, batch_size=batch_size, is_inference=False)
- elapsed_iteration_time = 0.
- num = 0.
- total_loss = 0.
- print("len(train_dataset.dataset):",len(train_dataset.dataset))
- for iter, (decompressed,gt) in enumerate(train_dataset.dataset): #共3034个iter
- start_iteration_time = time.time()
- # print("decompressed.shape:",decompressed.shape)
- with tf.GradientTape() as tape:
- predict = net_G(decompressed, training_flag=True) #decompressed.shape: (32, 1, 64, 64, 64)
- gen_losses = BCELoss(predict, gt)
- grads = tape.gradient(gen_losses, net_G.trainable_variables)
- clip_grads = [tf.clip_by_value(grad, -1.0, 1.0) for grad in grads] #限制梯度大小,不然不收敛
- opt_G.apply_gradients(zip(clip_grads, net_G.trainable_variables))
- total_loss += gen_losses.numpy()
- num += 1
- elapsed_iteration_time += time.time() - start_iteration_time
- if (iter+1) % logging_iter == 0:
- total_loss /= num
- ave_t = elapsed_iteration_time / num
- printl(datetime.datetime.now().strftime('\r\n%Y-%m-%d:%H:%M:%S')+': Epoch: {} ,Iteration: {}, average iter time: {:3f}.'.format(epoch, iter, ave_t))
- printl('total_loss:' + str(total_loss))
- if (iter+1) % snapshot_save_iter == 0:
- if best_loss is None or best_loss > total_loss:
- checkpoint.save(args.checkpoints_dir + '/model.ckpt')
- best_loss = total_loss
- printl('epoch:'+str(epoch)+', Iteration:' + str(iter) + " ,save model!")
- num = 0.
- total_loss = 0.
- elapsed_iteration_time = 0
- elapsed_epoch_time = time.time() - start_epoch_time
- printl('Epoch: {}, total time: {:3f}'.format(epoch, elapsed_epoch_time))
|