Browse Source

update

master
yehua 1 month ago
parent
commit
867e6f1d3f
5 changed files with 34 additions and 14 deletions
  1. +6
    -5
      README.md
  2. +2
    -2
      TensorFlow/decoder.py
  3. +16
    -5
      TensorFlow/encoder.py
  4. +2
    -2
      TensorFlow/networkTool.py
  5. +8
    -0
      requirements.txt

+ 6
- 5
README.md View File

@@ -20,21 +20,22 @@ root
## environment
1. pytorch
refer to OctAttention-obj-pytorch/README.md

2. tensorflow
tensorflow-gpu 2.3.1
others same as pytorch
refer to requirements.txt

## command
* pytorch && tensorflow:
* tensorflow:
> cd TensorFlow

training:
>python octAttention.py

encode:
>python encoder.py
>python encoder.py --input="/userhome/PCGCv1/pytorch_eval/28_airplane_0270.ply" --ckpt_dir="checkpoint_TF_1024"

decode:
>python decoder.py
>python decoder.py --input="/userhome/PCGCv1/pytorch_eval/28_airplane_0270.ply" --ckpt_dir="checkpoint_TF_1024"

## performance
* test on some PC files. From the result below, we can see that, encoding in pytorch is faster than tensorflow, maybe because of the fast operation of linear module of pytorch. And bpip of tensorflow is closed to pytorch.


+ 2
- 2
TensorFlow/decoder.py View File

@@ -121,7 +121,7 @@ if __name__=="__main__":

for oriFile in list_orifile: # from encoder.py
ptName = os.path.basename(oriFile)[:-4]
matName = '../Data/testPly/'+ptName+'.mat'
matName = './Data/testPly/'+ptName+'.mat'
binfile = expName+'/data/'+ptName+'.bin'
cell,mat =matloader(matName)

@@ -141,4 +141,4 @@ if __name__=="__main__":
# Dequantization
DQpt = (ptrec*qs+offset)
pt.write_ply_data(expName+"/temp/test/rec.ply",DQpt)
pt.pcerror(p,DQpt,None,'-r 1',None).wait()
# pt.pcerror(p,DQpt,None,'-r 1',None).wait()

+ 16
- 5
TensorFlow/encoder.py View File

@@ -12,19 +12,30 @@ from networkTool import CPrintl,expName
from octAttention import model
import glob,datetime,os
import pt as pointCloud #无框架
import argparse

def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input", default='', help="Input filename.") #'compressed/28_airplane_0270' '/userhome/PCGCv1/testdata/8iVFB/redandblack_vox10_1550.ply'
parser.add_argument("--ckpt_dir", type=str, default='/userhome/PCGCv1/pytorch2/ckpts/hyper_mgpu2/epoch_13_12599.pth', dest="ckpt_dir", help='checkpoint')
args = parser.parse_args()
return args

############## warning ###############
## decoder.py relys on this model here
## do not move this lines to somewhere else
# model = model.to(device)
# saveDic = reload(None,'modelsave/obj/encoder_epoch_00800093.pth')
# model.load_state_dict(saveDic['encoder'])
args = parse_args()
input_data = tf.ones([1024, 32, 4, 6],dtype=tf.int32)
input_mask = model.generate_square_subsequent_mask(1024)
output = model(input_data, input_mask,[],training=False)
model.load_weights(tf.train.latest_checkpoint('/userhome/OctAttention/OctAttention-obj/Exp/Obj/checkpoint_TF/')) #文件夹名称
model.load_weights(tf.train.latest_checkpoint(args.ckpt_dir)) #文件夹名称

###########Objct##############
list_orifile = ['../28_airplane_0270.ply']
list_orifile = [args.input]
if __name__=="__main__":
printl = CPrintl(expName+'/encoderPLY_tf.txt') #输出的同时保存到log
printl('_'*50,'OctAttention V tf','_'*50)
@@ -38,8 +49,8 @@ if __name__=="__main__":
ptName = os.path.splitext(os.path.basename(oriFile))[0]
for qs in [1]:
ptNamePrefix = ptName
matFile,DQpt,refPt = dataPrepare(oriFile,saveMatDir='../Data/testPly',qs=qs,ptNamePrefix='',rotation=False)
matFile,DQpt,refPt = dataPrepare(oriFile,saveMatDir='./Data/testPly',qs=qs,ptNamePrefix='',rotation=False)
# please set `rotation=True` in the `dataPrepare` function when processing MVUB data
main(matFile,model,actualcode=True,printl =printl) # actualcode=False: bin file will not be generated
print('_'*50,'pc_error','_'*50)
pointCloud.pcerror(refPt,DQpt,None,'-r 1023',None).wait() #这里比的是量化前,量化后再反量化的点
# print('_'*50,'pc_error','_'*50)
# pointCloud.pcerror(refPt,DQpt,None,'-r 1023',None).wait() #这里比的是量化前,量化后再反量化的点

+ 2
- 2
TensorFlow/networkTool.py View File

@@ -15,8 +15,8 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Network parameters
bptt = 1024 # Context window length
expName = '../Exp/Obj'
DataRoot = '../Data/Obj'
expName = './Exp/Obj'
DataRoot = './Data/Obj'

checkpointPath = expName+'/checkpoint_TF'
levelNumK = 4


+ 8
- 0
requirements.txt View File

@@ -0,0 +1,8 @@
h5py==2.10.0
hdf5storage==0.1.18
numpy==1.18.4
pandas==1.0.4
Pillow==9.1.0
plyfile==0.7.4
tensorflow_gpu==2.3.1
tqdm==4.48.0

Loading…
Cancel
Save