|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """ M2Det evaluation """
-
- import os
- import argparse
- import numpy as np
- from tqdm import tqdm
-
- from src.priors import PriorBox
- from src.dataset import get_dataset, BaseTransform
- from src.utils import Timer
- from src import config as cfg
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--output_path', help="save dir", type=str, default='./result_dir')
- args = parser.parse_args()
-
- def test_net(testset, transform, save_folder):
- if not os.path.exists(save_folder):
- os.mkdir(save_folder)
-
- num_images = len(testset)
- print('=> Total {} images to test.'.format(num_images))
-
- _t = {'im_detect': Timer(), 'misc': Timer()}
- print('Begin to evaluate')
- print(num_images)
- for i in tqdm(range(num_images)):
- img, img_id = testset.pull_image(i)
- # step1: CNN detection
- _t['im_detect'].tic()
- images = transform(img)
- images = np.expand_dims(images, 0)
- img_id = img_id.split('/')
- file_name = "{}.bin".format(str(img_id[-1]).split('.')[0])
- img_file_path = os.path.join(save_folder, file_name)
- images.tofile(img_file_path)
-
- def main():
- priorbox = PriorBox(cfg)
- priors = priorbox.forward()
- _, generator = get_dataset(cfg=cfg,
- dataset='COCO',
- priors=priors,
- setname='eval_sets')
-
- _preprocess = BaseTransform(cfg.model['input_size'], cfg.model['rgb_means'], (2, 0, 1))
- test_net(generator,
- transform=_preprocess,
- save_folder=args.output_path)
-
- if __name__ == "__main__":
- main()
|