|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import mindspore as ms
- import mindspore.nn as nn
- import mindspore.ops as ops
- from datasets.dataset_xcit import create_dataset
- import time
- import logging
- import os
- from collections import OrderedDict
- from create_model import create_model
- import argparse
-
- parser = argparse.ArgumentParser(description='XBM Validation')
- parser.add_argument('--model_name', default='xcit_large_24_p8_224',
- help='model name')
- parser.add_argument('--pretrained', default=True,
- help='use pre-trained model ')
- parser.add_argument('--dataset_val_path', default='',
- help='ImageNet Val Path')
- parser.add_argument('--image_size', default=(224, 224),
- help='the size of image')
- parser.add_argument('--batch_size', default=16,
- help='batch size')
-
- def accuracy(output, target, topk=(1,)):
- """Computes the accuracy over the k top predictions for the specified values of k"""
- maxk = min(max(topk), output.shape[1])
- batch_size = target.shape[0]
- _, pred = ops.TopK()(output, maxk)
- pred = pred.T
- correct = ops.Equal()(pred, ops.Reshape()(target, (1, -1)).expand_as(pred))
- # return [ops.Reshape()(correct[:min(k, maxk)], (-1,)).sum(axis=0, dtype=ms.float32) * 100. / batch_size for k in topk]
- return [ops.ReduceSum()(ops.Cast()(ops.Reshape()(correct[:min(k, maxk)], (-1,)), ms.float32)) * 100. / batch_size for k in topk]
-
- class AverageMeter:
- """Computes and stores the average and current value"""
- def __init__(self):
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
- if __name__ == '__main__':
- args = parser.parse_args()
- model_name = args.model_name
- pretrained = args.pretrained
- dataset_path = args.dataset_val_path
- image_height = args.image_size[0]
- image_width = args.image_size[1]
- batch_size = args.batch_size
-
- device_id = int(os.getenv('DEVICE_ID', '0'))
- ms.context.set_context(mode=ms.context.GRAPH_MODE, device_target='Ascend', device_id=device_id)
-
- logging.basicConfig(level=logging.DEBUG)
- batch_time = AverageMeter()
- losses = AverageMeter()
- top1 = AverageMeter()
- top5 = AverageMeter()
-
- criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
- net = create_model(model_name=model_name, pretrained=pretrained)
-
- dataset = create_dataset(dataset_path=dataset_path,
- do_train=False,
- image_height=image_height,
- image_width=image_width,
- batch_size=batch_size,
- run_distribute=False)
-
- batch_idx = 0
- end = time.time()
- ds_iter = dataset.create_dict_iterator(output_numpy=False, num_epochs=1)
- for item in ds_iter:
- input = item['image']
- target = item['label']
- output = net(input)
- loss = criterion(output, target)
- acc1, acc5 = accuracy(output, target, topk=(1, 5))
-
- losses.update(loss.asnumpy().item(), input.shape[0])
- top1.update(acc1.asnumpy().item(), input.shape[0])
- top5.update(acc5.asnumpy().item(), input.shape[0])
- batch_time.update(time.time() - end)
- end = time.time()
-
- if batch_idx % 10 == 0:
- logging.info(
- 'Test: [{0:>4d}/{1}] '
- 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
- 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
- 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
- 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
- batch_idx,
- 50000//batch_size,
- batch_time=batch_time,
- rate_avg=input.shape[0] / batch_time.avg,
- loss=losses,
- top1=top1,
- top5=top5
- )
- )
- batch_idx +=1
-
- top1a, top5a = top1.avg, top5.avg
- results = OrderedDict(model='xcit_large_24_p8_224',
- top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
- top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
- )
-
- logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
- results['top1'], results['top1_err'], results['top5'], results['top5_err']))
-
-
|