|
- import os
- import torch
- from torch.utils.data import DataLoader
- from tqdm.auto import tqdm
- from opt_gan import config_parser
-
- import json, random
- from renderer import *
- from utils.utils import *
- from torch.utils.tensorboard import SummaryWriter
- import datetime
-
- from datasets import dataset_dict
- from datasets.gan import GANDataset
- import sys
-
- from models.tensorRFSemVMSplit import TensorSemVMSplit
- from models.GAN_model import GANModel
-
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- renderer = OctreeRender_trilinear_fast
-
-
- @torch.no_grad()
- def export_mesh(args):
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- tensorf = eval(args.model_name)(**kwargs)
- tensorf.load(ckpt)
-
- alpha, _ = tensorf.getDenseAlpha()
- convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply', bbox=tensorf.aabb.cpu(), level=0.005)
-
-
- def train(args):
- # init dataset
- train_dataset = GANDataset(dataset_name=args.dataset_name,
- datadir=args.datadir,
- near=args.near, far=args.far, scene_bbox_stretch=args.scene_bbox_stretch,
- downsample_train=args.downsample_train,
- nCases=args.nCases,
- use_same_rays=args.use_same_rays, load_colored_sem=args.load_colored_sem)
-
- # # set gpu_ids list
- # str_ids = args.gpu_ids.split(',')
- # args.gpu_ids = []
- # for str_id in str_ids:
- # id = int(str_id)
- # if id >= 0:
- # args.gpu_ids.append(id)
-
- # set gpu id
- # if len(args.gpu_ids) > 0:
- # torch.cuda.set_device(args.gpu_ids[0])
-
- if torch.cuda.device_count() > 1:
- assert(torch.cuda.is_available())
- torch.distributed.init_process_group(backend='nccl')
-
- gpu_id = int(os.environ['LOCAL_RANK'])
- train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
- elif torch.cuda.device_count() == 1:
- gpu_id = 0
- train_sampler = torch.utils.data.RandomSampler(train_dataset)
- else:
- gpu_id = None
- train_sampler = None
-
- # todo: shuffle ?
- train_loader = DataLoader(
- train_dataset,
- batch_size=args.batch_size,
- pin_memory=True, # multi-gpu
- sampler=train_sampler
- )
-
- # init model
- gan = GANModel(args=args, device=device, gpu_ids=gpu_id)
- gan.cuda(gpu_id)
-
- gan.get_loss()
-
- grad_var_G, grad_var_D = gan.get_optparam_groups(lr_gan=args.lr_gan_init)
- optimizer_G = torch.optim.Adam(grad_var_G, betas=(0.5, 0.999))
- optimizer_D = torch.optim.Adam(grad_var_D, betas=(0.5, 0.999))
-
- if torch.cuda.device_count() > 1:
- gan = torch.nn.parallel.DistributedDataParallel(gan, device_ids=[gpu_id])
-
- gan.train()
-
- old_gan_lr = args.lr_gan_init
-
- if args.add_timestamp:
- logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
- else:
- logfolder = f'{args.basedir}/{args.expname}'
- os.makedirs(logfolder, exist_ok=True)
- iter_path = os.path.join(logfolder, 'iter.txt')
-
- if args.use_same_rays:
- rays = train_dataset.rays
-
- aabb = train_dataset.aabb.cuda(gpu_id)
- reso_cur = N_to_reso(args.N_voxel_init, aabb) # N_voxel_grids of each dimension
- nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
-
- for epoch in range(args.nEpoches):
- pbar = tqdm(train_loader)
- for iter, case_sample in enumerate(pbar):
- # todo: set changeable dataset ?
- # todo: concatenate planes and lines from different cases ?
- # todo: more than one case for one batch ?
- if not args.use_same_rays:
- rays = case_sample["rays"]
-
- aabb = case_sample["aabb"].to(device)
- reso_cur = N_to_reso(args.N_voxel_init, aabb) # N_voxel_grids of each dimension
- nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
-
- case_idx = case_sample['case_idx'].item()
- colored_sem_maps, imgs = case_sample['colored_sem_maps'], case_sample['imgs']
- ckpt = train_dataset.ckpts[case_idx] # kwargs.update({'device': device}) will add device here
-
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- kwargs.update({'gpu_ids': gpu_id})
- kwargs.update({'aabb': aabb}) # kwargs.update({'aabb': kwargs['aabb'].cuda(gpu_id)})
- # kwargs.pop('sem_n_comp')
- # kwargs.pop('sem_dim')
- sem_tensorf = eval(args.model_name)(**kwargs) # ?
-
- losses, generated = gan(colored_sem_maps, imgs, rays, ckpt, sem_tensorf, renderer, nSamples)
-
- # sum per device losses
- losses = [torch.mean(x).unsqueeze(0) if not isinstance(x, int) else x for x in losses]
- if len(args.gpu_ids) > 0:
- loss_dict = dict(zip(gan.module.loss_names, losses))
- else:
- loss_dict = dict(zip(gan.loss_names, losses))
-
- # calculate final loss scalar
- loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)
- loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
-
- # update generator weights
- optimizer_G.zero_grad()
- # if opt.fp16:
- # with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:
- # scaled_loss.backward()
- # else:
- loss_G.backward()
- optimizer_G.step()
-
- # update discriminator weights
- optimizer_D.zero_grad()
- # if opt.fp16:
- # with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:
- # scaled_loss.backward()
- # else:
- loss_D.backward()
- optimizer_D.step()
-
- # Print the current values of the losses.
- G_GAN_loss = loss_dict['G_GAN'].item()
- G_Gan_Feat_loss = loss_dict['G_GAN_Feat'].item() if 'G_GAN_Feat' in loss_dict else 0
- G_VGG_loss = loss_dict['G_VGG'].item() if 'G_VGG' in loss_dict else 0
- D_fake, D_real = loss_dict['D_fake'].item(), loss_dict['D_real'].item()
-
- if iter % args.progress_refresh_rate == 0:
- pbar.set_description(
- f'Iteration {iter:03d}:'
- + f' G_GAN = {float(G_GAN_loss):.2f}'
- + f' G_GAN_Feat = {float(G_Gan_Feat_loss):.2f}'
- + f' G = {loss_G.item():.2f}'
- + f' G_VGG = {float(G_VGG_loss):.2f}'
- + f' D_fake = {float(D_fake):.2f}'
- + f' D_real = {float(D_real):.2f}'
- + f' D = {loss_D.item():.2f}'
- )
-
- ### display output images
-
- ### linearly decay/update learning rate after certain iterations
- lrd = args.lr_gan_init / args.lr_decay_iters
- if iter > args.lr_decay_iter_start:
- lr = old_gan_lr - lrd
-
- grad_var_G, grad_var_D = gan.get_optparam_groups(lr_gan=lr)
- optimizer_G = torch.optim.Adam(grad_var_G, betas=(0.5, 0.999))
- optimizer_D = torch.optim.Adam(grad_var_D, betas=(0.5, 0.999))
- print(f'======>update GAN learning rate: {old_gan_lr} -> {lr} <========================')
-
- old_gan_lr = lr
-
- ### save model for certain iteration
- if iter % args.save_iter_freq == 0:
- if len(args.gpu_ids) > 0:
- gan.module.save(logfolder, iter)
- else:
- gan.save(logfolder, iter)
- np.savetxt(iter_path, (iter + 1, 0), delimiter=',', fmt='%d')
- print(f'======>saving the model at the end of epoch {epoch}, iters {iter}<=======================')
-
-
- if __name__ == '__main__':
- torch.set_default_dtype(torch.float32)
- torch.manual_seed(20211202)
- np.random.seed(20211202)
-
- args = config_parser()
- # print(args)
-
- if args.export_mesh:
- export_mesh(args)
-
- train(args)
|