|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """eval launch."""
-
- import os
- import sys
- import argparse
- import numpy as np
-
- import mindspore
- from mindspore import context
-
- from src.deepmar_dataset import create_dataset
- from src.deepmar import Deep_Mar
-
-
- def extract_feat(model, test_dataset, test_len):
- N = test_len
- start = 0
- per_iter = test_dataset.create_dict_iterator()
-
- for ep, data in enumerate(per_iter):
- images = data['data']
- feat_tmp = model(images).asnumpy()
- batch_size = feat_tmp.shape[0]
- if ep == 0:
- feat = np.zeros((N, feat_tmp.size // batch_size))
- feat[start:start + batch_size, :] = feat_tmp.reshape((batch_size, -1))
- start += batch_size
-
- return feat
-
-
- def attribute_evaluate(model, test_dataset, test_label, test_len):
- pt_result = extract_feat(model, test_dataset, test_len)
- gt_result = np.zeros(pt_result.shape)
-
- for idx, label in enumerate(test_label):
- gt_result[idx, :] = label
-
- pt_result[pt_result >= 0.5] = 1
- pt_result[pt_result < 0.5] = 0
-
- return attribute_evaluate_lidw(gt_result, pt_result)
-
-
- def attribute_evaluate_lidw(gt_result, pt_result):
- if gt_result.shape != pt_result.shape:
- raise Exception('Shape between groundtruth and predicted results are different')
-
- # compute the label-based accuracy
- result = {}
- gt_pos = np.sum((gt_result == 1).astype(float), axis=0)
- gt_neg = np.sum((gt_result == 0).astype(float), axis=0)
- pt_pos = np.sum((gt_result == 1).astype(float) * (pt_result == 1).astype(float), axis=0)
- pt_neg = np.sum((gt_result == 0).astype(float) * (pt_result == 0).astype(float), axis=0)
- label_pos_acc = 1.0 * pt_pos / gt_pos
- label_neg_acc = 1.0 * pt_neg / gt_neg
- label_acc = (label_pos_acc + label_neg_acc) / 2
- avg_acc = (pt_neg + pt_pos) / (gt_neg + gt_pos)
- result['label_pos_acc'] = label_pos_acc
- result['label_neg_acc'] = label_neg_acc
- result['label_acc'] = label_acc
- result['avg_acc'] = avg_acc
- leng = len(avg_acc)
- all_avg = np.sum(avg_acc)
- all_acac = all_avg / leng
- result['all_acac'] = all_acac
-
- # compute the instance-based accuracy
- gt_pos = np.sum((gt_result == 1).astype(float), axis=1)
- pt_pos = np.sum((pt_result == 1).astype(float), axis=1)
- floatersect_pos = np.sum((gt_result == 1).astype(float) * (pt_result == 1).astype(float), axis=1)
- union_pos = np.sum(((gt_result == 1) + (pt_result == 1)).astype(float), axis=1)
- # avoid empty label in predicted results
- cnt_eff = float(gt_result.shape[0])
- for iters, key in enumerate(gt_pos):
- if key == 0:
- union_pos[iters] = 1
- pt_pos[iters] = 1
- gt_pos[iters] = 1
- cnt_eff = cnt_eff - 1
- continue
- if pt_pos[iters] == 0:
- pt_pos[iters] = 1
- instance_acc = np.sum(floatersect_pos / union_pos) / cnt_eff
- instance_precision = np.sum(floatersect_pos / pt_pos) / cnt_eff
- instance_recall = np.sum(floatersect_pos / gt_pos) / cnt_eff
- floatance_F1 = 2 * instance_precision * instance_recall / (instance_precision + instance_recall)
- result['instance_acc'] = instance_acc
- result['instance_precision'] = instance_precision
- result['instance_recall'] = instance_recall
- result['instance_F1'] = floatance_F1
-
- return result
-
-
- def main(input_args):
- if input_args.enable_pengcheng_cloud:
- import moxing as mox
- data_dir = input_args.workroot + '/data'
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
-
- obs_data_url = input_args.data_url
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
-
- input_args.image_path = data_dir + '/images'
- input_args.peta_dataset_mat_dir = data_dir + '/PETA.mat'
-
- if input_args.ckpt_path[0] == "." and input_args.ckpt_path[1] == "/":
- input_args.ckpt_path = input_args.ckpt_path[2:]
- elif input_args.ckpt_path[0] == "/":
- input_args.ckpt_path = input_args.ckpt_path[1:]
- else:
- pass
- input_args.ckpt_path = os.path.join(sys.path[0], input_args.ckpt_path)
-
- context.set_context(mode=context.GRAPH_MODE, device_target=input_args.device_target)
- test_dataset, test_label, test_len = create_dataset(input_args, rank_size=None, rank_id=None)
-
- feature_net = Deep_Mar()
- dic = mindspore.train.serialization.load_checkpoint(input_args.ckpt_path)
- mindspore.train.serialization.load_param_into_net(feature_net, parameter_dict=dic)
-
- test_result = attribute_evaluate(feature_net, test_dataset, test_label, test_len)
- label_acc_avg = np.sum(test_result['label_acc']) / len(test_result['label_acc'])
- print("-" * 50)
- print('Test result , Label_acc_avg: {0:}'.format(label_acc_avg))
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='reidentification')
-
- # dataset option
- parser.add_argument('--split', type=str, default='test', choices=['trainval', 'train', "test"],
- help="Select dataset only or training set and validation set")
- parser.add_argument('--partition_idx', type=int, default=0,
- help="Dataset split sequence number")
- parser.add_argument('--image_resize', type=tuple, default=(224, 224),
- help="Data set picture specified size")
-
- # eval option
- parser.add_argument('--batch_size', type=int, default=256)
- parser.add_argument('--num_work', type=int, default=2)
-
- # device option
- parser.add_argument('--device_target', type=str, default="Ascend")
- parser.add_argument('--train_mode', type=str, default='test', choices=['test', 'train'],
- help="the mode of loading a dataset")
-
- # url option
- parser.add_argument('--peta_dataset_mat_dir', type=str, default='./data/PETA.mat',
- help="The absolute address of PETA.mat ")
- parser.add_argument('--image_path', type=str, default='./data/images',
- help="The absolute address of image folder")
- parser.add_argument('--ckpt_path', type=str, default='./checkpoint/Deepmar_Distributed_Epoch_400_.ckpt',
- help="The absolute address of ckpt")
-
- # PengCheng cloud brain option
- parser.add_argument('--enable_pengcheng_cloud', type=int, default=0,
- help="Whether it runs on Pengcheng cloud brain")
- parser.add_argument('--workroot', type=str, default='/home/work/user-job-dir',
- help="Cloud brain working environment for training tasks")
- parser.add_argument('--train_url', type=str, default=' ',
- help="Training task result saving address")
- parser.add_argument('--data_url', type=str, default=' ',
- help="Dataset address of training task")
-
- args = parser.parse_args()
- main(args)
|