|
- import argparse
- import datetime
- import os
- import warnings
- from pathlib import Path
- import time
-
- import moxing as mox
-
- from mindspore import context
- from mindspore import dataset as de
- from mindspore import set_seed
- from mindspore import load_checkpoint
- from mindspore import load_param_into_net
- from mindspore import dtype as mstype
-
- from src.utils import get_config
- from src.utils import get_model_dataset
- from src.core.eval_utils import get_official_eval_result
- from src.predict import predict
- from src.predict import predict_kitti_to_anno
- from src.utils import get_params_for_net
-
- warnings.filterwarnings('ignore')
-
-
- def train(args):
- data_dir = '/home/work/user-job-dir/data' #数据集存放路径
- #初始化数据存放目录
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
- #创建数据存放的位置
- obs_data_url = args.data_url
- #将数据拷贝到训练环境
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
- ######################## 将数据集从obs拷贝到训练镜像中 ########################
- """run train"""
- for root, dirs, files in os.walk("/home/work/user-job-dir/V0002"):
- print(root, dirs, files)
- print("___________________________________________________")
- cfg_path = Path("/home/work/user-job-dir/V0004/experiments/car/car_xyres16_modelarts.yaml")
- ckpt_path = "/home/work/user-job-dir/V0004/experiments/car/pointpillars-160_296960.ckpt"
- cfg = get_config(cfg_path)
-
- context.set_context(mode=context.GRAPH_MODE, device_target=cfg["device_target"])
- device_id = int(os.getenv('DEVICE_ID', '0'))
- context.set_context(device_id=device_id)
- """run evaluate"""
- model_cfg = cfg['model']
-
- center_limit_range = model_cfg['post_center_limit_range']
-
- pointpillarsnet, eval_dataset, box_coder = get_model_dataset(cfg, False)
-
- params = load_checkpoint(ckpt_path)
- new_params = get_params_for_net(params)
- load_param_into_net(pointpillarsnet, new_params)
-
- eval_input_cfg = cfg['eval_input_reader']
-
- eval_column_names = eval_dataset.data_keys
-
- ds = de.GeneratorDataset(
- eval_dataset,
- column_names=eval_column_names,
- python_multiprocessing=True,
- num_parallel_workers=1,
- max_rowsize=100,
- shuffle=False
- )
- batch_size = eval_input_cfg['batch_size']
- ds = ds.batch(batch_size, drop_remainder=False)
- data_loader = ds.create_dict_iterator(num_epochs=1)
-
- class_names = list(eval_input_cfg['class_names'])
-
- dt_annos = []
- gt_annos = [info["annos"] for info in eval_dataset.kitti_infos]
-
- log_freq = 100
- len_dataset = len(eval_dataset)
- start = time.time()
- for i, data in enumerate(data_loader):
- voxels = data["voxels"]
- num_points = data["num_points"]
- coors = data["coordinates"]
- bev_map = data.get('bev_map', False)
- preds = pointpillarsnet(voxels, num_points, coors, bev_map)
- if len(preds) == 2:
- preds = {
- 'box_preds': preds[0],
- 'cls_preds': preds[1],
- }
- else:
- preds = {
- 'box_preds': preds[0],
- 'cls_preds': preds[1],
- 'dir_cls_preds': preds[2]
- }
- preds = predict(data, preds, model_cfg, box_coder)
-
- dt_annos += predict_kitti_to_anno(preds,
- data,
- class_names,
- center_limit_range)
-
- if i % log_freq == 0 and i > 0:
- time_used = time.time() - start
- print(f'processed: {i * batch_size}/{len_dataset} imgs, time elapsed: {time_used} s',
- flush=True)
- result = get_official_eval_result(
- gt_annos,
- dt_annos,
- class_names,
- )
- print("\n+++++++++++++++++++++++++++++++++++++++++++\n")
- print("\n+++++++++++++++++++++++++++++++++++++++++++\n")
- print(result)
- print("\n+++++++++++++++++++++++++++++++++++++++++++\n")
- print("\n+++++++++++++++++++++++++++++++++++++++++++\n")
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--device_target', default='Ascend', help='device target')
- parser.add_argument('--data_url', required=True, help='')
- parser.add_argument('--train_url', required=True, help='')
- parse_args = parser.parse_args()
- train(parse_args)
|