|
- """
- The script for testing snn4ecoset.
- Author: Yu Liutao @ PCL, 2022.11
- """
-
- # ---------------------------------------------------
- # Imports
- # ---------------------------------------------------
- from __future__ import print_function
- import argparse
- import torch
- from torchvision import datasets, transforms
- from torch.utils.data.dataloader import DataLoader
- import datetime
- import sys
- import os
-
-
- # ---------------------------------------------------
- # Utils
- # ---------------------------------------------------
- class AverageMeter(object):
- """Computes and stores the average and current value"""
-
- def __init__(self, name, fmt=':f'):
- self.name = name
- self.fmt = fmt
- self.reset()
-
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
- def __str__(self):
- fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
- return fmtstr.format(**self.__dict__)
-
-
- def accuracy(output, target, topk=(1,)):
- """Computes the accuracy over the k top predictions for the specified values of k"""
- with torch.no_grad():
- maxk = max(topk)
- batch_size = target.size(0)
-
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
-
- res = []
- for k in topk:
- correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
- res.append(correct_k.mul_(100.0 / batch_size))
- return res
-
-
- def find_file(directory, fileName):
- '''
- directory: the folder needed to be searched, eg /dataset
- fileName: the name of the model file, eg model.pth
- Return: the absolute path of the model file
- '''
- #import sys, os
- flag = 0
- for root, dirs, files in os.walk(directory):
- for name in files:
- if name == fileName:
- flag = 1
- filepath = os.path.join(root, name)
- print('\n Model path: {} \n'.format(filepath))
- return filepath
-
- if flag == 0:
- print('\n {} not found in {}, use pretrained model defined in utils.py \n'.format(fileName, directory))
- return ''
- # sys.exit('\n {} not found in {}'.format(fileName, directory))
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='snn4ecoset')
- # parser.add_argument('--gpu', default=True, type=bool, help='use gpu')
- # parser.add_argument('--batch_size', default=32, type=int, help='minibatch size')
- parser.add_argument('--log', default=True, type=bool, help='to print the output to log file')
- # parser.add_argument('--devices', default='0', type=str, help='list of gpu device(s)')
- parser.add_argument('--datapath', default='/dataset/', type=str, help='dataset path')
- parser.add_argument('--modelname', default='sew_resnet18', type=str, help='SNN model name')
- parser.add_argument('--modelpath', default='/code/pretrained_model/sew_resnet18.pth', type=str, help='the directory that contains your pretrained snn model') # /dataset/pretrained_model/'
- # parser.add_argument('--ufuncname', default='', type=str, help='a user-defined .py file to load pretrained snn model and related packages or functions')
- parser.add_argument('--modeldescription', default='', type=str, help='one sentence less than 200 characters to describe your model briefly')
-
- args = parser.parse_args()
-
- # assert args.ufuncname != '', 'user-defined .py file (ufuncname) not provided'
- assert args.modelname != '', 'model name (modelname) not provided'
- assert args.modelpath != '', 'model path (modelpath) not provided'
- assert args.modeldescription != '', 'model description (modeldescription) not provided'
- assert args.datapath != '', 'data path (datapath) not provided'
-
- # Set some key parameters (batch_size, torch.device)
- # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
- if torch.cuda.is_available():
- torch.set_default_tensor_type('torch.cuda.FloatTensor')
-
- batch_size = 16 # try 32 when using small models
-
- # ---------------------------------------------------
- # log_file
- # ---------------------------------------------------
- log_file = '/code/logs/'
- try:
- # os.mkdir(log_file)
- os.makedirs(log_file) # Recursive directory creation function.
- except OSError:
- pass
- identifier = 'snn_' + args.modelname.lower() + '.log'
- log_file += identifier
-
- if args.log:
- f = open(log_file, 'w', buffering=1)
- # f = open(log_file, 'x', buffering=1)
- else:
- f = sys.stdout
-
- # ---------------------------------------------------
- # Prepare the data
- # ---------------------------------------------------
- # testdir = os.path.join('/gdata/ecoset/', 'test')
- # testdir = os.path.join('/dataset/', 'test') # /dataset/ (chosing Ecoset as the dataset when starting a task)
- testdir = os.path.join(args.datapath, 'val')
-
- testset = datasets.ImageFolder(
- testdir,
- transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
- ]))
-
- test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=0, generator=torch.Generator(device='cuda'))
- f.write('\n\n Make Ecoset TEST DataLoader finished! Time: {} \n'.format(datetime.datetime.now()))
- print('\n\n Make Ecoset TEST DataLoader finished! Time: {} \n'.format(datetime.datetime.now()))
-
-
- # search the pretrained model file in the given directory 'modelpath'
- # abs_modelpath = find_file(args.modelpath, 'model.pth')
- abs_modelpath = args.modelpath
-
- # ----------------------------------------------------------------
- # Load pretrained model and import related functions/packages
- # through a USER-DEFINED .py file
- # ----------------------------------------------------------------
-
- sys.path.append("/code/")
- # import load_model
- import utils
- # from utils import load_MyModel, eval_one_batch
- model_pars = utils.load_MyModel(abs_modelpath) # load_model.load_MyModel()
- model = model_pars[0]
- abs_modelpath = model_pars[1]
- if len(model_pars) > 2:
- forward_pars = model_pars[2:] # forward_pars: in case that feedforward process needs more pars besides inputs
-
- f.write("\n\n Load model '{}' successfully!".format(abs_modelpath))
- f.write("\n\n {} \n".format(model))
-
- # get model size
- modelsize = os.path.getsize(abs_modelpath) # bytes
- modelsize = float(modelsize/1024/1024) # MB
- print('\n Model name: {}'.format(args.modelname))
- print('\n Model size: {:0.1f} MB'.format(modelsize))
-
-
- # start testing
- f.write('\n Run on time: {}'.format(datetime.datetime.now()))
- f.write('\n\n Model name: {}'.format(args.modelname))
- f.write('\n Model size: {:0.1f} MB'.format(modelsize))
- f.write('\n Model description: {}'.format(args.modeldescription))
- # f.write('\n User-defined .py file: {}'.format(args.ufuncname))
-
- # -------------------------------------------------------------------
- # The main testing process (accuracy & time consumption)
- # -------------------------------------------------------------------
- # losses = AverageMeter('Loss')
- top1 = AverageMeter('Acc@1')
- top5 = AverageMeter('Acc@5')
-
- if torch.cuda.is_available(): # and not model.is_cuda:
- model.cuda(device=0)
-
- with torch.no_grad():
- model.eval()
-
- start_time = datetime.datetime.now()
-
- for batch_idx, (data, target) in enumerate(test_loader):
-
- if torch.cuda.is_available():
- data, target = data.cuda(), target.cuda()
-
- # output = model(data)
- if len(model_pars) == 2:
- #output = model(data)
- output = utils.eval_one_batch(model, data)
- elif len(model_pars) > 2:
- #output = model(data, *forward_pars)
- output = utils.eval_one_batch(model, data, *forward_pars)
-
- acc1, acc5 = accuracy(output, target, topk=(1, 5))
- top1.update(acc1[0], data.size(0))
- top5.update(acc5[0], data.size(0))
-
- # plot every 100 batches
- if batch_idx % 100 == 0 or batch_idx == len(test_loader) // batch_size + 1:
- print(
- '\n Batch {}, Batch acc@1: {:.4f}, Avg acc@1: {:.4f}, Batch acc@5: {:.4f}, Avg acc@5: {:.4f},Time: {}'.format(
- batch_idx,
- top1.val,
- top1.avg,
- top5.val,
- top5.avg,
- datetime.timedelta(seconds=(datetime.datetime.now() - start_time).seconds)
- )
- )
-
- if args.log: # and (batch_idx + 1) % 10 == 0:
- f.write('\n Batch {}, Batch acc@1: {:.4f}, Avg acc@1: {:.4f}, Batch acc@5: {:.4f}, Avg acc@5: {:.4f},Time: {}'.format(
- batch_idx,
- top1.val,
- top1.avg,
- top5.val,
- top5.avg,
- datetime.timedelta(seconds=(datetime.datetime.now() - start_time).seconds)
- )
- ) # correct.item() / data.size(0),
-
- # terminate the testing process after 2 batches
- # if batch_idx + 1 >= 2:
- # break
-
- # average accuracy across all testing samples
- top1acc = top1.avg
- top5acc = top5.avg
- timeCost = datetime.timedelta(seconds=(datetime.datetime.now() - start_time).seconds)
- # print testing result
- print('\n\n Testing average acc@1: {:.4f}, time cost: {}'.format(top1acc, timeCost))
- f.write('\n\n Testing average acc@1: {:.4f}, time cost: {}'.format(top1acc, timeCost))
- print('\n Testing average acc@5: {:.4f}, time cost: {}'.format(top5acc, timeCost))
- f.write('\n Testing average acc@5: {:.4f}, time cost: {}'.format(top5acc, timeCost))
-
|