|
- import os
- import argparse
- import sys
- import moxing as mox
- import mindspore.nn as nn
- from mindspore import context
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.train import Model
- from mindspore.nn.metrics import Accuracy
- from mindspore import Tensor
-
- import numpy as np
- from data.eval_dataset import configdataset
- from tools.exactor_global_feature import get_X_and_Q
- from tools.evaluate_with_global import global_search, reportMAP
- from model.delg_model import delg_model
-
- INFO_PATH = sys.path[0] + "/data/"
- CKPT_PATH = sys.path[0] + "/ckpt/"
- NUM_CLASSES = 81312
-
- parser = argparse.ArgumentParser(description='DELG Mindspore Version')
-
- parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
- help='device where the code will be implemented (default: Ascend)')
- parser.add_argument('--data_url',
- type=str,
- default="./Data",
- help='path where the dataset is saved')
- parser.add_argument('--ckpt_url',
- help='model to save/load',
- default='./ckpt_url')
- parser.add_argument('--result_url',
- help='result folder to save/load',
- default='./result')
-
- # TODO 添加args.dataset用于选择数据集
- parser.add_argument('--dataset',
- default='roxford5k',
- choices=['roxford5k', 'rparis6k'],
- help='select evaluate dataset.')
- # TODO 添加args.eval_model用于选择验证的模型是global还是local
- parser.add_argument('--eval_model',
- default='Global',
- choices=['Global', 'Local'],
- help='select evaluate model.')
- parser.add_argument('--epoch_size',
- type=int,
- default=21,
- help='Training epochs.')
-
-
- if __name__ == '__main__':
- args = parser.parse_args()
-
- ######################## 将数据集从obs拷贝到训练镜像中 ########################
- obs_data_url = args.data_url
- args.data_url = '/home/work/user-job-dir/data/'
- if not os.path.exists(args.data_url):
- os.mkdir(args.data_url)
-
- try:
- mox.file.copy_parallel(obs_data_url, args.data_url)
- print("Successfully Download {} to {}".format(obs_data_url, args.data_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, args.data_url) + str(e))
-
- # 将模型文件从obs拷贝到推理镜像中
- obs_ckpt_url = args.ckpt_url
- args.ckpt_url = '/home/work/user-job-dir/checkpoint.ckpt'
- try:
- mox.file.copy(obs_ckpt_url, args.ckpt_url)
- print("Successfully Download {} to {}".format(obs_ckpt_url,
- args.ckpt_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- obs_ckpt_url, args.ckpt_url) + str(e))
-
- # 设置输出路径result_url
- obs_result_url = args.result_url
- args.result_url = '/home/work/user-job-dir/result/'
- if not os.path.exists(args.result_url):
- os.mkdir(args.result_url)
-
- args.dataset_path = args.data_url
- args.save_checkpoint_path = args.ckpt_url
-
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
-
- # get Ascend info
- # rank_id = get_rank()
- # rank_size = get_group_size()
-
- # 根据args.dataset定义测试数据集
-
- # 根据args.eval_model设定模型
- # if args.eval_model == 'Global':
- # backbone = resnet50(NUM_CLASSES)
- # model = GlobalLayer(NUM_CLASSES, backbone)
- # args.load_ckpt_url = os.path.join(args.save_checkpoint_path)
- # print("args.load_ckpt_url is:{}", args.load_ckpt_url)
- # param_dict = load_checkpoint(args.load_ckpt_url)
- # load_param_into_net(model, param_dict)
- # # pretrained_resnet50_ckpt = "resnet50_ascend_v160_imagenet2012_official_cv_top1acc76.97_top5acc93.44.ckpt"
- # # ckpt_file = CKPT_PATH + pretrained_resnet50_ckpt
- # # load_checkpoint(ckpt_file, model.globalmodel)
- # elif args.eval_model == 'Local':
- # pass
-
- model = delg_model()
- model.update_parameters_name('net.')
- args.load_ckpt_url = os.path.join(args.save_checkpoint_path)
- param_dict = load_checkpoint(args.load_ckpt_url)
- load_param_into_net(model, param_dict)
-
- # 开始测试
- # TODO 检查调用文件
- pkl = configdataset(args, INFO_PATH)
- X, Q = get_X_and_Q(args, pkl, model)
- ranks = global_search(X, Q)
- reportMAP(args, pkl, ranks)
-
-
- # 输出预测分类并输出到result_url
-
- # Upload results to obs
- ######################## 将输出的结果拷贝到obs(固定写法) ########################
- # 把推理后的结果从本地的运行环境拷贝回obs,在启智平台相对应的推理任务中会提供下载
- try:
- mox.file.copy_parallel(args.result_url, obs_result_url)
- print("Successfully Upload {} to {}".format(args.result_url, obs_result_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(args.result_url, obs_result_url) + str(e))
- ######################## 将输出的模型拷贝到obs ########################
-
-
-
-
|