|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- @Author: Linyao Gao
- @Contact: linyaog@sjtu.edu.cn
- @File: test.py
- @Time: 2021/01/04
- """
- import collections
-
- import open3d
- import os
- import numpy as np
- import torch
- from dataset import Dataset
- from torch.utils.data import DataLoader
- import matplotlib.pyplot as plt
- from utils.pc_error_wrapper import pc_error
- # from iostream import IOStream
- import time
- import importlib
- import sys
- import argparse
-
- import logging
-
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
- ROOT_DIR = BASE_DIR
- sys.path.append(os.path.join(ROOT_DIR, 'models'))
- os.environ["CUDA_VISIBLE_DEVICES"] = '0'
-
- # 定义测试指标
-
- def parse_args():
- '''PARAMETERS'''
- parser = argparse.ArgumentParser('point_based_PCGC')
- parser.add_argument('--dataset_path', type=str, default='dataset') # ../pointcloud_compression/PointCloudDatasets
- parser.add_argument('--model', default='NGS_PCC', help='model name')
-
- return parser.parse_args()
-
- def cal_bpp(likelihood, device, num_points, batch_size):
- bpp = torch.sum(torch.log(likelihood)) / -(torch.log(torch.Tensor([2.0]).to(device))) / (float(num_points)*float(batch_size))
-
- return bpp
-
- def cal_d1(pc_gt, decoder_output, step, name, checkpoint_path):
- # 原始点云写入ply文件
- ori_pcd = open3d.geometry.PointCloud() # 定义点云
- ori_pcd.points = open3d.utility.Vector3dVector(np.squeeze(pc_gt)) # 定义点云坐标位置[N,3]
- orifile = checkpoint_path+'/pc_files/'+'d1_ori_'+str(step)+"_"+name+'.ply'# 保存路径
- open3d.io.write_point_cloud(orifile, ori_pcd, write_ascii=True)
- # 重建点云写入ply文件
- rec_pcd = open3d.geometry.PointCloud()
- decoder_output = decoder_output.reshape(-1, 3) ## add
- rec_pcd.points = open3d.utility.Vector3dVector(np.squeeze(decoder_output))
- recfile = checkpoint_path+'/pc_files/'+'d1_rec_'+str(step)+"_"+name+'.ply'
- open3d.io.write_point_cloud(recfile, rec_pcd, write_ascii=True)
-
- pc_error_metrics = pc_error(infile1=orifile, infile2=recfile, res=2) # res为数据峰谷差值
- pc_errors = [pc_error_metrics["mse1,PSNR (p2point)"][0],
- pc_error_metrics["mse2,PSNR (p2point)"][0],
- pc_error_metrics["mseF,PSNR (p2point)"][0],
- pc_error_metrics["mse1 (p2point)"][0],
- pc_error_metrics["mse2 (p2point)"][0],
- pc_error_metrics["mseF (p2point)"][0]]
-
- return pc_errors
-
- def cal_d2(pc_gt, decoder_output, step, name, checkpoint_path):
-
- # 原始点云写入ply文件
- ori_pcd = open3d.geometry.PointCloud() # 定义点云
- ori_pcd.points = open3d.utility.Vector3dVector(np.squeeze(pc_gt)) # 定义点云坐标位置[N,3]
- ori_pcd.estimate_normals(search_param=open3d.geometry.KDTreeSearchParamHybrid(radius=0.1,max_nn=30)) # 计算normal
- orifile = checkpoint_path+'/pc_files/'+'d2_ori_'+str(step)+"_"+name+'.ply'# 保存路径
- open3d.io.write_point_cloud(orifile, ori_pcd, write_ascii=True)
- # 将ply文件中normal类型double转为float32
- lines = open(orifile).readlines()
- to_be_modified = [7, 8, 9]
- for i in to_be_modified:
- lines[i] = lines[i].replace('double','float32')
- file = open(orifile, 'w')
- for line in lines:
- file.write(line)
- file.close()
- # 可视化点云,only xyz
- # open3d.visualization.draw_geometries([ori_pcd])
-
- # 重建点云写入ply文件
- rec_pcd = open3d.geometry.PointCloud()
- decoder_output = decoder_output.reshape(-1, 3) ## add
- rec_pcd.points = open3d.utility.Vector3dVector(np.squeeze(decoder_output))
- # rec_pcd.estimate_normals(search_param=open3d.geometry.KDTreeSearchParamHybrid(radius=0.1,max_nn=30)) # 计算normal
- recfile = checkpoint_path+'/pc_files/'+'d2_rec_'+str(step)+"_"+name+'.ply'
- open3d.io.write_point_cloud(recfile, rec_pcd, write_ascii=True)
-
- pc_error_metrics = pc_error(infile1=orifile, infile2=recfile, normal=True, res=2) # res为数据峰谷差值,normal=True为d2
- pc_errors = [pc_error_metrics["mse1,PSNR (p2plane)"][0],
- pc_error_metrics["mse2,PSNR (p2plane)"][0],
- pc_error_metrics["mseF,PSNR (p2plane)"][0],
- pc_error_metrics["mse1 (p2plane)"][0],
- pc_error_metrics["mse1 (p2plane)"][0],
- pc_error_metrics["mse1 (p2plane)"][0],
- pc_error_metrics["mse2 (p2plane)"][0],
- pc_error_metrics["mseF (p2plane)"][0]]
-
- return pc_errors
-
- def test(model, args, batch_size=1):
- checkpoint_path = '.'
- # shapenetcorev2
- test_data = Dataset(root=args.dataset_path, dataset_name='shapenetcorev2', num_points=2048, split='test')
- test_loader = DataLoader(test_data, num_workers=2, batch_size=batch_size, shuffle=False)
-
- # 初始化变量
- avg_chamfer_dist = np.array([0.0 for i in range(55)])
- avg_d1_psnr = np.array([0.0 for i in range(55)])
- avg_d1_mse = np.array([0.0 for i in range(55)])
- avg_d2_psnr = np.array([0.0 for i in range(55)])
- avg_d2_mse = np.array([0.0 for i in range(55)])
- avg_bpp = np.array([0.0 for i in range(55)])
- counter = np.array([0.0 for i in range(55)])
- total_chamfer_dist = 0.0
- total_d1_psnr = 0.0
- total_d1_mse = 0.0
- total_d2_psnr = 0.0
- total_d2_mse = 0.0
- total_bpp = 0.0
-
- num_samples = 0
-
- label_names = {0: 'airplane', 1: 'bag', 2: 'basket', 3: 'bathtub', 4: 'bed',
- 5: 'bench', 6: 'bottle', 7: 'bowl', 8: 'bus', 9: 'cabinet',
- 10: 'can', 11: 'camera', 12: 'cap', 13: 'car', 14: 'chair',
- 15: 'clock', 16: 'dishwasher', 17: 'monitor', 18: 'table', 19: 'telephone',
- 20: 'tin_can', 21: 'tower', 22: 'train', 23: 'keyboard', 24: 'earphone',
- 25: 'faucet', 26: 'file', 27: 'guitar', 28: 'helmet', 29: 'jar',
- 30: 'knife', 31: 'lamp', 32: 'laptop', 33: 'speaker', 34: 'mailbox',
- 35: 'microphone', 36: 'microwave', 37: 'motorcycle', 38: 'mug', 39: 'piano',
- 40: 'pillow', 41: 'pistol', 42: 'pot', 43: 'printer', 44: 'remote_control',
- 45: 'rifle', 46: 'rocket', 47: 'skateboard', 48: 'sofa', 49: 'stove',
- 50: 'vessel', 51: 'washer', 52: 'cellphone', 53: 'birdhouse', 54: 'bookshelf'}
- names = list(label_names.values())
-
- logger.info("Start update")
- avg_point_cloud_time = np.array([0.0 for i in range(55)])
- torch.backends.cudnn.enabled = False
- torch.backends.cudnn.benchmark = False
-
- for step, data in enumerate(test_loader):
- with torch.no_grad():
- ##某个类别的点云,比如car, 开始测试的时间
- start_time = time.time()
-
- pc_data = data[0] # [1, N, C] -- > [1, 2048, 3]
- label = data[1] # like tensor([[3]])
- label_name = data[2] # like ("airplane", )
- #print("pc_data, label, label_name", pc_data.shape, label.shape, label_name)
-
- if torch.cuda.is_available():
- pc_gt = pc_data.cuda()
- # pc_data = pc_gt.clone()
- # The default code leads to the problem of dimension mismatch
- # pc_data = pc_data.cuda().transpose(1,2) # [B,N,3] --> [1, 3, 2048]
- pc_data = pc_data.cuda().repeat(2, 1, 1) # 每个点云时间/2
- # print("pc_data shape", pc_data.shape)
-
- #torch.backends.cudnn.enabled = False
- bpp, decoder_output, _ = model(pc_data) # pc_data
-
- avg_bpp[label] += bpp
- total_bpp += bpp
-
- # 转换成numpy
- pc_gt = pc_gt.cpu().detach().numpy()
- decoder_output = decoder_output.cpu().detach().numpy()
-
- # D1 psnr & D1 mse
- # d1_results = cal_d1(pc_gt, decoder_output, step, checkpoint_path)
- d1_results = cal_d1(pc_gt, decoder_output, step, label_name[0], checkpoint_path)
- d1_psnr = d1_results[2].item()
- d1_mse = d1_results[5].item()
- avg_d1_mse[label] += d1_mse # equal :
- #print("avg_d1_mse[label]", avg_d1_mse[label])
- total_d1_mse += d1_mse
- avg_d1_psnr[label] += d1_psnr
- total_d1_psnr += d1_psnr
- # D2 psnr & D2 mse
- #d2_results = cal_d2(pc_gt, decoder_output, step, checkpoint_path)
- d2_results = cal_d2(pc_gt, decoder_output, step, label_name[0], checkpoint_path)
- d2_psnr = d2_results[2].item()
- d2_mse = d2_results[5].item()
- avg_d2_mse[label] += d2_mse
- total_d2_mse += d2_mse
- avg_d2_psnr[label] += d2_psnr
- total_d2_psnr += d2_psnr
-
- end_time = time.time()
- each_point_cloud_time = round((end_time-start_time)/2.0, 3) # s
- avg_point_cloud_time[label] += each_point_cloud_time
-
- logger.info(f"The cost each_point_cloud_time is: {each_point_cloud_time}")
- logger.info(f"The cost of avg_point_cloud_time[label] is: {avg_point_cloud_time[label]}")
-
- logger.info(f"bpp in label_name: {label_name[0]} and label: {label[0].numpy().tolist()[0]} is:\t" + str(bpp.item()))
- logger.info(f"d1_psnr in label_name: {label_name[0]} and label: {label[0].numpy().tolist()[0]} is:\t"+ str(d1_psnr))
- logger.info(f"d2_psnr in label_name: {label_name[0]} and label: {label[0].numpy().tolist()[0]} is:\t" + str(d2_psnr))
- logger.info(f"d1_mse in label_name: {label_name[0]} and label: {label[0].numpy().tolist()[0]} is:\t" + str(d1_mse))
- logger.info(f"d2_mse in label_name: {label_name[0]} and label: {label[0].numpy().tolist()[0]} is:\t" + str(d2_mse))
-
- counter[label] += 1
- num_samples += 1
-
- logger.info(f"label_name: {label_name[0]} and label: {label[0].numpy().tolist()[0]} update: {str(counter[label])}")
- logger.info(f"num_samples is: {num_samples}")
-
- logger.info("End update")
- print("num_samples", num_samples)
-
- for i in range(55):
- avg_chamfer_dist[i] /= counter[i]
- avg_d1_psnr[i] /= counter[i]
- avg_d1_mse[i] /= counter[i]
- avg_d2_psnr[i] /= counter[i]
- avg_d2_mse[i] /= counter[i]
- avg_bpp[i] /= counter[i]
-
- avg_point_cloud_time[i] /= counter[i]
- logger.info(f"avg_point_cloud_time[i] of {names[i]} is: {avg_point_cloud_time[i]}")
- logger.info(f"avg_point_cloud_time[i] of {names[i]} is: {avg_point_cloud_time[i]}")
- logger.info(f"avg_d1_psnr[i] of {names[i]} is: {avg_d1_psnr[i]}")
- logger.info(f"avg_d2_psnr[i] of {names[i]} is: {avg_d2_psnr[i]}")
- logger.info(f"avg_bpp[i] of {names[i]} is: {avg_bpp[i]}")
-
- total_chamfer_dist /= num_samples
- total_d1_psnr /= num_samples
- total_d2_mse /= num_samples
- total_d2_psnr /= num_samples
- total_d2_mse /= num_samples
- total_bpp /= num_samples
-
- # print("Average_Chamfer_Dist:", avg_chamfer_dist)
- # print("Average_D1_PSNR:", avg_d1_psnr)
- # print("Average_D1_mse:", avg_d1_mse)
- # print("Average_D2_PSNR:", avg_d2_psnr)
- # print("Average_D2_mse:", avg_d2_mse)
- # print("Average_bpp:", avg_bpp)
-
- for i in range(55):
- outstr = f"The {str(i)}th point cloud category {names[i]}\t" + \
- "Average_bpp: %.6f, \
- Average_D1_PSNR: %.6f, \
- Average_D2_PSNR: %.6f, \
- Average_D1_mse: %.6f, \
- Average_D2_mse: %.6f, \
- Average_Chamfer_Dist: %.6f \
- \n" % (
- avg_bpp[i],
- avg_d1_psnr[i],
- avg_d2_psnr[i],
- avg_d1_mse[i],
- avg_d2_mse[i],
- avg_chamfer_dist[i]
- )
-
- logger.info(outstr)
-
- # Average_all_categories
- outstr = f"Average of all point cloud categories \t" + \
- "Average_all_categories_bpp: %.6f, \
- Average_all_categories_D1_PSNR: %.6f, \
- Average_all_categories_D2_PSNR: %.6f, \
- Average_all_categories_D1_mse: %.6f, \
- Average_all_categories_D2_mse: %.6f, \
- Average_all_categories_Chamfer_Dist: %.6f \
- \n" % (
- total_bpp,
- total_d1_psnr,
- total_d2_psnr,
- total_d1_mse,
- total_d2_mse,
- total_chamfer_dist
- )
-
- logger.info(outstr)
-
-
- if __name__ == '__main__':
-
- args = parse_args()
- model_name = args.model # 'NGS_PCC'
-
- log_dir = f"test_log_d1_d2_bpp_time/{model_name}"
- os.makedirs(log_dir, exist_ok=True)
-
- logger = logging.getLogger("Test Log")
- logger.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-
- #ckpt_of_different_rates = ['4000_256_2400','2000_256_2400','1000_256_2400','500_256_2400','100_256_2400','10_256_2400']
- #ckpt_of_different_rates = ['4000_256_2048','2000_256_2048','1000_256_2048','500_256_2048','100_256_2048','10_256_2048']
- ckpt_of_different_rates = ['4000_256_2048','2000_256_2048','1000_256_2048','500_256_2048','100_256_2048','10_256_2048',
- '4000_256_2400','2000_256_2400','1000_256_2400','500_256_2400','100_256_2400','10_256_2400'
- ]
- # model_name = 'NGS_PCC'
- for exp_name in ckpt_of_different_rates:
- experiment_dir = 'log/'+model_name+'/'+exp_name+"/checkpoints/"+'best_model.pth'
- if not os.path.exists(experiment_dir):
- print(experiment_dir+" not exist!")
- continue
-
- recon_points = int(exp_name.split("_")[-1])
-
- file_handler = logging.FileHandler('%s/%s_%s.txt' % (log_dir, model_name, exp_name))
- file_handler.setLevel(logging.INFO)
- file_handler.setFormatter(formatter)
- logger.addHandler(file_handler)
-
- MODEL = importlib.import_module(model_name)
- model = MODEL.get_model(use_hyperprior=True, bottleneck_size=256, recon_points=recon_points).cuda() # recon_points=2400
- model.eval()
-
- checkpoint = torch.load(str(experiment_dir)) ## best_model
- new_state_dict = collections.OrderedDict()
- for k, v in checkpoint['model_state_dict'].items():
- name = k.replace('module.', '') # remove `module.`
- new_state_dict[name] = v
- model.load_state_dict(new_state_dict)
-
- start_time = time.time()
- test(model, args)
- end_time = time.time()
- outstr = "test_time: %.6f" % ((end_time-start_time)//3600.0)
- print(outstr)
-
- # model_name = 'NGS_PCC'
- # experiment_dir = 'log/NGS_PCC/10_256_2400/checkpoints/'
- # MODEL = importlib.import_module(model_name)
- # model = MODEL.get_model(use_hyperprior=True, bottleneck_size=256, recon_points=args.recon_points).cuda() # recon_points=2400
-
- # print("next:")
- # model.eval()
- # checkpoint = torch.load(experiment_dir + 'best_model.pth') ## best_model
- # new_state_dict = collections.OrderedDict()
- # for k, v in checkpoint['model_state_dict'].items():
- # name = k.replace('module.', '') # remove `module.`
- # new_state_dict[name] = v
- # model.load_state_dict(new_state_dict)
- # print("model:", model)
-
- # start_time = time.time()
- # test(model, args)
- # end_time = time.time()
- # outstr = "test_time: %.6f" % ((end_time-start_time)//3600.0)
- # print(outstr)
-
-
- # import glob
- # log_folder = "log"
- # chk_list = glob.glob(f"{log_folder}/*/checkpoints/*.pth")
- # #print(chk_list)
-
- # dic = {}
- # for chk in chk_list:
- # key = chk.split("/")[1]
- # if key not in dic:
- # dic[key] = [chk]
- # else:
- # dic[key].append(chk)
- # #print(dic)
-
- # all_chk_paths = []
- # for keys, values in dic.items():
- # print(values)
- # all_chk_paths.append(values)
- # #print("all_chk_paths", all_chk_paths)
-
- # model_name = 'NGS_PCC'
- # for paths in all_chk_paths:
-
- # file_handler = logging.FileHandler('%s/%s_%s.txt' % (log_dir, model_name, exp_name))
- # file_handler.setLevel(logging.INFO)
- # file_handler.setFormatter(formatter)
- # logger.addHandler(file_handler)
-
- # paths = sorted(paths)
- # if "best_model.pth" in "/".join(paths):
- # #print("best model", paths)
- # model_path = os.path.join(os.path.split(paths[0])[0], "best_model.pth")
- # else:
- # #print("no best", paths)
- # #print("last model", paths[-1])
- # model_path = paths[-1]
-
- # MODEL = importlib.import_module(model_name)
- # model = MODEL.get_model(use_hyperprior=True, bottleneck_size=256, recon_points=args.recon_points).cuda() # recon_points=2400
- # model.eval()
- # # checkpoint = torch.load(str(experiment_dir) + '/checkpoints/10.pth') ## best_model
- # checkpoint = torch.load(model_path)
- # new_state_dict = collections.OrderedDict()
- # for k, v in checkpoint['model_state_dict'].items():
- # name = k.replace('module.', '') # remove `module.`
- # new_state_dict[name] = v
- # model.load_state_dict(new_state_dict)
-
- # start_time = time.time()
- # test(model, args)
- # end_time = time.time()
- # outstr = "test_time: %.6f" % ((end_time-start_time)//3600.0)
- # print(outstr)
|