|
- import argparse
- import importlib
- from pathlib import Path
-
- import mindspore
- from mindspore import context
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- # training params
- parser.add_argument('--algo', default='MGF', type=str, help='the algorithm need to train or eval.')
- parser.add_argument('--save_ckpt_path', default='/checkpoints/MGF/data/best_eval.ckpt', type=str)
- parser.add_argument('--seed', default=1000, type=int, help='set seed for reproducibility')
- parser.add_argument('--device', default='GPU', type=str, choices=['Ascend', 'CPU', 'GPU'])
- parser.add_argument('--device_id', default=0, type=int)
- parser.add_argument('--pynative_mode', default=True, action='store_true')
- # train/eval
- parser.add_argument('--mode', default='eval', type=str, choices=['train', 'eval'])
-
- opt = parser.parse_args()
- # if opt.pynative_mode:
- # context.set_context(mode=context.PYNATIVE_MODE, device_target=opt.device, device_id=opt.device_id)
- # print(f'[PYNATIVE] Start running algorithm {opt.algo} in {opt.mode} mode.')
- # else:
- # context.set_context(mode=context.GRAPH_MODE, device_target=opt.device, device_id=opt.device_id)
- # print(f'[GRAPH] Start running algorithm {opt.algo} in {opt.mode} mode.')
- algo = importlib.import_module(opt.algo)
- algo.main(opt.mode, opt.save_ckpt_path)
|