|
- import os
- import glob
- import numpy as np
- from tqdm import tqdm
- import random
- # import torch
- # from torch.utils.data import Dataset
- import tensorflow as tf
-
- class BaseDataset():
- def __init__(self, path, cube_size, batch_size, is_inference=False): #is_inference: True=val, False=training
- self.cube_size = cube_size
- self.path = path
- if os.path.isfile(path):
- print('reading dataset from {}'.format(path))
- with open(path, 'r') as f:
- lines = f.readlines()
- self.file_list = [item.strip() for item in lines]
- else:
- print('Do not find the file path')
- # self.file_list = np.random.shuffle(self.file_list) #不能用,报错
- self.file_list = random.sample(self.file_list, len(self.file_list))
- self.dataset = tf.data.Dataset.from_tensor_slices(self.file_list) #np.arange(self.fileLen)
- self.dataset = self.dataset.map(lambda x: tf.py_function(self.parse_file, [x], [tf.float32,tf.float32]))
- # if not is_inference:
- # self.dataset = self.dataset.shuffle(buffer_size=len(self.file_list)) #时间太长了,占用太多内存
- self.dataset = self.dataset.batch(batch_size, drop_remainder=not is_inference) #把最后不足1 batch的丢掉
- self.dataset = self.dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
- self.dataset = self.dataset.repeat(1) #epochs这个值要具体根据数据集和batchsize大小计算一下
- # self.iterator = iter(self.dataset)
- # self.iterator = self.dataset.make_one_shot_iterator()
-
- def parse_file(self,filename):
- filename = filename.numpy().decode(encoding='utf8',errors='ignore')
- filename = '/userhome/postprocess_v1_bek' + filename[1:]
- # print(filename)
- gt_name = self.to_gt_name(filename)
- decompressed_points = np.loadtxt(filename).astype(np.int32)
- gt_points = np.loadtxt(gt_name).astype(np.int32)
- decompressed = self.convert_to_onehot(decompressed_points)
- gt = self.convert_to_onehot(gt_points)
- return decompressed,gt
-
- def convert_to_onehot(self, points):
- inputs = np.zeros((self.cube_size, self.cube_size, self.cube_size), dtype=np.float32)
- inputs[points[:,0], points[:,1], points[:,2]] = 1.0
- inputs = np.expand_dims(inputs,0) #[1,64,64,64]
- return tf.convert_to_tensor(inputs)
-
- def to_gt_name(self, name):
- name_list = name.split('/')
- name_list[-4] = 'gt'
- return '/'.join(name_list)
-
- if __name__=="__main__":
- train_set = BaseDataset('/userhome/postprocess_v1_bek/train_test/train.txt', cube_size=64, batch_size=64, is_inference=False)
- for (decompressed,gt) in train_set.dataset:
- print("decompressed.shape:",decompressed.shape) #decompressed.shape: (64, 1, 64, 64, 64)
- print("gt.shape:",gt.shape) #gt.shape: (64, 1, 64, 64, 64)
- break
- train_set = BaseDataset('/userhome/postprocess_v1_bek/train_test/train.txt', cube_size=64, batch_size=64, is_inference=False)
- for (decompressed,gt) in train_set.dataset:
- print("decompressed.shape:",decompressed.shape)
- print("gt.shape:",gt.shape)
- break
|