@@ -12,19 +12,30 @@ from networkTool import CPrintl,expName
from octAttention import model
import glob,datetime,os
import pt as pointCloud #无框架
import argparse
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input", default='', help="Input filename.") #'compressed/28_airplane_0270' '/userhome/PCGCv1/testdata/8iVFB/redandblack_vox10_1550.ply'
parser.add_argument("--ckpt_dir", type=str, default='/userhome/PCGCv1/pytorch2/ckpts/hyper_mgpu2/epoch_13_12599.pth', dest="ckpt_dir", help='checkpoint')
args = parser.parse_args()
return args
############## warning ###############
## decoder.py relys on this model here
## do not move this lines to somewhere else
# model = model.to(device)
# saveDic = reload(None,'modelsave/obj/encoder_epoch_00800093.pth')
# model.load_state_dict(saveDic['encoder'])
args = parse_args()
input_data = tf.ones([1024, 32, 4, 6],dtype=tf.int32)
input_mask = model.generate_square_subsequent_mask(1024)
output = model(input_data, input_mask,[],training=False)
model.load_weights(tf.train.latest_checkpoint('/userhome/OctAttention/OctAttention-obj/Exp/Obj/checkpoint_TF/')) #文件夹名称
model.load_weights(tf.train.latest_checkpoint(args.ckpt_dir )) #文件夹名称
###########Objct##############
list_orifile = ['../28_airplane_0270.ply' ]
list_orifile = [args.input ]
if __name__=="__main__":
printl = CPrintl(expName+'/encoderPLY_tf.txt') #输出的同时保存到log
printl('_'*50,'OctAttention V tf','_'*50)
@@ -38,8 +49,8 @@ if __name__=="__main__":
ptName = os.path.splitext(os.path.basename(oriFile))[0]
for qs in [1]:
ptNamePrefix = ptName
matFile,DQpt,refPt = dataPrepare(oriFile,saveMatDir='.. /Data/testPly',qs=qs,ptNamePrefix='',rotation=False)
matFile,DQpt,refPt = dataPrepare(oriFile,saveMatDir='./Data/testPly',qs=qs,ptNamePrefix='',rotation=False)
# please set `rotation=True` in the `dataPrepare` function when processing MVUB data
main(matFile,model,actualcode=True,printl =printl) # actualcode=False: bin file will not be generated
print('_'*50,'pc_error','_'*50)
pointCloud.pcerror(refPt,DQpt,None,'-r 1023',None).wait() #这里比的是量化前,量化后再反量化的点
# print('_'*50,'pc_error','_'*50)
# pointCloud.pcerror(refPt,DQpt,None,'-r 1023',None).wait() #这里比的是量化前,量化后再反量化的点