|
- # from genericpath import exists
- import os
- import time
- import argparse
- import numpy as np
- import subprocess
- from tqdm import tqdm
- import tensorflow as tf
- import base_model
-
- # import torch
-
- # import data as Dataset
- # from config import Config
- # from util.distributed import init_dist
- # from util.trainer import get_model_optimizer_and_scheduler, set_random_seed, get_trainer
- from demo_point_cloud_dataset import DemoPointCloudDataset #已移植
-
- def parse_args():
- parser = argparse.ArgumentParser(description='Training')
- parser.add_argument('--checkpoints_dir', default='/userhome/postprocess_v1_bek/tensorflow/models/', help='Dir for saving logs and models.')
- parser.add_argument('--output_dir', type=str, default='eval_result')
- parser.add_argument('--threshold', type=str, default='adaptive') #adaptive fixed
- parser.add_argument('--threshold_value', type=float, default=0.98)
- parser.add_argument('--test_list', type=str, default='./test_10bit.txt')
- parser.add_argument('--calculate_psnr', default=True,action='store_true')
-
- args = parser.parse_args()
- return args
-
-
- if __name__ == '__main__':
- # get training options
- args = parse_args()
- # set_random_seed(args.seed)
- # opt = Config(args.config, args, is_train=False)
-
- # if not args.single_gpu:
- # opt.local_rank = args.local_rank
- # init_dist(opt.local_rank)
- # opt.device = torch.cuda.current_device()
- # else:
- # opt.distributed = False
- # opt.device = torch.cuda.current_device()
-
-
- # opt.logdir = os.path.join(opt.checkpoints_dir, opt.name)
- # print(opt.logdir)
- # create a model
- net_G = base_model.Generator(base_channel=16,num_layers=4)
- net_G.build(input_shape = (1,1,64,64,64))
- # opt_G = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.99, epsilon=1e-08, decay=0.0)
- # checkpoint = tf.train.Checkpoint(model_net=net_G,model_opt=opt_G)
- checkpoint = tf.train.Checkpoint(model_net=net_G)
- latest_ckpt = tf.train.latest_checkpoint(args.checkpoints_dir)
- checkpoint.restore(latest_ckpt)
- # checkpoint.save(args.checkpoints_dir + 'model_best.ckpt')
- print('loading checkpoint from '+latest_ckpt)
- # net_G, net_D, net_G_ema, opt_G, opt_D, sch_G, sch_D \
- # = get_model_optimizer_and_scheduler(opt) #net_G即generators.multi_scale_model::Generator
-
- # trainer = get_trainer(opt, net_G, net_D, net_G_ema, \
- # opt_G, opt_D, sch_G, sch_D, \
- # None)
-
- # current_epoch, current_iteration = trainer.load_checkpoint(
- # opt, args.which_iter)
- # output_dir = os.path.join(
- # args.output_dir,
- # 'epoch_{:05}_iteration_{:09}'.format(current_epoch, current_iteration)
- # )
- if not os.path.exists(args.output_dir):
- os.makedirs(args.output_dir)
-
- # net_G = net_G_ema.eval()
- data_loader = DemoPointCloudDataset(batch_size=32) #到这里,要测试convert_to_onehot的结果的shape
- test_dict = {}
- with open(args.test_list, 'r') as f:
- lines = f.readlines()
- lines = [item.strip().split(',') for item in lines]
- test_dict['input'] = [item[0] for item in lines]
- test_dict['gt'] = [item[1] for item in lines]
- num_files = len(test_dict['input'])
- # with torch.no_grad():
- for index in tqdm(range(num_files)):
- start_time = time.time()
- cubes_input, cubes_gt = data_loader.load_item(test_dict['input'][index], test_dict['gt'][index])
- predict = []
- input, global_id_list = data_loader.convert_to_onehot(cubes_input) #input.shape: torch.Size([202, 64, 64, 64]) len(global_id_list): 202 global_id_list[0]: (4, 0, 3)
- # print("input.shape:",input.shape)
- # print("len(global_id_list):",len(global_id_list))
- # print("global_id_list[0]:",global_id_list[0])
- num_cubes = input.shape[0]
- predict = []
- for i in range(0, num_cubes, data_loader.batch_size):
- input_batch = input[i:i+data_loader.batch_size, None, ...]
- # print("input_batch.shape:",input_batch.shape)
- predict_batch = net_G(input_batch, training_flag=False) #input_batch.shape: torch.Size([4, 1, 64, 64, 64]) predict_batch.shape: torch.Size([4, 1, 64, 64, 64])
- predict.append(predict_batch)
- # predict.append(predict_batch[2].detach().cpu())
- predict = tf.concat(predict, 0) #结果predict.shape: torch.Size([202, 1, 64, 64, 64])
- # print("predict.shape:",predict.shape)
-
- final_cubes = {}
- for i, global_id in enumerate(global_id_list):
- item_predict = predict[i, 0]
- if args.threshold == 'fixed':
- item_predict = (item_predict > args.threshold_value).numpy().astype(np.int32)
- elif args.threshold == 'adaptive':
- num_point = len(cubes_gt[global_id])
- if num_point != 0:
- # threshold_value = item_predict.view(-1).sort()[0][-num_point]
- threshold_value = tf.sort(tf.reshape(item_predict,[-1]),direction='ASCENDING')[-num_point]
- else:
- threshold_value = 1
- item_predict = (item_predict > threshold_value).numpy().astype(np.int32)
- else:
- assert False
- points = np.array(np.where(item_predict>0)).transpose((1, 0))
- final_cubes[global_id] = points
- end_time = time.time()
- print('process needs: %.2f s', end_time-start_time)
- write_name = os.path.join(args.output_dir, os.path.basename(test_dict['input'][index]))
- data_loader.write_to_ply(final_cubes, write_name)
-
- if args.calculate_psnr:
- for index in range(num_files):
- input_name, gt_name = test_dict['input'][index], test_dict['gt'][index]
- write_name = os.path.join(args.output_dir, os.path.basename(input_name))
- if '9bit' in args.test_list:
- resolution = '511'
- elif '11bit' in args.test_list:
- resolution = '2047'
- else:
- resolution = '1023'
-
- command = '/userhome/pcc_geo_cnn_v2/pc_error_d' \
- + ' -a ' + gt_name \
- + ' -b ' + write_name \
- + ' -r ' + resolution
- screen_str = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
- screen_str.wait()
- str_record = screen_str.stdout.read().decode('utf8')
- if '11bit' in args.test_list:
- d1_PSNR = float(str_record.split('\n')[-5][24:])
- else:
- d1_PSNR = float(str_record.split('\n')[-3][24:])
- print(os.path.basename(write_name), d1_PSNR)
|