|
- import os
- import torch
- import numpy as np
- from renderer import *
- from datasets import dataset_dict
- from opt_sem_replica import config_parser
- from utils.tps.defomer_utils import TPS_Deformer
-
- from models.tensorRFSemVMSplit import TensorSemVMSplit
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- sem_renderer = OctreeRender_trilinear_fast_with_sem
-
-
- @torch.no_grad()
- def tps_warp(args):
- # load dataset
- dataset = dataset_dict[args.dataset_name]
-
- """vclab"""
- logfolder = "/media/alpha4TB/exchange/nerf_db/datasets/nerf_replica/replica_ins/office_0/tps_renderings/"
- # logfolder = "/media/alpha4TB/exchange/nerf_db/datasets/nerf_replica/replica_ins/room_0/tps_renderings/"
-
- """paraclouds"""
-
- """shb, gyh"""
- # logfolder = os.path.dirname(args.ckpt)
-
- if args.dataset_name == "kitti360":
- train_dataset = dataset(args.datadir, split='train', start=args.start, end=args.end,
- near=args.near, far=args.far, downsample=args.downsample_train, is_stack=True)
- test_dataset = dataset(args.datadir, split='test', start=args.start, end=args.end,
- near=args.near, far=args.far, test_ids=args.test_ids,
- downsample=args.downsample_train, is_stack=True)
- train_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"])
- test_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"])
- train_dataset.set_label_colour_map()
- test_dataset.set_label_colour_map()
- else:
- train_dataset = dataset(args.datadir, near=args.near, far=args.far, scene_bbox_stretch=args.scene_bbox_stretch,
- split='train', downsample=args.downsample_train, is_stack=True)
- test_dataset = dataset(args.datadir, near=args.near, far=args.far, scene_bbox_stretch=args.scene_bbox_stretch,
- split='test', downsample=args.downsample_train, is_stack=True)
- if args.dataset_name == "replica" or "replica_dmnerf":
- train_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"],
- args.sem_info_path)
- test_dataset.remap_sem_gt_label(train_dataset.sem_samples["sem_img"],
- test_dataset.sem_samples["sem_img"],
- args.sem_info_path)
- train_dataset.set_label_colour_map(sem_info_path=args.sem_info_path)
- test_dataset.set_label_colour_map(sem_info_path=args.sem_info_path)
-
- # load model
- assert os.path.exists(args.ckpt), 'the ckpt path does not exists!!'
-
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
-
- sem_tensorf = eval(args.model_name)(**kwargs)
- sem_tensorf.load(ckpt)
-
- white_bg = train_dataset.white_bg
- near_far = train_dataset.near_far
- ndc_ray = args.ndc_ray
-
- N_render_case = 1000 # first warp to 1000 cases, then reconstruction and select
-
- """vclab"""
- # office_0
- start_warp_case = 855
- end_warp_case = 860
- # room_0
- # start_warp_case = 500
- # end_warp_case = 500
-
- """shb"""
- # office_0
- # trainset
- # start_warp_case = 715
- # end_warp_case = 715
- # testset
-
- # room_0
- # trainset
- # start_warp_case = 500
- # end_warp_case = 500
- # testset
-
- """gyh"""
- # office_0
-
- # room_0
-
- for case in range(start_warp_case, end_warp_case+1):
- deformer = TPS_Deformer(sem_tensorf.gridSize, device, args.tps_scale)
- exp_folder = f'{logfolder}/case_{case}'
- os.makedirs(exp_folder, exist_ok=True)
-
- tps_params_path = os.path.join(exp_folder, "tps_params.pkl")
- if not os.path.exists(tps_params_path):
- deformer.save_tps_params(tps_save_path=tps_params_path)
-
- if args.render_train:
- print("========> tps on train_dataset")
- deformer.load_tps_params(tps_file_path=tps_params_path) # promise save and load params are the same
- tps_rendering(train_dataset, sem_tensorf, args, sem_renderer, f'{exp_folder}/train_dataset',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- tps_function=deformer.get_deform, chunk_size=args.batch_size)
- train_dataset.save_c2ws(c2w_save_path=f"{exp_folder}/train_dataset")
-
- print(f"========>tps warp case {case} train_dataset done")
-
- if args.render_test:
- print("========> tps on test_dataset")
- deformer.load_tps_params(tps_file_path=tps_params_path)
- tps_rendering(test_dataset, sem_tensorf, args, sem_renderer, f'{exp_folder}/test_dataset',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- tps_function=deformer.get_deform, chunk_size=args.batch_size)
- test_dataset.save_c2ws(c2w_save_path=f'{exp_folder}/test_dataset')
-
- print(f"========>tps warp case {case} test_dataset done")
-
-
- if __name__ == '__main__':
- torch.set_default_dtype(torch.float32)
-
- args = config_parser()
-
- tps_warp(args)
|