|
- import os
- import time
-
- import numpy as np
- import torch.nn
- from tqdm.auto import tqdm
- from opt_sem_scannet import config_parser
-
- import json, random
- from renderer import *
- from utils.utils import *
- from torch.utils.tensorboard import SummaryWriter
- import datetime
-
- from datasets import dataset_dict, scannet
- import sys
-
- from models.tensorRFSemVMSplit import TensorSemVMSplit
-
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # input from TensorBase ?
- if torch.cuda.device_count() > 1:
- assert (torch.cuda.is_available())
- torch.distributed.init_process_group(backend='nccl')
-
- gpu_id = int(os.environ['LOCAL_RANK'])
- elif torch.cuda.device_count() == 1:
- gpu_id = 0
- else:
- gpu_id = None
-
-
- sem_renderer = OctreeRender_trilinear_fast_with_sem
-
-
- class SimpleSampler:
- def __init__(self, total, batch):
- self.total = total
- self.batch = batch
- self.curr = total
- self.ids = None
-
- def nextids(self):
- self.curr += self.batch
- if self.curr + self.batch > self.total:
- self.ids = torch.LongTensor(np.random.permutation(self.total))
- self.curr = 0
- return self.ids[self.curr:self.curr + self.batch]
-
-
- class MultiTaskLossWrapper(nn.Module):
- def __init__(self, task_num=2):
- super(MultiTaskLossWrapper, self).__init__()
- self.log_vars = torch.FloatTensor((task_num)).to(device)
- self.log_vars.requires_grad = True
-
- self.log_vars[0].data.fill_(0.)
- self.log_vars[1].data.fill_(3.22)
-
- def forward(self, loss_reg, loss_cls):
- precision1 = torch.exp(-self.log_vars[0])
- # loss = 0.5 * precision1 * loss_reg + self.log_vars[0]
- loss = precision1 * loss_reg
-
- precision2 = torch.exp(-self.log_vars[1])
- # loss += precision2 * loss_cls + 0.5 * self.log_vars[1]
- loss += precision2 * loss_cls
-
- return loss, self.log_vars.data.tolist()
-
-
- @torch.no_grad()
- def export_mesh(args):
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- tensorf = eval(args.model_name)(**kwargs)
- tensorf.load(ckpt)
-
- alpha, _ = tensorf.getDenseAlpha()
- convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply', bbox=tensorf.aabb.cpu(), level=0.005)
-
-
- @torch.no_grad()
- def render_test_with_sem(args):
- # init dataset
- dataset = dataset_dict[args.dataset_name]
- if args.dataset_name == "replica":
- train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False)
- test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
- train_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"],
- args.sem_info_path)
- test_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"],
- args.sem_info_path)
- else:
- train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False)
- test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
- white_bg = test_dataset.white_bg
- ndc_ray = args.ndc_ray
-
- if not os.path.exists(args.ckpt):
- print('the ckpt path does not exists!!')
- return
-
- n_lamb_sem, num_valid_sem_class = args.n_lamb_sem, train_dataset.num_ins_class - 1
-
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
-
- sem_tensorf = eval(args.model_name)(**kwargs)
- sem_tensorf.load(ckpt)
- # if args.dataset_name == "replica":
- # sem_tensorf.set_label_colour_map(args.sem_info_path)
-
- logfolder = os.path.dirname(args.ckpt)
- if args.render_train:
- os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
- PSNRs_test = evaluation_with_sem(train_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_train_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
- print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_test:
- os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
- evaluation_with_sem(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
-
- if args.render_path:
- c2ws = test_dataset.render_path
- os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
- evaluation_path_with_sem(test_dataset, sem_tensorf, c2ws, sem_renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
-
-
- def reconstruction_with_sem(args):
- # init dataset
- # dataset = dataset_dict[args.dataset_name]
- dataset = scannet.ScanNet
-
- train_dataset = dataset(args.datadir, split='train', near=args.near, far=args.far,
- scene_bbox_stretch=args.scene_bbox_stretch, is_stack=False)
- test_dataset = dataset(args.datadir, split='test', near=args.near, far=args.far,
- scene_bbox_stretch=args.scene_bbox_stretch, is_stack=True)
- train_dataset.remap_sem_gt_label(train_dataset.sem_samples['sem_img'],
- test_dataset.sem_samples['sem_img'])
- test_dataset.remap_sem_gt_label(train_dataset.sem_samples['sem_img'],
- test_dataset.sem_samples['sem_img'])
- train_dataset.set_label2color_map()
- test_dataset.set_label2color_map()
-
- # train_dataset.save_sem_color_map()
- # test_dataset.save_sem_color_map()
-
- white_bg = train_dataset.white_bg
- near_far = train_dataset.near_far
- ndc_ray = args.ndc_ray
-
- # init resolution ?
- upsamp_list = args.upsamp_list
- update_AlphaMask_list = args.update_AlphaMask_list
-
- n_lamb_sigma = args.n_lamb_sigma
- n_lamb_sh = args.n_lamb_sh
- n_lamb_sem = args.n_lamb_sem
-
- if args.add_timestamp:
- logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
- else:
- logfolder = f'{args.basedir}/{args.expname}'
-
- # init log file
- os.makedirs(logfolder, exist_ok=True)
- os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
- os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
- os.makedirs(f'{logfolder}/rgba', exist_ok=True)
- os.makedirs(f'{logfolder}/sem', exist_ok=True) # ?
- summary_writer = SummaryWriter(logfolder)
-
- # init parameters
- # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
- aabb = train_dataset.scene_bbox.to(device) # adaptive scene_bbox world coordinates, (2, 3)
- reso_cur = N_to_reso(args.N_voxel_init, aabb) # N_voxel_grids of each dimension
- nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
-
- # build the model
- if args.ckpt is not None:
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- sem_tensorf = eval(args.model_name)(**kwargs)
- sem_tensorf.load(ckpt)
- else:
- sem_tensorf = eval(args.model_name)(aabb, reso_cur, device, gpu_ids=gpu_id,
- density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, sem_n_comp=n_lamb_sem,
- app_dim=args.data_dim_color, sem_dim=train_dataset.num_ins_class,
- near_far=near_far,
- shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre,
- density_shift=args.density_shift, distance_scale=args.distance_scale,
- pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe,
- featureC=args.featureC, step_ratio=args.step_ratio,
- fea2denseAct=args.fea2denseAct)
-
- if args.use_sdf:
- sem_tensorf.set_sdf_init()
- # build optimizer
- grad_vars = sem_tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
- if args.lr_decay_iters > 0:
- lr_factor = args.lr_decay_target_ratio ** (1 / args.lr_decay_iters)
- else:
- args.lr_decay_iters = args.n_iters
- lr_factor = args.lr_decay_target_ratio ** (1 / args.n_iters)
-
- print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
- loss_3d_weight = 1e-4 # 0.1 Manhattan sdf eikonal_weight
- if not args.use_mul_task_loss:
- rgb_loss_weight = 1.
- sem_loss_weight = 4e-2
- else:
- # rgb_loss_weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device)
- # rgb_loss_weight.data.fill_(1.)
- # grad_vars.append({'params': rgb_loss_weight, 'lr': 0.001})
- #
- # sem_loss_weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device)
- # sem_loss_weight.data.fill_(4e-2)
- # grad_vars.append({'params': sem_loss_weight, 'lr': 0.001})
-
- criterion_mul_task = MultiTaskLossWrapper()
- grad_vars.append({'params': criterion_mul_task.log_vars, 'lr': 0.001})
-
- # sem_loss_weight = 4e-2
-
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
- scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
-
- N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list) + 1))).long()).tolist()[1:]
-
- torch.cuda.empty_cache()
- PSNRs, PSNRs_test = [], [0]
-
- allrays, allrgbs, allsems = train_dataset.all_rays, train_dataset.all_rgbs, train_dataset.sem_samples["sem_remap"]
- allsems = allsems.reshape(allrgbs.shape[0], 1) # (num_imgs*h*w, 1)
- allsems = torch.tensor(allsems)
-
- if not args.ndc_ray:
- allrays, allrgbs, allsems = sem_tensorf.filtering_rays(allrays, allrgbs, allsems, bbox_only=True)
- trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
-
- Ortho_reg_weight = args.Ortho_weight
- print("initial Ortho_reg_weight", Ortho_reg_weight)
-
- L1_reg_weight = args.L1_weight_inital
- print("initial L1_reg_weight", L1_reg_weight)
- TV_weight_density, TV_weight_app, TV_weight_sem = args.TV_weight_density, args.TV_weight_app, args.TV_weight_sem
- tvreg = TVLoss()
- print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app} semantic: {TV_weight_sem}")
-
- pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
-
- start_t = time.time()
-
- # use values on each pixel, not each image
- for iteration in pbar:
- ray_idx = trainingSampler.nextids() # Tensor, ()
- rays_train, rgb_train, sem_train = allrays[ray_idx], allrgbs[ray_idx].to(device), allsems[ray_idx].to(device)
-
- # rgb_map, alphas_map, sem_map, depth_map, weights, uncertainty
- rgb_map, alphas_map, sem_map, depth_map, weights, uncertainty, loss_3d = sem_renderer(rays_train, sem_tensorf,
- chunk=args.batch_size,
- N_samples=nSamples,
- white_bg=white_bg,
- ndc_ray=ndc_ray,
- device=device,
- is_train=True,
- fp16=args.fp16)
- loss_3d_total = 0
- if loss_3d != {}:
- for key in loss_3d.keys():
- loss_3d_total += loss_3d[key] * loss_3d_weight
- rgb_loss = torch.mean((rgb_map - rgb_train) ** 2)
-
- sem_crossentropy_loss = train_dataset.get_sem_loss(sem_map, sem_train)
-
- # loss
- if not args.use_mul_task_loss:
- total_loss = rgb_loss * rgb_loss_weight + sem_crossentropy_loss * sem_loss_weight + loss_3d_total
- else:
- total_loss, loss_weight = criterion_mul_task(rgb_loss, sem_crossentropy_loss)
-
- if Ortho_reg_weight > 0:
- loss_reg = sem_tensorf.vector_comp_diffs()
- total_loss += Ortho_reg_weight * loss_reg # ?
- summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
- if L1_reg_weight > 0:
- loss_reg_L1 = sem_tensorf.density_L1()
- total_loss += L1_reg_weight * loss_reg_L1
- summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)
-
- if TV_weight_density > 0:
- TV_weight_density *= lr_factor
- loss_tv = sem_tensorf.TV_loss_density(tvreg) * TV_weight_density
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
- if TV_weight_app > 0:
- TV_weight_app *= lr_factor
- loss_tv = sem_tensorf.TV_loss_app(tvreg) * TV_weight_app
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)
- if TV_weight_sem > 0:
- TV_weight_sem *= lr_factor
- loss_tv = sem_tensorf.TV_loss_sem(tvreg) * TV_weight_sem
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_sem', loss_tv.detach().item(), global_step=iteration)
-
- optimizer.zero_grad()
-
- scaler.scale(total_loss).backward()
- scaler.step(optimizer)
- scaler.update()
-
- rgb_loss = rgb_loss.detach().item()
- sem_loss = sem_crossentropy_loss.detach().item()
-
- PSNRs.append(-10.0 * np.log(rgb_loss) / np.log(10.0))
- summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
- summary_writer.add_scalar('train/mse', rgb_loss, global_step=iteration)
- summary_writer.add_scalar('train/sem_CE', sem_loss, global_step=iteration)
-
- if loss_3d != {}:
- for key in loss_3d.keys():
- summary_writer.add_scalar(f'train/loss_3d_{key}', loss_3d[key], global_step=iteration)
- for param_group in optimizer.param_groups:
- param_group['lr'] = param_group['lr'] * lr_factor
-
- # Print the current values of the losses.
- if (iteration + 1) % args.progress_refresh_rate == 0:
- pbar.set_description(
- f'Iter {(iteration+1):05d}:'
- + f' train_psnr:{float(np.mean(PSNRs)):.2f}'
- + f' test_psnr:{float(np.mean(PSNRs_test)):.2f}'
- + f' mse:{rgb_loss:.5f}'
- + f' CE:{sem_loss:.5f}'
- + f' 3d:{loss_3d_total:.5f}'
- )
- PSNRs = []
-
- # try on test datasets during training
- if iteration % args.vis_every == args.vis_every - 1 and args.N_vis != 0:
- PSNRs_test = evaluation_with_sem(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis,
- prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray,
- compute_extra_metrics=False, fp16=args.fp16, chunk_size=args.batch_size)
- summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
-
- if iteration in update_AlphaMask_list:
- if reso_cur[0] * reso_cur[1] * reso_cur[2] < 256 ** 3: # update volume resolution
- reso_mask = reso_cur
- new_aabb = sem_tensorf.updateAlphaMask(tuple(reso_mask))
- if iteration == update_AlphaMask_list[0]:
- sem_tensorf.shrink(new_aabb)
- # tensorVM.alphaMask = None
- L1_reg_weight = args.L1_weight_rest
- print("continuing L1_reg_weight", L1_reg_weight)
-
- if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
- # filter rays outside the bbox
- allrays, allrgbs, allsems = sem_tensorf.filtering_rays(allrays, allrgbs, allsems)
- trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size)
-
- # enlarge gridSize
- if iteration in upsamp_list:
- n_voxels = N_voxel_list.pop(0)
- reso_cur = N_to_reso(n_voxels, sem_tensorf.aabb)
- nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
- sem_tensorf.upsample_volume_grid(reso_cur)
-
- if args.lr_upsample_reset:
- print("reset lr to initial")
- lr_scale = 1 # 0.1 ** (iteration / args.n_iters)
- else:
- lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
- grad_vars = sem_tensorf.get_optparam_groups(args.lr_init * lr_scale, args.lr_basis * lr_scale)
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
- end_t = time.time()
-
- print(f"Training takes {(end_t - start_t)/ 60:.4f} minutes.")
-
- sem_tensorf.save(f'{logfolder}/{args.expname}.pth')
-
- if args.render_train:
- os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
- train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
- PSNRs_test = evaluation_with_sem(train_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_train_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- fp16=args.fp16, chunk_size=args.batch_size)
- print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_test:
- os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
- PSNRs_test = evaluation_with_sem(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_test_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- fp16=args.fp16, chunk_size=args.batch_size)
- summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
- print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
-
- # images from new view-points
- if args.render_path:
- c2ws = test_dataset.render_path
- # c2ws = test_dataset.poses
- print('========>', c2ws.shape)
- os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
- evaluation_path_with_sem(test_dataset, sem_tensorf, c2ws, sem_renderer, f'{logfolder}/imgs_path_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- chunk_size=args.batch_size)
-
- # if args.tps:
- # # os.makedirs(f'{logfolder}/imgs_tps', exist_ok=True)
-
-
- def reconstruction_with_rawfeas(args):
- # init dataset
- dataset = scannet.ScanNet
-
- train_dataset = dataset(args.datadir, near=args.near, far=args.far, scene_bbox_stretch=args.scene_bbox_stretch,
- split='train', is_stack=False, use_sem=True,
- dino_feature_dir=args.dino_feature_dir)
- test_dataset = dataset(args.datadir, near=args.near, far=args.far, scene_bbox_stretch=args.scene_bbox_stretch,
- split='test', is_stack=True, use_sem=True)
- train_dataset.get_sem_label_num(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"],
- args.sem_info_path)
-
- train_dataset.load_dino_features()
-
- white_bg = train_dataset.white_bg
- near_far = train_dataset.near_far
- ndc_ray = args.ndc_ray
-
- # init resolution ?
- upsamp_list = args.upsamp_list
- update_AlphaMask_list = args.update_AlphaMask_list
-
- n_lamb_sigma = args.n_lamb_sigma
- n_lamb_sh = args.n_lamb_sh
- n_lamb_sem = args.n_lamb_sem
-
- if args.add_timestamp:
- logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
- else:
- logfolder = f'{args.basedir}/{args.expname}'
-
- # init log file
- os.makedirs(logfolder, exist_ok=True)
- os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
- os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
- os.makedirs(f'{logfolder}/rgba', exist_ok=True)
- summary_writer = SummaryWriter(logfolder)
-
- # init parameters
- # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
- aabb = train_dataset.scene_bbox.to(device) # adaptive scene_bbox world coordinates, (2, 3)
- reso_cur = N_to_reso(args.N_voxel_init, aabb) # N_voxel_grids of each dimension
- nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
-
- # build the model
- if args.ckpt is not None:
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- sem_tensorf = eval(args.model_name)(**kwargs)
- sem_tensorf.load(ckpt)
- else:
- sem_tensorf = eval(args.model_name)(aabb, reso_cur, device, gpu_ids=gpu_id, fp16=args.fp16,
- density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, sem_n_comp=n_lamb_sem, n_dino_fea=args.n_dino_fea,
- app_dim=args.data_dim_color, sem_dim=train_dataset.num_valid_semantic_class, # sem_dim: default, will be updated in sem rendering
- use_rgbs=True, use_raw_semfeas=True,
- near_far=near_far, shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre,
- density_shift=args.density_shift, distance_scale=args.distance_scale,
- pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe,
- featureC=args.featureC, step_ratio=args.step_ratio,
- fea2denseAct=args.fea2denseAct)
- start_iter = 0
-
- if args.use_sdf:
- sem_tensorf.set_sdf_init()
-
- # build optimizer
- grad_vars = sem_tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
- if args.lr_decay_iters > 0:
- lr_factor = args.lr_decay_target_ratio ** (1 / args.lr_decay_iters)
- else:
- args.lr_decay_iters = args.n_iters
- lr_factor = args.lr_decay_target_ratio ** (1 / args.n_iters)
-
- print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
-
- if args.use_mul_task_loss:
- criterion_mul_task = MultiTaskLossWrapper()
- grad_vars.append({'params': criterion_mul_task.log_vars, 'lr': 0.001})
- else:
- rgb_loss_weight = 1.
- distill_loss_weight = 0.1
- loss_3d_weight = 1e-4 # 0.1 Manhattan sdf eikonal_weight
-
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
- scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
-
- # linear in logrithmic space
- N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list) + 1))).long()).tolist()[1:]
-
- torch.cuda.empty_cache()
- PSNRs, PSNRs_test = [], [0]
-
- allrays, allrgbs, allsemfeas = train_dataset.all_rays, train_dataset.all_rgbs, train_dataset.all_semfeas
-
- if not args.ndc_ray:
- allrays, allrgbs, allsemfeas = sem_tensorf.filtering_rays(allrays, allrgbs, allsemfeas, bbox_only=True)
- trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
-
- Ortho_reg_weight = args.Ortho_weight
- print("initial Ortho_reg_weight", Ortho_reg_weight)
-
- L1_reg_weight = args.L1_weight_inital
- print("initial L1_reg_weight", L1_reg_weight)
- TV_weight_density, TV_weight_app, TV_weight_semfea = args.TV_weight_density, args.TV_weight_app, args.TV_weight_semfea
- tvreg = TVLoss()
- print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app} raw_sem_features: {TV_weight_semfea}")
-
- pbar = tqdm(range(start_iter, start_iter + args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
-
- if args.fp16:
- torch.set_default_tensor_type('torch.cuda.HalfTensor')
-
- start_t = time.time()
-
- # use values on each pixel, not each image
- for iteration in pbar:
- ray_idx = trainingSampler.nextids() # Tensor, ()
- rays_train, rgbs_train, semfeas_train = allrays[ray_idx], allrgbs[ray_idx].to(device), allsemfeas[ray_idx].to(device)
-
- # rgb_map, alphas_map, sem_map, depth_map, weights, uncertainty
- rgb_map, alphas_map, sem_map, semfea_map, depth_map, weights, uncertainty, loss_3d = sem_renderer(rays_train,
- sem_tensorf,
- chunk=args.batch_size,
- N_samples=nSamples,
- white_bg=white_bg,
- ndc_ray=ndc_ray,
- device=device,
- is_train=True,
- # fp16=args.fp16,
- use_rgbs=sem_tensorf.use_rgbs,
- use_raw_semfeas=sem_tensorf.use_raw_semfeas)
- loss_3d_total = 0
- if loss_3d != {}:
- for key in loss_3d.keys():
- loss_3d_total += loss_3d[key] * loss_3d_weight
-
- # loss
- if sem_tensorf.use_rgbs:
- rgb_loss = torch.mean((rgb_map - rgbs_train) ** 2)
- if sem_tensorf.use_raw_semfeas:
- distill_loss = torch.mean((semfea_map - semfeas_train) ** 2)
-
- if not args.use_mul_task_loss:
- total_loss = loss_3d_total * loss_3d_weight
- if sem_tensorf.use_rgbs:
- total_loss += rgb_loss * rgb_loss_weight
- if sem_tensorf.use_raw_semfeas:
- total_loss += distill_loss * distill_loss_weight
- else:
- total_loss, loss_weight = criterion_mul_task(rgb_loss, distill_loss)
-
- if sem_tensorf.use_rgbs:
- if Ortho_reg_weight > 0:
- loss_reg = sem_tensorf.vector_comp_diffs()
- total_loss += Ortho_reg_weight * loss_reg # ?
- summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
- if L1_reg_weight > 0:
- loss_reg_L1 = sem_tensorf.density_L1()
- total_loss += L1_reg_weight * loss_reg_L1
- summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)
- if TV_weight_density > 0:
- TV_weight_density *= lr_factor
- loss_tv = sem_tensorf.TV_loss_density(tvreg) * TV_weight_density
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
- if TV_weight_app > 0:
- TV_weight_app *= lr_factor
- loss_tv = sem_tensorf.TV_loss_app(tvreg) * TV_weight_app
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)
- if sem_tensorf.use_raw_semfeas:
- if TV_weight_semfea > 0:
- TV_weight_semfea *= lr_factor
- loss_tv = sem_tensorf.TV_loss_sem(tvreg) * TV_weight_semfea
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_semfea', loss_tv.detach().item(), global_step=iteration)
-
- optimizer.zero_grad()
-
- scaler.scale(total_loss).backward()
- scaler.step(optimizer)
- scaler.update()
-
- for param_group in optimizer.param_groups:
- param_group['lr'] = param_group['lr'] * lr_factor
-
- if sem_tensorf.use_rgbs:
- rgb_loss = rgb_loss.detach().item()
-
- PSNRs.append(-10.0 * np.log(rgb_loss) / np.log(10.0))
- summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
- summary_writer.add_scalar('train/mse', rgb_loss, global_step=iteration)
-
- if sem_tensorf.use_raw_semfeas:
- distill_loss = distill_loss.detach().item()
-
- summary_writer.add_scalar('train/distill_mse', distill_loss, global_step=iteration)
-
- if loss_3d != {}:
- for key in loss_3d.keys():
- summary_writer.add_scalar(f'train/loss_3d_{key}', loss_3d[key], global_step=iteration)
-
- # Print the current values of the losses.
- if (iteration + 1) % args.progress_refresh_rate == 0:
- desc = f'Iteration {(iteration + 1):05d}:' \
- + f'3d {loss_3d_total:.2f}'
- if sem_tensorf.use_rgbs:
- desc += (f' train_psnr = {float(np.mean(PSNRs)):.2f}'
- + f' mse = {rgb_loss:.6f}')
- if sem_tensorf.use_raw_semfeas:
- desc += f' distill_mse = {distill_loss:.6f}'
-
- pbar.set_description(desc)
- pbar.set_postfix({'iter': f'{iteration}'})
- PSNRs = []
-
- # try on test datasets during training
- if iteration % args.vis_every == args.vis_every - 1 and args.N_vis != 0:
- PSNRs_test = evaluation_with_semfea(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis,
- prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray,
- compute_extra_metrics=False, fp16=args.fp16, chunk_size=args.batch_size,
- use_rgbs=sem_tensorf.use_rgbs, use_raw_semfeas=sem_tensorf.use_raw_semfeas)
- if PSNRs_test is not None:
- summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
-
- if iteration in update_AlphaMask_list:
- if reso_cur[0] * reso_cur[1] * reso_cur[2] < 256 ** 3: # update volume resolution
- reso_mask = reso_cur
- new_aabb = sem_tensorf.updateAlphaMask(tuple(reso_mask))
- if iteration == update_AlphaMask_list[0]:
- sem_tensorf.shrink(new_aabb)
- # tensorVM.alphaMask = None
- L1_reg_weight = args.L1_weight_rest
- print("continuing L1_reg_weight", L1_reg_weight)
-
- if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
- # filter rays outside the bbox
- allrays, allrgbs, allsemfeas = sem_tensorf.filtering_rays(allrays, allrgbs, allsemfeas)
- trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size)
-
- # enlarge gridSize
- if iteration in upsamp_list:
- n_voxels = N_voxel_list.pop(0)
- reso_cur = N_to_reso(n_voxels, sem_tensorf.aabb)
- nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
- sem_tensorf.upsample_volume_grid(reso_cur)
-
- if args.lr_upsample_reset:
- print("reset lr to initial")
- lr_scale = 1 # 0.1 ** (iteration / args.n_iters)
- else:
- lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
- grad_vars = sem_tensorf.get_optparam_groups(args.lr_init * lr_scale, args.lr_basis * lr_scale)
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
- end_t = time.time()
-
- print(f"Training {args.expname} takes {(end_t - start_t)/ 60:.4f} minutes.")
-
- sem_tensorf.save(f'{logfolder}/{args.expname}.pth', iteration)
-
- if args.render_train:
- os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
- train_dataset = dataset(args.datadir, split='train', is_stack=True)
- PSNRs_test = evaluation_with_semfea(train_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_train_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- fp16=args.fp16, chunk_size=args.batch_size)
- print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_test:
- os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
- PSNRs_test = evaluation_with_semfea(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_test_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- fp16=args.fp16, chunk_size=args.batch_size)
- summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
- print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
-
- # todo: fix
- # images from new view-points
- if args.render_path:
- c2ws = test_dataset.render_path
- # c2ws = test_dataset.poses
- print('========>', c2ws.shape)
- os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
- evaluation_path_with_sem(test_dataset, sem_tensorf, c2ws, sem_renderer, f'{logfolder}/imgs_path_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- fp16=args.fp16, chunk_size=args.batch_size)
-
-
- def reconstruction_sem(args):
- # init dataset
- dataset = scannet.ScanNet
-
- train_dataset = dataset(args.datadir, near=args.near, far=args.far, scene_bbox_stretch=args.scene_bbox_stretch,
- split='train', is_stack=True, use_sem=True, sem_interval=args.label_interval,
- dino_feature_dir=args.dino_feature_dir)
- test_dataset = dataset(args.datadir, near=args.near, far=args.far, scene_bbox_stretch=args.scene_bbox_stretch,
- split='test', is_stack=True, use_sem=True, sem_interval=args.label_interval)
- train_dataset.get_sem_label_num(train_dataset.sem_samples["sem_img"], test_dataset.sem_samples["sem_img"],
- args.sem_info_path)
- test_dataset.get_sem_label_num(train_dataset.sem_samples["sem_img"], test_dataset.sem_samples["sem_img"],
- args.sem_info_path)
-
- train_dataset.set_label2color_map(args.color_dict)
- test_dataset.set_label2color_map(args.color_dict)
- # train_dataset.save_sem_color_map(resize=True)
- # test_dataset.save_sem_color_map(resize=True)
- train_dataset.select_sems()
- '''sparse 20'''
- # 10_20
- # train_dataset.select_sems(select_idx=torch.tensor([40, 67, 72, 80,100, 120,
- # 140,
- # 151,
- # 160,
- # 168,
- # 180,
- # 200,
- # 210,
- # 220,
- # 240,
- # 260,
- # 264]))
- # 12_20
- # train_dataset.select_sems(select_idx=torch.tensor([20, 40, 60, 82, 90, 107, 120, 140, 160, 195,
- # 210, 220, 245, 267, 291, 300, 329]))
- # 24_20
- # train_dataset.select_sems(select_idx=torch.tensor([40, 60, 80, 100, 113, 120, 131, 152, 160,
- # 171,
- # 180,
- # 192,
- # 200,
- # 220,
- # 240,
- # 254]))
- # 33_20
- # train_dataset.select_sems(select_idx=torch.tensor([0,
- # 40,
- # 60,
- # 80,
- # 100,
- # 107,
- # 120,
- # 127,
- # 137,
- # 140,
- # 160,
- # 180,
- # 200,
- # 220,
- # 230,
- # 240,
- # 260,
- # 272]))
- # 88_20
- # train_dataset.select_sems(select_idx=torch.tensor([0, 76, 84, 95, 102, 114, 121, 133,
- # 140,
- # 152,
- # 160,
- # 180,
- # 200,
- # 230,
- # 240,
- # 260,
- # 280,
- # 300]))
- '''sparse 10'''
- # 10_10
- # train_dataset.select_sems(select_idx=torch.tensor([20, 30, 60, 67, 70, 75, 80, 90, 100,
- # 110,
- # 120,
- # 130,
- # 134,
- # 140,
- # 144, 148, 150, 160, 170, 180, 190, 200,
- # 210, 215,
- # 220,
- # 230,
- # 240,
- # 250,
- # 256,
- # 260,
- # 270,
- # 280,
- # 300,
- # 310]))
- # 12_10,
- # train_dataset.select_sems(select_idx=torch.tensor([0, 10, 20, 30, 40, 50, 60, 70, 82, 90,
- # 100, 110, 120, 130, 140, 150, 160, 170, 180, 190,
- # 200, 210, 220, 240, 245, 250, 260, 270, 280, 290,
- # 300, 310, 320, 330]))
- # 24_10
- # train_dataset.select_sems(select_idx=torch.tensor([0,
- # 10,
- # 30,
- # 40,
- # 44,
- # 50,
- # 60,
- # 63,
- # 70,
- # 80,
- # 90,
- # 110,
- # 113,
- # 120,
- # 126,
- # 131,
- # 135,
- # 140,
- # 150,
- # 160,
- # 164,
- # 170,
- # 180,
- # 190,
- # 192,
- # 200,
- # 210,
- # 220,
- # 230,
- # 240,
- # 245,
- # 253]))
- # 33_10
- # train_dataset.select_sems(select_idx=torch.tensor([0,
- # 40,
- # 50,
- # 60,
- # 70,
- # 80,
- # 80,
- # 100,
- # 107,
- # 110,
- # 115,
- # 120,
- # 130,
- # 140,
- # 150,
- # 155,
- # 160,
- # 165,
- # 170,
- # 180,
- # 196,
- # 200,
- # 210,
- # 220,
- # 230,
- # 235,
- # 240,
- # 250,
- # 260,
- # 265,
- # 270,
- # 280,
- # 300,
- # 320,
- # 350]))
- # 88_10
- # train_dataset.select_sems(select_idx=torch.tensor([20, 30, 50, 70, 73, 77, 80, 90, 95, 100,
- # 110,
- # 115,
- # 120,
- # 130,
- # 140,
- # 150,
- # 155,
- # 160,
- # 180,
- # 190,
- # 200,
- # 210,
- # 220,
- # 235,
- # 240,
- # 245,
- # 250,
- # 260,
- # 265,
- # 270,
- # 280,
- # 290,
- # 320,
- # 330,
- # 340,
- # 360]))
-
- assert hasattr(train_dataset, 'selected_rays')
-
- white_bg = train_dataset.white_bg
- near_far = train_dataset.near_far
- ndc_ray = args.ndc_ray
-
- # init resolution ?
- upsamp_list = args.upsamp_list
- update_AlphaMask_list = args.update_AlphaMask_list
-
- n_lamb_sigma = args.n_lamb_sigma
- n_lamb_sh = args.n_lamb_sh
- n_lamb_sem = args.n_lamb_sem
-
- if args.add_timestamp:
- logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
- else:
- logfolder = f'{args.basedir}/{args.expname}'
-
- # init log file
- os.makedirs(logfolder, exist_ok=True)
- os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
- os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
- os.makedirs(f'{logfolder}/rgba', exist_ok=True)
- os.makedirs(f'{logfolder}/sem', exist_ok=True) # ?
- summary_writer = SummaryWriter(logfolder)
-
- # init parameters
- # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
- aabb = train_dataset.scene_bbox.to(device) # adaptive scene_bbox world coordinates, (2, 3)
- reso_cur = N_to_reso(args.N_voxel_init, aabb) # N_voxel_grids of each dimension
- nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
-
- # todo: update sem_dim
- # build the model
- if args.ckpt is not None:
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- start_iter = kwargs.pop('N_iter')
-
- sem_tensorf = eval(args.model_name)(**kwargs)
- if args.use_sdf:
- sem_tensorf.set_sdf_init()
- sem_tensorf.load(ckpt)
- elif args.rawfea_ckpt is not None:
- rawfea_ckpt = torch.load(args.rawfea_ckpt, map_location=device)
- kwargs = rawfea_ckpt['kwargs']
- kwargs.update({'n_dino_fea': 64})
- kwargs.update({'use_raw_semfeas': False})
- kwargs.update({'use_rgbs': True})
- # kwargs.update({'sem_dim': train_dataset.num_semantic_class})
- kwargs.update({'use_pe_objdecoder': args.use_pe_objdecoder})
- kwargs.update({'device': device})
- start_iter = kwargs.pop('N_iter') + 1
-
- sem_tensorf = eval(args.model_name)(**kwargs)
- if args.use_sdf:
- sem_tensorf.set_sdf_init()
- sem_tensorf.load(rawfea_ckpt)
- else:
- sem_tensorf = eval(args.model_name)(aabb, reso_cur, device, gpu_ids=gpu_id,
- density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, sem_n_comp=n_lamb_sem,
- app_dim=args.data_dim_color, sem_dim=train_dataset.num_valid_semantic_class,
- near_far=near_far,
- shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre,
- density_shift=args.density_shift, distance_scale=args.distance_scale,
- pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe,
- featureC=args.featureC, step_ratio=args.step_ratio,
- fea2denseAct=args.fea2denseAct)
- start_iter = 0
- if args.use_sdf:
- sem_tensorf.set_sdf_init()
-
- if args.warmup:
- sem_tensorf.use_rgbs = True
- sem_tensorf.use_raw_semfeas = True
-
- train_dataset.load_dino_features()
- grad_vars = sem_tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
- if args.lr_decay_iters > 0:
- lr_factor = args.lr_decay_target_ratio ** (1 / args.lr_decay_iters)
- else:
- args.lr_decay_iters = args.n_iters
- lr_factor = args.lr_decay_target_ratio ** (1 / args.n_iters)
-
- print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
-
- rgb_loss_weight = 1.
- distill_loss_weight = 0.1
- loss_3d_weight = 1e-4 # 0.1 Manhattan sdf eikonal_weight
-
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
- scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
-
- # linear in logrithmic space
- N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list) + 1))).long()).tolist()[1:]
-
- torch.cuda.empty_cache()
- PSNRs, PSNRs_test = [], [0]
-
- allrays, allrgbs, allsemfeas = train_dataset.all_rays, train_dataset.all_rgbs, train_dataset.all_semfeas
- allrays = allrays.reshape(-1, allrays.shape[-1])
- allrgbs = allrgbs.reshape(-1, allrgbs.shape[-1])
- allsemfeas = allsemfeas.reshape(-1, allsemfeas.shape[-1])
-
- if not args.ndc_ray:
- allrays, allrgbs, allsemfeas = sem_tensorf.filtering_rays(allrays, allrgbs, allsemfeas, bbox_only=True)
- trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
-
- Ortho_reg_weight = args.Ortho_weight
- print("initial Ortho_reg_weight", Ortho_reg_weight)
-
- L1_reg_weight = args.L1_weight_inital
- print("initial L1_reg_weight", L1_reg_weight)
- TV_weight_density, TV_weight_app, TV_weight_semfea = args.TV_weight_density, args.TV_weight_app, args.TV_weight_semfea
- tvreg = TVLoss()
- print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app} raw_sem_features: {TV_weight_semfea}")
-
- pbar = tqdm(range(start_iter, start_iter + args.warmup_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
-
- if args.fp16:
- torch.set_default_tensor_type('torch.cuda.HalfTensor')
- # use values on each pixel, not each image
- for iteration in pbar:
- ray_idx = trainingSampler.nextids() # Tensor, ()
- rays_train, rgbs_train, semfeas_train = allrays[ray_idx], allrgbs[ray_idx].to(device), allsemfeas[
- ray_idx].to(device)
-
- # rgb_map, alphas_map, sem_map, depth_map, weights, uncertainty
- rgb_map, alphas_map, sem_map, semfea_map, depth_map, weights, uncertainty, loss_3d = sem_renderer(
- rays_train,
- sem_tensorf,
- chunk=args.batch_size,
- N_samples=nSamples,
- white_bg=white_bg,
- ndc_ray=ndc_ray,
- device=device,
- is_train=True,
- # fp16=args.fp16,
- use_rgbs=sem_tensorf.use_rgbs,
- use_raw_semfeas=sem_tensorf.use_raw_semfeas)
-
- loss_3d_total = 0
- if loss_3d != {}:
- for key in loss_3d.keys():
- loss_3d_total += loss_3d[key] * loss_3d_weight
-
- # loss
- if sem_tensorf.use_rgbs:
- rgb_loss = torch.mean((rgb_map - rgbs_train) ** 2)
- if sem_tensorf.use_raw_semfeas:
- distill_loss = torch.mean((semfea_map - semfeas_train) ** 2)
-
- total_loss = loss_3d_total * loss_3d_weight
- if sem_tensorf.use_rgbs:
- total_loss += rgb_loss * rgb_loss_weight
- if sem_tensorf.use_raw_semfeas:
- total_loss += distill_loss * distill_loss_weight
-
- if sem_tensorf.use_rgbs:
- if Ortho_reg_weight > 0:
- loss_reg = sem_tensorf.vector_comp_diffs()
- total_loss += Ortho_reg_weight * loss_reg # ?
- summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
- if L1_reg_weight > 0:
- loss_reg_L1 = sem_tensorf.density_L1()
- total_loss += L1_reg_weight * loss_reg_L1
- summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)
- if TV_weight_density > 0:
- TV_weight_density *= lr_factor
- loss_tv = sem_tensorf.TV_loss_density(tvreg) * TV_weight_density
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
- if TV_weight_app > 0:
- TV_weight_app *= lr_factor
- loss_tv = sem_tensorf.TV_loss_app(tvreg) * TV_weight_app
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)
- if sem_tensorf.use_raw_semfeas:
- if TV_weight_semfea > 0:
- TV_weight_semfea *= lr_factor
- loss_tv = sem_tensorf.TV_loss_sem(tvreg) * TV_weight_semfea
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_semfea', loss_tv.detach().item(), global_step=iteration)
-
- optimizer.zero_grad()
-
- scaler.scale(total_loss).backward()
- scaler.step(optimizer)
- scaler.update()
-
- for param_group in optimizer.param_groups:
- param_group['lr'] = param_group['lr'] * lr_factor
-
- if sem_tensorf.use_rgbs:
- rgb_loss = rgb_loss.detach().item()
-
- PSNRs.append(-10.0 * np.log(rgb_loss) / np.log(10.0))
- summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
- summary_writer.add_scalar('train/mse', rgb_loss, global_step=iteration)
-
- if sem_tensorf.use_raw_semfeas:
- distill_loss = distill_loss.detach().item()
-
- summary_writer.add_scalar('train/distill_mse', distill_loss, global_step=iteration)
-
- if loss_3d != {}:
- for key in loss_3d.keys():
- summary_writer.add_scalar(f'train/loss_3d_{key}', loss_3d[key], global_step=iteration)
-
- # Print the current values of the losses.
- if (iteration + 1) % args.progress_refresh_rate == 0:
- desc = f'Iteration {(iteration + 1):05d}:'
- if sem_tensorf.use_rgbs:
- desc += (f' train_psnr = {float(np.mean(PSNRs)):.2f}'
- + f' mse = {rgb_loss:.6f}')
- if sem_tensorf.use_raw_semfeas:
- desc += f' distill_mse = {distill_loss:.6f}'
- desc += f' 3d:{loss_3d_total:.2f}'
-
- pbar.set_description(desc)
- pbar.set_postfix({'iter': f'{iteration}'})
- PSNRs = []
-
- if iteration in update_AlphaMask_list:
- if reso_cur[0] * reso_cur[1] * reso_cur[2] < 256 ** 3: # update volume resolution
- reso_mask = reso_cur
- new_aabb = sem_tensorf.updateAlphaMask(tuple(reso_mask))
- if iteration == update_AlphaMask_list[0]:
- sem_tensorf.shrink(new_aabb)
- # tensorVM.alphaMask = None
- L1_reg_weight = args.L1_weight_rest
- print("continuing L1_reg_weight", L1_reg_weight)
-
- if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
- # filter rays outside the bbox
- allrays, allrgbs, allsemfeas = sem_tensorf.filtering_rays(allrays, allrgbs, allsemfeas)
- trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size)
-
- # enlarge gridSize
- if iteration in upsamp_list:
- n_voxels = N_voxel_list.pop(0)
- reso_cur = N_to_reso(n_voxels, sem_tensorf.aabb)
- nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
- sem_tensorf.upsample_volume_grid(reso_cur)
-
- if args.lr_upsample_reset:
- print("reset lr to initial")
- lr_scale = 1 # 0.1 ** (iteration / args.n_iters)
- else:
- lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
- grad_vars = sem_tensorf.get_optparam_groups(args.lr_init * lr_scale, args.lr_basis * lr_scale)
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
- sem_tensorf.save(f'{logfolder}/{args.expname}_warmup.pth', iteration)
- torch.set_default_tensor_type('torch.FloatTensor')
- sem_tensorf.use_rgbs = False
- sem_tensorf.use_raw_semfeas = False
- start_iter = iteration + 1
- print("\nFinish warming up")
-
- # build optimizer
- grad_vars = sem_tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
- if args.lr_decay_iters > 0:
- lr_factor = args.lr_decay_target_ratio ** (1 / args.lr_decay_iters)
- else:
- args.lr_decay_iters = args.n_iters
- lr_factor = args.lr_decay_target_ratio ** (1 / args.n_iters)
-
- print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
- loss_3d_weight = 1e-4 # 0.1 Manhattan sdf eikonal_weight
- if not args.use_mul_task_loss:
- rgb_loss_weight = 1.
- sem_loss_weight = 4e-2
- else:
- # rgb_loss_weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device)
- # rgb_loss_weight.data.fill_(1.)
- # grad_vars.append({'params': rgb_loss_weight, 'lr': 0.001})
- #
- # sem_loss_weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device)
- # sem_loss_weight.data.fill_(4e-2)
- # grad_vars.append({'params': sem_loss_weight, 'lr': 0.001})
- criterion_mul_task = MultiTaskLossWrapper()
- grad_vars.append({'params': criterion_mul_task.log_vars, 'lr': 0.001})
-
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
- scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
-
- # linear in logrithmic space
- N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list) + 1))).long()).tolist()[1:]
-
- torch.cuda.empty_cache()
- PSNRs, PSNRs_test = [], [0]
-
- selectedrays, selectedrgbs, selectedsems = train_dataset.selected_rays, train_dataset.selected_rgbs, train_dataset.selected_sems
- selectedrays = selectedrays.reshape(-1, selectedrays.shape[-1]) # (num_imgs*h*w, 6)
- selectedrgbs = selectedrgbs.reshape(-1, selectedrgbs.shape[-1]) # (num_imgs*h*w, 3)
- selectedsems = selectedsems.reshape(-1, selectedsems.shape[-1]) # (num_imgs*h*w, 1)
-
- '''freeze & use selected rgbs'''
- if not args.ndc_ray:
- selectedrays, selectedrgbs, selectedsems = sem_tensorf.filtering_rays(selectedrays, all_rgbs=selectedrgbs, all_sems=selectedsems, bbox_only=True)
-
- trainingSampler = SimpleSampler(selectedrays.shape[0], args.batch_size)
- '''all rgb'''
- # allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs
- # allrays = allrays.reshape(-1, allrays.shape[-1]) # (all_imgs*h*w, 6)
- # allrgbs = allrgbs.reshape(-1, allrgbs.shape[-1]) # (all_imgs*h*w, 3)
- #
- # if not args.ndc_ray:
- # selectedrays, selectedrgbs, selectedsems = sem_tensorf.filtering_rays(selectedrays, all_rgbs=selectedrgbs, all_sems=selectedsems, bbox_only=True)
- # allrays, allrgbs, allsems = sem_tensorf.filtering_rays(selectedrays, all_rgbs=selectedrgbs, bbox_only=True)
- #
- # semtrainingSampler = SimpleSampler(selectedrays.shape[0], args.batch_size) # sem sampler
- # rgbtrainSampler = SimpleSampler(allrays.shape[0], args.batch_size)
-
- Ortho_reg_weight = args.Ortho_weight
- print("initial Ortho_reg_weight", Ortho_reg_weight)
-
- # L1_reg_weight = args.L1_weight_inital
- # print("initial L1_reg_weight", L1_reg_weight)
- TV_weight_density, TV_weight_app, TV_weight_sem = args.TV_weight_density, args.TV_weight_app, args.TV_weight_sem
- tvreg = TVLoss()
- # print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app} semantic: {TV_weight_sem}")
- print(f"initial TV_weight semantic: {TV_weight_sem}")
-
- pbar = tqdm(range(start_iter, start_iter+args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
-
- '''freeze'''
- set_requires_grad(sem_tensorf, keys_excl=['ins_basis_mat'], requires_grad=False)
- '''use selected rgbs'''
- # set_requires_grad(sem_tensorf, keys_incl=['sem_plane', 'sem_line', 'semfea_basis_mat'], requires_grad=False)
-
- if args.fp16:
- torch.set_default_tensor_type('torch.cuda.HalfTensor')
-
- start_t = time.time()
-
- # use values on each pixel, not each image
- for iteration in pbar:
- ray_idx = trainingSampler.nextids() # Tensor, ()
- rays_train, sem_train = selectedrays[ray_idx], selectedsems[ray_idx].to(device)
- if sem_tensorf.use_rgbs:
- rgb_train = selectedrgbs[ray_idx].to(device)
-
- # rgb_map, alphas_map, sem_map, depth_map, weights, uncertainty
- rgb_map, alphas_map, sem_map, semfea_map, depth_map, weights, uncertainty, loss_3d = sem_renderer(rays_train,
- sem_tensorf,
- chunk=args.batch_size,
- N_samples=nSamples,
- white_bg=white_bg,
- ndc_ray=ndc_ray,
- device=device,
- is_train=True,
- fp16=args.fp16,
- use_rgbs=sem_tensorf.use_rgbs,
- use_raw_semfeas=sem_tensorf.use_raw_semfeas)
- loss_3d_total = 0
- if loss_3d != {}:
- for key in loss_3d.keys():
- loss_3d_total += loss_3d[key] * loss_3d_weight
-
- # loss
- sem_crossentropy_loss = train_dataset.get_sem_loss(sem_map, sem_train)
- total_loss = sem_crossentropy_loss * sem_loss_weight + loss_3d_total
- if sem_tensorf.use_rgbs:
- rgb_loss = torch.mean((rgb_map - rgb_train) ** 2)
- total_loss += rgb_loss * rgb_loss_weight
-
- # if Ortho_reg_weight > 0:
- # loss_reg = sem_tensorf.vector_comp_diffs()
- # total_loss += Ortho_reg_weight * loss_reg # ?
- # summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
- if Ortho_reg_weight > 0:
- loss_sem_reg = sem_tensorf.vector_sem_comp_diffs()
- total_loss += Ortho_reg_weight * loss_sem_reg
- summary_writer.add_scalar('train/sem_reg', loss_sem_reg.detach().item(), global_step=iteration)
- # if L1_reg_weight > 0:
- # loss_reg_L1 = sem_tensorf.density_L1()
- # total_loss += L1_reg_weight * loss_reg_L1
- # summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)
-
- # if TV_weight_density > 0:
- # TV_weight_density *= lr_factor
- # loss_tv = sem_tensorf.TV_loss_density(tvreg) * TV_weight_density
- # total_loss = total_loss + loss_tv
- # summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
- # if TV_weight_app > 0:
- # TV_weight_app *= lr_factor
- # loss_tv = sem_tensorf.TV_loss_app(tvreg) * TV_weight_app
- # total_loss = total_loss + loss_tv
- # summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)
- if TV_weight_sem > 0:
- TV_weight_sem *= lr_factor
- loss_tv = sem_tensorf.TV_loss_sem(tvreg) * TV_weight_sem
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_sem', loss_tv.detach().item(), global_step=iteration)
-
- optimizer.zero_grad()
-
- scaler.scale(total_loss).backward()
- scaler.step(optimizer)
- scaler.update()
-
- sem_loss = sem_crossentropy_loss.detach().item()
- if sem_tensorf.use_rgbs:
- rgb_loss = rgb_loss.detach().item()
-
- PSNRs.append(-10.0 * np.log(rgb_loss) / np.log(10.0))
- summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
- summary_writer.add_scalar('train/mse', rgb_loss, global_step=iteration)
- PSNRs.append(-10.0 * np.log(rgb_loss) / np.log(10.0))
-
- summary_writer.add_scalar('train/sem_CE', sem_loss, global_step=iteration)
-
- if loss_3d != {}:
- for key in loss_3d.keys():
- summary_writer.add_scalar(f'train/loss_3d_{key}', loss_3d[key], global_step=iteration)
-
- for param_group in optimizer.param_groups:
- param_group['lr'] = param_group['lr'] * lr_factor
-
- # Print the current values of the losses.
- if (iteration + 1) % args.progress_refresh_rate == 0:
- desc = f'Iteration {(iteration + 1):05d}:' \
- + f' 3d:{loss_3d_total:.5f}'\
- + f' CE:{sem_loss:.5f}'
- if sem_tensorf.use_rgbs:
- desc += (f' train_psnr = {float(np.mean(PSNRs)):.2f}'
- + f' mse = {rgb_loss:.6f}')
-
- pbar.set_description(desc)
-
- # try on test datasets during training
- # evaluation_sem(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_vis/',
- # N_vis=args.N_vis,
- # prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg,
- # ndc_ray=ndc_ray,
- # compute_extra_metrics=False, fp16=args.fp16, chunk_size=args.batch_size)
-
- if iteration % args.vis_every == args.vis_every - 1 and args.N_vis != 0:
- evaluation_sem(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_vis/',
- N_vis=args.N_vis,
- prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg,
- ndc_ray=ndc_ray,
- compute_extra_metrics=False, fp16=args.fp16, chunk_size=args.batch_size)
- # evaluation_sem(train_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_vis_train/',
- # N_vis=args.N_vis,
- # prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg,
- # ndc_ray=ndc_ray,
- # compute_extra_metrics=False, fp16=args.fp16, chunk_size=args.batch_size)
-
- if iteration in update_AlphaMask_list:
- if reso_cur[0] * reso_cur[1] * reso_cur[2] < 256 ** 3: # update volume resolution
- reso_mask = reso_cur
- new_aabb = sem_tensorf.updateAlphaMask(tuple(reso_mask))
- if iteration == update_AlphaMask_list[0]:
- sem_tensorf.shrink(new_aabb)
- # tensorVM.alphaMask = None
- L1_reg_weight = args.L1_weight_rest
- print("continuing L1_reg_weight", L1_reg_weight)
-
- if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
- # filter rays outside the bbox
- selectedrays, allrgbs, selectedsems = sem_tensorf.filtering_rays(selectedrays, all_sems=selectedsems)
- trainingSampler = SimpleSampler(selectedsems.shape[0], args.batch_size)
-
- # enlarge gridSize
- if iteration in upsamp_list:
- n_voxels = N_voxel_list.pop(0)
- reso_cur = N_to_reso(n_voxels, sem_tensorf.aabb)
- nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
- sem_tensorf.upsample_volume_grid(reso_cur)
-
- if args.lr_upsample_reset:
- print("reset lr to initial")
- lr_scale = 1 # 0.1 ** (iteration / args.n_iters)
- else:
- lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
- grad_vars = sem_tensorf.get_optparam_groups(args.lr_init * lr_scale, args.lr_basis * lr_scale)
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
- '''allrgb'''
- # for iteration in pbar:
- # sem_ray_idx = semtrainingSampler.nextids() # Tensor, ()
- # rgb_ray_idx = rgbtrainSampler.nextids()
- #
- # rays_sem_train, sem_train = selectedrays[sem_ray_idx], selectedsems[sem_ray_idx].to(device)
- # rays_rgb_train, rgb_train = allrays, allrgbs
- #
- # # rgb_map, alphas_map, sem_map, depth_map, weights, uncertainty
- # rgb_map, alphas_map, sem_map, semfea_map, depth_map, weights, uncertainty, loss_3d = sem_renderer(rays_train,
- # sem_tensorf,
- # chunk=args.batch_size,
- # N_samples=nSamples,
- # white_bg=white_bg,
- # ndc_ray=ndc_ray,
- # device=device,
- # is_train=True,
- # fp16=args.fp16,
- # use_rgbs=sem_tensorf.use_rgbs,
- # use_raw_semfeas=sem_tensorf.use_raw_semfeas)
- # loss_3d_total = 0
- # if loss_3d != {}:
- # for key in loss_3d.keys():
- # loss_3d_total += loss_3d[key] * loss_3d_weight
- #
- # # loss
- # sem_crossentropy_loss = train_dataset.get_sem_loss(sem_map, sem_train)
- # total_loss = sem_crossentropy_loss * sem_loss_weight + loss_3d_total
- # if sem_tensorf.use_rgbs:
- # rgb_loss = torch.mean((rgb_map - rgb_train) ** 2)
- # total_loss += rgb_loss * rgb_loss_weight
- #
- # # if Ortho_reg_weight > 0:
- # # loss_reg = sem_tensorf.vector_comp_diffs()
- # # total_loss += Ortho_reg_weight * loss_reg # ?
- # # summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
- # if Ortho_reg_weight > 0:
- # loss_sem_reg = sem_tensorf.vector_sem_comp_diffs()
- # total_loss += Ortho_reg_weight * loss_sem_reg
- # summary_writer.add_scalar('train/sem_reg', loss_sem_reg.detach().item(), global_step=iteration)
- # # if L1_reg_weight > 0:
- # # loss_reg_L1 = sem_tensorf.density_L1()
- # # total_loss += L1_reg_weight * loss_reg_L1
- # # summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)
- #
- # # if TV_weight_density > 0:
- # # TV_weight_density *= lr_factor
- # # loss_tv = sem_tensorf.TV_loss_density(tvreg) * TV_weight_density
- # # total_loss = total_loss + loss_tv
- # # summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
- # # if TV_weight_app > 0:
- # # TV_weight_app *= lr_factor
- # # loss_tv = sem_tensorf.TV_loss_app(tvreg) * TV_weight_app
- # # total_loss = total_loss + loss_tv
- # # summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)
- # if TV_weight_sem > 0:
- # TV_weight_sem *= lr_factor
- # loss_tv = sem_tensorf.TV_loss_sem(tvreg) * TV_weight_sem
- # total_loss = total_loss + loss_tv
- # summary_writer.add_scalar('train/reg_tv_sem', loss_tv.detach().item(), global_step=iteration)
- #
- # optimizer.zero_grad()
- #
- # scaler.scale(total_loss).backward()
- # scaler.step(optimizer)
- # scaler.update()
- #
- # sem_loss = sem_crossentropy_loss.detach().item()
- # if sem_tensorf.use_rgbs:
- # rgb_loss = rgb_loss.detach().item()
- #
- # PSNRs.append(-10.0 * np.log(rgb_loss) / np.log(10.0))
- # summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
- # summary_writer.add_scalar('train/mse', rgb_loss, global_step=iteration)
- # PSNRs.append(-10.0 * np.log(rgb_loss) / np.log(10.0))
- #
- # summary_writer.add_scalar('train/sem_CE', sem_loss, global_step=iteration)
- #
- # if loss_3d != {}:
- # for key in loss_3d.keys():
- # summary_writer.add_scalar(f'train/loss_3d_{key}', loss_3d[key], global_step=iteration)
- #
- # for param_group in optimizer.param_groups:
- # param_group['lr'] = param_group['lr'] * lr_factor
- #
- # # Print the current values of the losses.
- # if (iteration + 1) % args.progress_refresh_rate == 0:
- # desc = f'Iteration {(iteration + 1):05d}:' \
- # + f' CE:{sem_loss:.5f}'\
- # + f' 3d:{loss_3d_total:.5f}'
- # if sem_tensorf.use_rgbs:
- # desc += (f' train_psnr = {float(np.mean(PSNRs)):.2f}'
- # + f' mse = {rgb_loss:.6f}')
- #
- # pbar.set_description(desc)
- #
- # # try on test datasets during training
- # if iteration % args.vis_every == args.vis_every - 1 and args.N_vis != 0:
- # evaluation_sem(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_vis/',
- # N_vis=args.N_vis,
- # prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg,
- # ndc_ray=ndc_ray,
- # compute_extra_metrics=False, fp16=args.fp16, chunk_size=args.batch_size)
- #
- # # PSNRs_test = evaluation_with_sem(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_vis/',
- # # N_vis=args.N_vis,
- # # prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg,
- # # ndc_ray=ndc_ray,
- # # compute_extra_metrics=False, fp16=args.fp16, chunk_size=args.batch_size)
- # # summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
- #
- # if iteration in update_AlphaMask_list:
- # if reso_cur[0] * reso_cur[1] * reso_cur[2] < 256 ** 3: # update volume resolution
- # reso_mask = reso_cur
- # new_aabb = sem_tensorf.updateAlphaMask(tuple(reso_mask))
- # if iteration == update_AlphaMask_list[0]:
- # sem_tensorf.shrink(new_aabb)
- # # tensorVM.alphaMask = None
- # L1_reg_weight = args.L1_weight_rest
- # print("continuing L1_reg_weight", L1_reg_weight)
- #
- # if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
- # # filter rays outside the bbox
- # selectedrays, allrgbs, selectedsems = sem_tensorf.filtering_rays(selectedrays, all_sems=selectedsems)
- # trainingSampler = SimpleSampler(selectedsems.shape[0], args.batch_size)
- #
- # # enlarge gridSize
- # if iteration in upsamp_list:
- # n_voxels = N_voxel_list.pop(0)
- # reso_cur = N_to_reso(n_voxels, sem_tensorf.aabb)
- # nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
- # sem_tensorf.upsample_volume_grid(reso_cur)
- #
- # if args.lr_upsample_reset:
- # print("reset lr to initial")
- # lr_scale = 1 # 0.1 ** (iteration / args.n_iters)
- # else:
- # lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
- # grad_vars = sem_tensorf.get_optparam_groups(args.lr_init * lr_scale, args.lr_basis * lr_scale)
- # optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
- end_t = time.time()
-
- print(f"Training {args.expname} takes {(end_t - start_t)/ 60:.4f} minutes.")
-
- sem_tensorf.save(f'{logfolder}/{args.expname}.pth', iteration)
-
- if args.render_train:
- os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
- train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
- PSNRs_test = evaluation_with_sem(train_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_train_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- fp16=args.fp16, chunk_size=args.batch_size)
- print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_test:
- os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
- PSNRs_test = evaluation_with_sem(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_test_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- fp16=args.fp16, chunk_size=args.batch_size)
- summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
- print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
-
- # images from new view-points
- if args.render_path:
- c2ws = test_dataset.render_path
- # c2ws = test_dataset.poses
- print('========>', c2ws.shape)
- os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
- evaluation_path_with_sem(test_dataset, sem_tensorf, c2ws, sem_renderer, f'{logfolder}/imgs_path_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- fp16=args.fp16, chunk_size=args.batch_size)
-
-
- if __name__ == '__main__':
- torch.set_default_tensor_type('torch.FloatTensor')
- torch.manual_seed(20230404)
- np.random.seed(20230404)
-
- args = config_parser()
- # print(args)
-
- if args.export_mesh:
- export_mesh(args)
-
- if args.render_only and (args.render_test or args.render_path):
- render_test_with_sem(args)
-
- if args.distill_active:
- reconstruction_with_rawfeas(args)
- else:
- reconstruction_sem(args)
|