|
- import argparse
- import os
-
- parser = argparse.ArgumentParser(description="Arg parser")
- parser.add_argument('--gpu', type=str, default='0', help='specify gpu')
- parser.add_argument('--mode', default='train', help='train | test')
- parser.add_argument('--wgt_dir', default='weights/punet', help='Log dir [default: logs/test_log]')
- parser.add_argument('--rst_dir', default='results/punet', help='')
- parser.add_argument('--model', default='punet', help='model to train or test')
-
- parser.add_argument('--npoint', type=int, default=1024,help='Point Number [1024/2048] [default: 1024]')
- parser.add_argument('--up_ratio', type=int, default=4, help='Upsampling Ratio [default: 4]')
- parser.add_argument('--max_epoch', type=int, default=120, help='Epochs to run [default: 100]')
- parser.add_argument('--batch_size', type=int, default=28, help='Batch Size during training')
- parser.add_argument("--alpha", type=float, default=1.0) # for repulsion loss
- parser.add_argument('--lr', type=float, default=0.001)
- parser.add_argument('--weight_decay', type=float, default=0.00005)
- parser.add_argument('--workers', type=int, default=4)
-
-
- args = parser.parse_args()
- print(args)
- os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
-
- import torch
- import torch.nn as nn
- from torch.utils.data import DataLoader
- import numpy as np
- import importlib
-
- from punet import PUNet as Model
- from losses import UpsampleLoss
- from dataset import PUNET_Dataset
-
- from utils.ply_utils import save_ply
- from utils.utils import save_xyz_file
- from dataset import PUNET_Dataset_Whole
-
-
-
-
- def train_net():
- if not os.path.exists(os.path.join('./', args.wgt_dir)):
- os.makedirs(os.path.join('./', args.wgt_dir))
-
- train_dataset = PUNET_Dataset(npoint=args.npoint,
- use_random=True, use_norm=True, split='train', is_training=True)
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
- shuffle=True, pin_memory=True, num_workers=args.workers)
-
-
- print(" >>> begin to training " + args.model)
- model = Model(npoint=args.npoint, up_ratio=args.up_ratio,
- use_normal=False, use_bn=False, use_res=False)
- model = nn.DataParallel(model).cuda()
-
- optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
- loss_func = UpsampleLoss(alpha=args.alpha)
-
- model.train()
- for epoch in range(args.max_epoch):
- loss_list = []
- emd_loss_list = []
- rep_loss_list = []
- for batch in train_loader:
- optimizer.zero_grad()
- input_data, gt_data, radius_data = batch
-
- input_data = input_data.float().cuda()
- gt_data = gt_data.float().cuda()
- gt_data = gt_data[..., :3].contiguous()
- radius_data = radius_data.float().cuda()
-
- preds = model(input_data)
- emd_loss, rep_loss = loss_func(preds, gt_data, radius_data)
- loss = emd_loss + rep_loss
-
- loss.backward()
- optimizer.step()
-
- loss_list.append(loss.item())
- emd_loss_list.append(emd_loss.item())
- rep_loss_list.append(rep_loss.item())
- print(' -- epoch {}, loss {:.4f}, weighted emd loss {:.4f}, repulsion loss {:.4f}, lr {}.'.format(
- epoch, np.mean(loss_list), np.mean(emd_loss_list), np.mean(rep_loss_list), \
- optimizer.state_dict()['param_groups'][0]['lr']))
-
- if (epoch + 1) % 20 == 0:
- state = {'epoch': epoch, 'model_state': model.state_dict()}
- save_path = os.path.join(args.wgt_dir, 'punet_epoch_{}.pth'.format(epoch))
- torch.save(state, save_path)
-
- def test_net():
- if not os.path.exists(os.path.join('./', args.rst_dir)):
- os.makedirs(os.path.join('./', args.rst_dir))
-
- print(" >>> begin to testing " + args.model)
-
- model = Model(npoint=1024, up_ratio=args.up_ratio,
- use_normal=False, use_bn=False, use_res=False)
- model = nn.DataParallel(model).cuda()
-
- resume_dir = os.path.join(args.wgt_dir, "punet_epoch_{}.pth".format(args.max_epoch-1))
- checkpoint = torch.load(resume_dir)
- model.load_state_dict(checkpoint['model_state'])
- print(" >>> Load model weights from " + resume_dir)
- model.eval()
-
- eval_dst = PUNET_Dataset_Whole(data_dir='./datas/test_data/our_collected_data/MC_5k')
- eval_loader = DataLoader(eval_dst, batch_size=1,
- shuffle=False, pin_memory=True, num_workers=0)
-
- names = eval_dst.names
- for itr, batch in enumerate(eval_loader):
- name = names[itr]
- points = batch.float().cuda()
- preds = model(points, npoint=points.shape[1])
-
- preds = preds.data.cpu().numpy()
- points = points.data.cpu().numpy()
- save_ply(os.path.join(args.rst_dir, '{}_input.ply'.format(name)), points[0, :, :3])
- save_ply(os.path.join(args.rst_dir, '{}.ply'.format(name)), preds[0])
- save_xyz_file(preds[0], os.path.join(args.rst_dir, '{}.xyz'.format(name)))
- print('{} with shape {}, output shape {}'.format(name, points.shape, preds.shape))
-
-
- if __name__ == '__main__':
- if args.mode == 'train':
- train_net()
- elif args.mode == 'test':
- test_net()
- else:
- print('please switch to train or test !')
|