|
- import os
- import numpy as np
- import torch.nn
- from tqdm.auto import tqdm
- from opt_sem_replica import config_parser
- from opt_sem_kitti_360 import kitti_config_parser
-
- import json, random
- from renderer import *
- from utils.utils import *
- from torch.utils.tensorboard import SummaryWriter
- import datetime
-
- from datasets import dataset_dict
- import sys
-
- from models.tensorRFSemVMSplit import TensorSemVMSplit
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # input from TensorBase ?
-
- sem_renderer = OctreeRender_trilinear_fast_with_sem
-
-
- @torch.no_grad()
- def tps_render_testset(args):
- # init dataset
- dataset = dataset_dict[args.dataset_name]
-
- """vclab"""
- # office_0
- # todo: already: train_set 0, 11; test_set 0(vclab)
- # start_case = 1
- # end_case = 9
-
- """shb"""
- # room_0
- start_case = 1
- end_case = 10
-
- for case in range(start_case, end_case+1):
- datadir = os.path.join(args.datadir, f"case_{case}")
-
- if args.dataset_name == "kitti360":
- train_dataset = dataset(args.datadir, split='train', start=args.start, end=args.end, near=args.near, far=args.far, downsample=args.downsample_train, is_stack=True)
- test_dataset = dataset(args.datadir, split='test', start=args.start, end=args.end, near=args.near, far=args.far, test_ids=args.test_ids, downsample=args.downsample_train, is_stack=True)
- train_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"])
- test_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"])
- train_dataset.set_label_colour_map()
- test_dataset.set_label_colour_map()
- else:
- test_dataset = dataset(datadir, near=args.near, far=args.far,
- scene_bbox_stretch=args.scene_bbox_stretch,
- split='test', downsample=args.downsample_train, is_stack=True,
- use_tps_dataset=True)
- if args.dataset_name == "replica" or "replica_dmnerf":
- test_dataset.remap_sem_gt_label(load_map=True, ins2label_path=args.ins2label_path)
- test_dataset.set_label_colour_map(sem_info_path=args.sem_info_path,
- label2color_path=args.label2color_path)
-
- white_bg = test_dataset.white_bg
- ndc_ray = args.ndc_ray
-
- # load ckpt
- # todo: change if basedir changed
- ckpt_path = os.path.join(args.ckpt, f"case_{case}", "train_tps_tensorf", "train_tps_tensorf.pth")
-
- assert os.path.exists(ckpt_path), 'the ckpt path does not exists!!'
-
- ckpt = torch.load(ckpt_path, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
-
- sem_tensorf = eval(args.model_name)(**kwargs)
- sem_tensorf.load(ckpt)
-
- logfolder = os.path.dirname(ckpt_path)
-
- os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
- PSNRs_test = evaluation_with_sem(test_dataset, sem_tensorf, args, sem_renderer, f'{logfolder}/imgs_test_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- chunk_size=args.batch_size)
- print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
-
- print(f'======> render case_{case} done<========================')
-
-
- if __name__ == '__main__':
- torch.set_default_dtype(torch.float32)
- torch.manual_seed(20211202)
- np.random.seed(20211202)
-
- args = config_parser()
-
- tps_render_testset(args)
|