|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """pre process for 310 inference"""
- import argparse
- import os
-
- import numpy as np
- import mindspore.dataset as ds
- from src.dataset import DatasetGenerator
-
-
- def parse_args(arg=None):
- """Define configuration of preprocess"""
- parser = argparse.ArgumentParser()
- parser.add_argument('--data_path', type=str, help='data path')
- parser.add_argument('--preprocess_path', type=str, help='preprocess path')
- parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
- parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
- parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
- parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
-
- return parser.parse_args(arg)
-
-
- def run_pre_process():
- """run pre process"""
- args = parse_args()
-
- num_workers = 8
- test_ds_generator = DatasetGenerator(root=args.data_path, args=args, split='test')
- test_ds = ds.GeneratorDataset(test_ds_generator, ["data", "label"], num_parallel_workers=num_workers, shuffle=False)
- test_ds = test_ds.batch(batch_size=1, num_parallel_workers=num_workers)
-
- preprocess_data_path = os.path.join(args.preprocess_path, 'data')
- preprocess_label_path = os.path.join(args.preprocess_path, 'label')
- if not os.path.exists(preprocess_data_path):
- os.makedirs(preprocess_data_path)
- if not os.path.exists(preprocess_label_path):
- os.makedirs(preprocess_label_path)
-
- label_list = []
- for idx, data in enumerate(test_ds.create_dict_iterator(output_numpy=True)):
- file_name = "pointnet2_data_bs_" + str(idx).zfill(4) + ".bin"
- file_path = os.path.join(preprocess_data_path, file_name)
- data["data"].tofile(file_path)
- label_list.append(data["label"])
- np.save(os.path.join(preprocess_label_path, "labels.npy"), label_list)
- print("=" * 20, "export bin files finished", "=" * 20)
-
-
- if __name__ == "__main__":
- run_pre_process()
|