|
- import time, os, sys, glob, argparse
- import importlib
- import numpy as np
- import torch
- import h5py
- import random
- random.seed() #这句对下面的random.sample没影响
- from data_loader import PCDataset, make_data_loader
- from trainer import Trainer
- from pcc_model import PCCModel
-
- def parse_args(): #已改
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-
- parser.add_argument("--dataset", type=str, default='/userhome/PCGCv1/training-data-set/points64_part1/')
- parser.add_argument("--dataset_num", type=int, default=2.8e5) #2.8e5可以先搞个小的试试 300,001 个文件
- parser.add_argument(
- "--alpha", type=float, default=10, dest="alpha", #6
- help="weights for distoration.")
- parser.add_argument(
- "--beta", type=float, default=3., dest="beta",
- help="Weight for empty position.")
- parser.add_argument(
- "--gamma", type=float, default=1.3, dest="gamma",
- help="Weight for hyper likelihoods.")
- parser.add_argument(
- "--delta", type=float, default=3., dest="delta",
- help="Weight for latent likelihoods.")
- parser.add_argument(
- "--lr", type=float, default=2e-4, dest="lr", #2e-4
- help="learning rate.")
- parser.add_argument("--epoch", type=int, default=14) #
- # parser.add_argument(
- # "--num_iteration", type=int, default=3e5, dest="num_iteration",
- # help="number of iteration.")
- parser.add_argument(
- "--prefix", type=str, default='hyper_mgpu4', dest="prefix",
- help="prefix of checkpoints/logger.")
- parser.add_argument(
- "--init_ckpt", type=str, default='/userhome/PCGCv1/pytorch3/ckpts/hyper_mgpu3/epoch_12_11099.pth', dest="init_ckpt", #/userhome/PCGCv1/pytorch/ckpts/hyper_/epoch_35.pth
- help='initial checkpoint directory.')
- #parser.add_argument(
- #"--reset_optimizer", type=int, default=0, dest="reset_optimizer",
- #help='reset optimizer (1) or not.')
- parser.add_argument(
- "--lower_bound", type=float, default=1e-9, dest="lower_bound",
- help="lower bound of scale. 1e-5 or 1e-9")
- parser.add_argument(
- "--batch_size", type=int, default=32, dest="batch_size", #48会爆显存
- help='batch_size')
-
- args = parser.parse_args()
-
- return args
-
- class TrainingConfig(): #已改
- def __init__(self, logdir, ckptdir, init_ckpt, alpha, beta, gamma, delta, lr):
- self.logdir = logdir
- if not os.path.exists(self.logdir): os.makedirs(self.logdir)
- self.ckptdir = ckptdir
- if not os.path.exists(self.ckptdir): os.makedirs(self.ckptdir)
- self.init_ckpt = init_ckpt
- self.alpha = alpha
- self.beta = beta
- self.lr = lr
- self.gamma = gamma # weight of hyper prior.
- self.delta = delta # weight of latent representation.
-
- #已改
- if __name__ == '__main__':
- # log
- args = parse_args()
- # Define parameters.
- RATIO_EVAL = 9 #
- #NUM_ITEATION = int(args.num_iteration)
- #print('lower bound of scale:', lower_bound)
- #reset_optimizer = bool(args.reset_optimizer)
- #print('reset_optimizer:::', reset_optimizer)
-
- training_config = TrainingConfig(
- logdir=os.path.join('./logs', args.prefix),
- ckptdir=os.path.join('./ckpts', args.prefix), #保存当前训练的模型
- init_ckpt=args.init_ckpt, #初始化模型
- alpha=args.alpha,
- beta=args.beta,
- gamma=args.gamma,
- delta=args.delta,
- lr=args.lr)
- # model
- model = PCCModel(lower_bound=args.lower_bound)
- # trainer
- trainer = Trainer(config=training_config, model=model)
-
- # dataset
- filedirs = sorted(glob.glob(args.dataset+'*.h5'))[:int(args.dataset_num)]
- #print("all files len(filedirs):",len(filedirs)) #有这句执行太慢
- # training
- for epoch in range(0, args.epoch):
- if epoch>0:
- trainer.update_lr(lr=max(trainer.config.lr/2, 1e-5)) #update lr
- train_list = random.sample(filedirs[len(filedirs)//RATIO_EVAL:], 1000*args.batch_size) #5000 每个epoch的迭代次数 每次随机生成,但偶尔有重复的
- train_dataset = PCDataset(train_list)
- train_dataloader = make_data_loader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=6, repeat=False)
- trainer.train(train_dataloader)
- eval_list = random.sample(filedirs[:len(filedirs)//RATIO_EVAL], 10*args.batch_size) #10
- test_dataset = PCDataset(eval_list)
- test_dataloader = make_data_loader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=3, repeat=False)
- trainer.test(test_dataloader, 'Test')
|