|
- import os
- import numpy as np
- import torch.nn
- from tqdm.auto import tqdm
- from opt_sem_replica import config_parser
- from opt_sem_kitti_360 import kitti_config_parser
-
- import json, random
- from renderer import *
- from utils.utils import *
- from torch.utils.tensorboard import SummaryWriter
- import datetime
-
- from datasets import dataset_dict
- import sys
-
- from models.tensorRFSemVMSplit import TensorSemVMSplit
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # input from TensorBase ?
-
- sem_renderer = OctreeRender_trilinear_fast_with_sem
-
-
- @torch.no_grad()
- def render_test_with_sem(args):
- # init dataset
- dataset = dataset_dict[args.dataset_name]
-
- if args.dataset_name == "kitti360":
- train_dataset = dataset(args.datadir, split='train', start=args.start, end=args.end, near=args.near, far=args.far, downsample=args.downsample_train, is_stack=True)
- test_dataset = dataset(args.datadir, split='test', start=args.start, end=args.end, near=args.near, far=args.far, test_ids=args.test_ids, downsample=args.downsample_train, is_stack=True)
- train_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"])
- test_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"])
- train_dataset.set_label_colour_map()
- test_dataset.set_label_colour_map()
- else:
- train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
- test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
- if args.dataset_name == "replica" or "replica_dmnerf":
- train_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"],
- args.sem_info_path)
- test_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"],
- args.sem_info_path)
- train_dataset.set_label_colour_map(args.sem_info_path)
- test_dataset.set_label_colour_map(args.sem_info_path)
-
- white_bg = test_dataset.white_bg
- ndc_ray = args.ndc_ray
-
- assert os.path.exists(args.ckpt), 'the ckpt path does not exists!!'
-
- n_lamb_sem, num_valid_sem_class = args.n_lamb_sem, train_dataset.num_valid_semantic_class
-
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
-
- sem_tensorf = eval(args.model_name)(**kwargs)
- sem_tensorf.load(ckpt)
-
- logfolder = os.path.dirname(args.ckpt)
-
- if args.render_train:
- os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
- PSNRs_test = evaluation_with_sem(train_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_train_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
- print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_test:
- os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
- evaluation_with_sem(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
-
- if args.render_path:
- c2ws = test_dataset.render_path
- os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
- evaluation_path_with_sem(test_dataset, sem_tensorf, c2ws, sem_renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
-
-
- if __name__ == '__main__':
- torch.set_default_dtype(torch.float32)
- torch.manual_seed(20211202)
- np.random.seed(20211202)
-
- args = config_parser()
-
- render_test_with_sem(args)
|