|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """eval net"""
- import os
- import argparse
- import mindspore
- from mindspore import context
- from mindspore import Tensor
- from mindspore.nn import Adam
- from mindspore.train.model import Model
- from mindspore.context import ParallelMode
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.common import set_seed
- from mindspore.train.summary.summary_record import SummaryRecord
- import mindspore.nn as nn
- from mindspore import dtype as mstype
- import itertools
- import mindspore as ms
- import numpy as np
- import math
- from os import path as osp
- from loguru import logger
- from src.gaitset import SetNet
- from src.dataset import get_one_set
-
-
- parser = argparse.ArgumentParser(description='Gait recognition')
- parser.add_argument("--config")
- args_opt = parser.parse_args()
-
- set_seed(1)
-
- def calc_dist(x, y):
- dist = np.sqrt(np.sum((x-y)**2))
- return dist
-
- def find_in_gallery(query_features, gallery):
- dist = np.zeros((50, 4, 11))
-
- for s in range(50):
- for c in range(4):
- for v in range(11):
- gallery_features = gallery[s][c][v]
- dist[s][c][v] = calc_dist(gallery_features, query_features)
-
- return np.unravel_index(np.argmin(dist, axis=None), dist.shape)
-
- if __name__ == '__main__':
- from src.config import cfg
-
- # init context
- context.set_context(mode=context.GRAPH_MODE,
- device_target=cfg.device_target,
- save_graphs=False,device_id=4)
-
- # define net
- net = SetNet()
-
- # load checkpoint
- print('--------------------')
- print(cfg.eval_checkpoint)
- print('--------------------')
- param_dict = load_checkpoint(cfg.eval_checkpoint)
- load_param_into_net(net, param_dict)
- net.set_train(False)
-
- subjects = [s for s in range(75, 125)]
- gallery_conditions = ['nm-01', 'nm-02', 'nm-03', 'nm-04']
- probes = [['nm-05', 'nm-06'], ['bg-01', 'bg-02'], ['cl-01', 'cl-02']]
- views = ['000', '018', '036', '054', '072', '090', '108', '126', '144', '162', '180']
-
- #fill the gallery
- gallery = np.zeros((50, 4, 11, 1, 62*256))
-
- for s in range(50):
- for c in range(4):
- for v in range(11):
- path_to_silhs = osp.join(cfg.dataset_path,
- str(subjects[s]).zfill(3),
- gallery_conditions[c],
- views[v])
- try:
- input_set = mindspore.Tensor(get_one_set(path_to_silhs), mstype.float32) #[1x<silhs_num>x1x64x64]
- value = net(input_set).asnumpy() #[1x62*128]
- except:
- value = math.inf
- logger.debug(f"WARNING! {path_to_silhs} - SKIPPED")
-
- gallery[s][c][v] = value
-
- logger.info(f"Put features for subject {subjects[s]} in gallery")
-
- #fill the probes
- res_including = []
- #res_excluding = []
- for probe in probes:
- true_count_in = 0
- all_count_in = 0
- true_count_ex = 0
- all_count_ex = 0
- logger.info(f"Start calculating metric for probe [{probe}]")
- for s in range(50):
- for c in range(2):
- for v in range(11):
- path_to_silhs = osp.join(cfg.dataset_path,
- str(subjects[s]).zfill(3),
- probe[c],
- views[v])
- input_set = mindspore.Tensor(get_one_set(path_to_silhs), mstype.float32) #[1x<silhs_num>x1x64x64]
- query_features = net(input_set).asnumpy() #[1x62*128]
- index = find_in_gallery(query_features, gallery)
-
- #logger.info(f"Set {'/'.join(path_to_silhs.split('/')[-3:])} explored. Get index {subjects[index[0]]} from the net.")
- if s == index[0]:
- true_count_in += 1
- all_count_in += 1
-
- top_1_in = true_count_in / all_count_in
- res_including.append(top_1_in)
- #res_excluding.append(top_1_ex)
- #logger.info(f"End calculating metric for probe [{probe}]. Mean top 1 rank is {top_1_in}")
-
- logger.info(f"Mean top 1 ranks for 3 probes (including): {res_including}")
|