|
- """
- The script for testing similarity2brain_snn in the openI community.
- Author: Yu Liutao @ PCL, 2022.12
- """
-
- # ---------------------------------------------------
- # Imports
- # ---------------------------------------------------
- from __future__ import print_function
- import argparse
- import torch
- # from torchvision import datasets, transforms
- # from torch.utils.data.dataloader import DataLoader
- import datetime
- import sys
- import os
- import numpy as np
-
- from dataset import NeuralDataset
- from visualmodel import VisualModel
- from metric import RSAMetric
- from benchmark import Benchmark
- # import spikingjelly
- # from spikingjelly.activation_based import layer, neuron, surrogate, functional
-
-
- def get_args():
- parser = argparse.ArgumentParser(description='similarity2brain_snn')
- parser.add_argument("--benchmark", default="AllenMouse", type=str, choices=["AllenMouse", "MacaqueSynthetic", "MacaqueFace"], help="benchmark name") # currently, only 'AllenMouse' is available
- # parser.add_argument("--dataset", default="allen_natural_scenes", type=str, choices=["allen_natural_scenes", "macaque_synthetic", "macaque_face"], help="dataset name")
- parser.add_argument("--dataset-path", default="./", type=str, help="dataset path") # /dataset/
- # parser.add_argument("--stimulus-path", default="stimulus/allen_natural_scenes_224.pt", type=str, help="stimulus path") # predifined, excluded from arguments later (should not be accessiable by users)
-
- parser.add_argument("--metric", default="RSA", type=str, choices=["RSA", "CCA", "LR"], help="neural similarity metric") # currently, only 'RSA' is available
- parser.add_argument("--model-name", default="sew_resnet18", type=str, help="name of model") # need to be set by users
- parser.add_argument("--model-checkpoint", default="model_checkpoint/sew_resnet18.pth", type=str, help="model checkpoint path") # /pretrainmodel/xxx.pth (need to be set by users)
- parser.add_argument("--device", default="cuda:0", type=str, help="torch device")
- parser.add_argument('--T', default=4, type=int, help="total time-steps") # might be different for different nets
-
- parser.add_argument("--output-path", default="results/", help="path to save outputs")
- parser.add_argument("--model-description", default="SEW ResNet trained on ImageNet", help="one sentence less than 200 characters to describe your model briefly")
-
- args = parser.parse_args()
- return args
-
-
- def dataset_areas(dataset):
- brain_areas_dict = {'allen_natural_scenes': ['visp', 'visl', 'visrl', 'visal', 'vispm', 'visam'],
- 'macaque_synthetic': ['V4', 'IT'],
- 'macaque_face': ['AM']} # currently, only 'allen_natural_scenes' is available
- return brain_areas_dict[dataset]
-
-
- def preprocess_input(x, T=4):
- return x.unsqueeze(0).repeat(T, 1, 1, 1, 1) # _time_step=args.T (脉冲网络中处理单张图片)
-
-
- def metric_preset(args):
- if args.metric == "RSA":
- return RSAMetric()
- else:
- NotImplementedError
-
-
- def main(args):
- print("Showing aruments")
- print(args)
- assert args.benchmark != '', 'benchmark (benchmark) not chosen'
- assert args.model_name != '', 'model name (model_name) not provided'
- assert args.model_description != '', 'model description (model_description) not provided'
-
- # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
- # if torch.cuda.is_available():
- # torch.set_default_tensor_type('torch.cuda.FloatTensor')
-
- if args.benchmark == 'AllenMouse':
- args.dataset = 'allen_natural_scenes'
- args.neuraldata_path = os.path.join(args.dataset_path, 'neural_data')
- args.stimulus_path = os.path.join(args.dataset_path, 'stimulus/allen_natural_scenes_224.pt')
- else:
- raise NotImplementedError
-
- print("Loading neural data")
- neural_dataset = NeuralDataset(args.dataset, dataset_areas(args.dataset),
- data_path=args.neuraldata_path)
-
- print("Loading model and chosen layers")
- # sys.path.append("/code/")
- import custom_model
- model, model_path = custom_model.get_model(args.model_name, args.model_checkpoint)
- layers = custom_model.get_layers()
- print(f"The model to be tested: {args.model_name}\n{model}")
- print(f"The layers to be analyzed: \n{layers}")
-
- # get model size
- modelsize = os.path.getsize(model_path) # bytes
- modelsize = float(modelsize/1024/1024) # MB
- # print('\n Model name: {}'.format(args.model_name))
- print('\n Model size: {:0.1f} MB'.format(modelsize))
-
- # intializing the visual_model and the benchmark
- print("Initializing the visual_model, metric, benchmark")
- device = torch.device(args.device)
- visual_model = VisualModel(model, args.model_name, layers, args.stimulus_path, _time_step=args.T, device=device)
- metric = metric_preset(args)
- save_path = os.path.join(args.output_path, args.metric.lower(), args.dataset)
- benchmark = Benchmark(neural_dataset, metric, save_path=save_path, suffix='')
-
- # testing the model
- print("Start testing ...")
- start_time = datetime.datetime.now()
- scores, max_scores_dic = benchmark(visual_model)
- timeCost = datetime.timedelta(seconds=(datetime.datetime.now() - start_time).seconds)
- print('\n Evaluation done! Time cost: {}'.format(timeCost))
- m_score = np.mean(scores)
- print(f'Highest score and corresponding layer for each brain area\n{max_scores_dic}')
- print(f'\nMean score of {args.model_name}: {m_score}')
-
-
- if __name__=="__main__":
- args = get_args()
- main(args)
|