|
- # import torch
- # import torch.optim as optim
- import tensorflow as tf
- import numpy as np
- import os
- import argparse
- import time
- from tensorboardX import SummaryWriter
- from fn import config, datacore
- from fn.trainer import Trainer
- # from fd.checkpoints import CheckpointIO
- import pickle
-
- import tensorlayer as tl
- from tensorlayer.dataflow import Dataloader
-
- os.environ["CUDA_VISIBLE_DEVICES"] = '0'
-
- if __name__ == '__main__':
- # Arguments
- cfg = config.load_config('configs/fn.yaml')
- # is_cuda = (torch.cuda.is_available() )
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- # print(device)
-
- # Set t0
- t0 = time.time()
-
- # Shorthands
- out_dir = 'out/fn'
- logfile = open('out/fn/log.txt','a')
- batch_size=cfg['training']['batch_size']
- if not os.path.exists(out_dir):
- os.makedirs(out_dir)
- train_dataset = config.get_dataset('train', cfg)
- val_dataset = config.get_dataset('val', cfg)
-
- train_dataset = tl.dataflow.FromGenerator(train_dataset,output_types=([tl.float32, tl.float32]))
- train_loader = Dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, shuffle_buffer_size=1000)
-
- eval_dataset = tl.dataflow.FromGenerator(val_dataset, output_types=([tl.float32, tl.float32]))
- val_loader = Dataloader(eval_dataset, batch_size=batch_size, shuffle=False)
-
- model = config.get_model(cfg)
-
- # optimizer = optim.Adam(model.parameters(), lr=1e-4)
- optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.999)
-
- trainer = Trainer(model, optimizer)
-
- # try:
- # load_dict = checkpoint_io.load('model.pt')
- # except FileExistsError:
- # load_dict = dict()
- # epoch_it = load_dict.get('epoch_it', -1)
- # it = load_dict.get('it', -1)
- epoch_it = 0
- it = 0
-
- metric_val_best = np.inf
-
- model.build(input_shape=(4, 16, 100, 3))
- model.summary()
-
- print_every = cfg['training']['print_every']
- checkpoint_every = cfg['training']['checkpoint_every']
- validate_every = cfg['training']['validate_every']
-
- while True:
- epoch_it += 1
- # scheduler.step()
- logfile.flush()
- if epoch_it>20000:
- logfile.close()
- break
- for points, outputs in train_loader:
- it += 1
-
- #batch[0] = points, batch[1] = label
- if points.shape[0]==1:
- continue
- loss = trainer.train_step(points, outputs)
- #logger.add_scalar('train/loss', loss, it)
-
- if print_every > 0 and (it % print_every) == 0 and it > 0 :
- logfile.write('[Epoch %02d] it=%03d, loss=%.6f\n'
- % (epoch_it, it, loss))
- print('[Epoch %02d] it=%03d, loss=%.6f'
- % (epoch_it, it, loss))
-
-
-
- # Save checkpoint
- if (checkpoint_every > 0 and (it % checkpoint_every) == 0) and it > 0 :
- logfile.write('Saving checkpoint')
- #checkpoint_io.save('save_model', epoch_it=epoch_it, it=it,loss_val_best=metric_val_best)
- model.save_weights('./out/fn/model.h5')
-
- # Run validation
- if (validate_every > 0 and (it % validate_every) == 0) and it > 0 :
- metric_val = trainer.evaluate(val_loader)
- metric_val=metric_val
- logfile.write('Validation metric : %.6f\n'
- % (metric_val))
- if metric_val < metric_val_best:
- metric_val_best = metric_val
- logfile.write('New best model (loss %.6f)\n' % metric_val_best)
- #checkpoint_io.save('model_best', epoch_it=epoch_it, it=it,loss_val_best=metric_val_best)
- model.save_weights('./out/fn/model_best.h5')
|