|
- import os
- import datetime
- import torch
- import numpy as np
- from renderer import *
- from utils.utils import *
- from datasets import replica_dmderf
- from opt_sem_replica import config_parser
-
- from utils.mani import manipulator, pose_generator
-
- from models.tensorRFSemVMSplit import TensorSemVMSplit
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # input from TensorBase ?
-
- sem_renderer = OctreeRender_trilinear_fast_mani
-
-
- @torch.no_grad()
- def mani(args):
- '''load dataset'''
- dataset = replica_dmderf.ReplicaDatasetDMNeRF
-
- test_dataset = dataset(args.datadir, near=args.near, far=args.far, scene_bbox_stretch=args.scene_bbox_stretch,
- split='test', downsample=args.downsample_train, is_stack=True, gen_rays=False, use_sem=True)
- test_dataset.remap_sem_gt_label(load_map=True, ins2label_path=args.ins2label_path)
- test_dataset.set_label_colour_map(sem_info_path=args.sem_info_path)
- test_dataset.plot_label_colormap()
-
- white_bg = test_dataset.white_bg
-
- gt_rgbs = test_dataset.all_rgbs
- gt_ins_labels = torch.tensor(test_dataset.sem_samples['sem_remap'], dtype=torch.int8)
-
- '''load ckpt'''
- assert os.path.exists(args.ckpt), 'the ckpt path does not exists!!'
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- iteration = kwargs.pop('N_iter')
- fp16 = kwargs['fp16']
- kwargs.update({'device': device})
- kwargs.update({'n_dino_fea': 64})
- kwargs.update({'use_rgbs': True})
- kwargs.update({'use_raw_semfeas': False})
- kwargs.update({'use_pe_objdecoder': True})
-
- sem_tensorf = eval(args.model_name)(**kwargs)
- sem_tensorf.load(ckpt)
-
- sem_tensorf.eval()
-
- if args.add_timestamp:
- logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
- else:
- logfolder = f'{args.basedir}/{args.expname}'
-
- os.makedirs(logfolder, exist_ok=True)
-
- print('Manipulating', args.mani_mode, '......')
- """this operations list can re-design"""
- mani_center = pose_generator.get_scene_center(test_dataset.poses) # scene_bbox center
- mani_pose_save_path = os.path.join(logfolder, 'transformation_matrix.json')
- pose_generator.generate_poses_eval(args, mani_center=mani_center, save_path=mani_pose_save_path)
- trans_dicts = pose_generator.load_mani_poses(args, load_path=mani_pose_save_path)
-
- manipulator.manipulator_eval(model=sem_tensorf, renderer=sem_renderer, ori_poses=test_dataset.poses,
- directions=test_dataset.directions, img_wh=test_dataset.img_wh,
- trans_dicts=trans_dicts, save_dir=logfolder,
- label_color_map=test_dataset.label_color_map, white_bg=white_bg, fp16=fp16, device=device, args=args,
- gt_rgbs=gt_rgbs, gt_labels=gt_ins_labels, use_inpaint=args.use_inpaint)
-
- print(f'Manipulating {args.expname} Done')
-
-
- if __name__ == '__main__':
- torch.set_default_tensor_type('torch.FloatTensor')
- torch.manual_seed(20230414)
- np.random.seed(20230414)
-
- args = config_parser()
-
- mani(args)
|