|
- # Copyright (c) 2015-present, Facebook, Inc.
- # All rights reserved.
- import os
- import argparse
- import datetime
- import numpy as np
- import time
- import torch
- import torch.backends.cudnn as cudnn
- import torch.multiprocessing as mp
- import torchvision
- import json
-
- from pathlib import Path
-
- from timm.data import Mixup
- from timm.models import create_model
- from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
- from timm.scheduler import create_scheduler
- from timm.optim import create_optimizer
- from timm.utils import NativeScaler, get_state_dict, ModelEma
-
- from datasets import build_dataset, build_transform
- from engine import train_one_epoch, evaluate
- from losses import DistillationLoss
- from samplers import RASampler
- import robust_models
- import utils
-
-
- def get_args_parser():
- parser = argparse.ArgumentParser('RVT training and evaluation script', add_help=False)
- parser.add_argument('--batch-size', default=64, type=int)
- parser.add_argument('--epochs', default=300, type=int)
-
- # Model parameters
- parser.add_argument('--model', default='rvt_small', type=str, metavar='MODEL',
- help='Name of model to train')
- parser.add_argument('--input-size', default=224, type=int, help='images input size')
-
- parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
- help='Dropout rate (default: 0.)')
- parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
- help='Drop path rate (default: 0.1)')
-
- parser.add_argument('--model-ema', action='store_true')
- parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
- parser.set_defaults(model_ema=True)
- parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
- parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
-
- # RVT params
- parser.add_argument('--use_patch_aug', action='store_true')
-
- # Optimizer parameters
- parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
- help='Optimizer (default: "adamw"')
- parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
- help='Optimizer Epsilon (default: 1e-8)')
- parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
- help='Optimizer Betas (default: None, use opt default)')
- parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
- help='Clip gradient norm (default: None, no clipping)')
- parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
- help='SGD momentum (default: 0.9)')
- parser.add_argument('--weight-decay', type=float, default=0.05,
- help='weight decay (default: 0.05)')
- # Learning rate schedule parameters
- parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
- help='LR scheduler (default: "cosine"')
- parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
- help='learning rate (default: 5e-4)')
- parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
- help='learning rate noise on/off epoch percentages')
- parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
- help='learning rate noise limit percent (default: 0.67)')
- parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
- help='learning rate noise std-dev (default: 1.0)')
- parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
- help='warmup learning rate (default: 1e-6)')
- parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
- help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
-
- parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
- help='epoch interval to decay LR')
- parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
- help='epochs to warmup LR, if scheduler supports')
- parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
- help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
- parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
- help='patience epochs for Plateau LR scheduler (default: 10')
- parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
- help='LR decay rate (default: 0.1)')
-
- # Augmentation parameters
- parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
- help='Color jitter factor (default: 0.4)')
- parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
- help='Use AutoAugment policy. "v0" or "original". " + \
- "(default: rand-m9-mstd0.5-inc1)'),
- parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
- parser.add_argument('--train-interpolation', type=str, default='bicubic',
- help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
-
- parser.add_argument('--repeated-aug', action='store_true')
- parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
- parser.set_defaults(repeated_aug=True)
-
- # * Random Erase params
- parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
- help='Random erase prob (default: 0.25)')
- parser.add_argument('--remode', type=str, default='pixel',
- help='Random erase mode (default: "pixel")')
- parser.add_argument('--recount', type=int, default=1,
- help='Random erase count (default: 1)')
- parser.add_argument('--resplit', action='store_true', default=False,
- help='Do not random erase first (clean) augmentation split')
-
- # * Mixup params
- parser.add_argument('--mixup', type=float, default=0.8,
- help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
- parser.add_argument('--cutmix', type=float, default=1.0,
- help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
- parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
- help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
- parser.add_argument('--mixup-prob', type=float, default=1.0,
- help='Probability of performing mixup or cutmix when either/both is enabled')
- parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
- help='Probability of switching to cutmix when both mixup and cutmix enabled')
- parser.add_argument('--mixup-mode', type=str, default='batch',
- help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
-
- # Distillation parameters
- parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
- help='Name of teacher model to train (default: "regnety_160"')
- parser.add_argument('--teacher-path', type=str, default='')
- parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
- parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
- parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
-
- # * Finetuning params
- parser.add_argument('--finetune', default='', help='finetune from checkpoint')
- parser.add_argument('--pretrained', action='store_true', help='load pretrained model')
-
- # Dataset parameters
- parser.add_argument('--data-path', default='/dataset', type=str,
- help='dataset path')
- parser.add_argument('--data-set', default='CVPR', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19', 'CVPR'],
- type=str, help='Image Net dataset path')
- parser.add_argument('--inat-category', default='name',
- choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
- type=str, help='semantic granularity')
-
- parser.add_argument('--output_dir', default='/model',
- help='path where to save, empty for no saving')
- parser.add_argument('--device', default='cuda',
- help='device to use for training / testing')
- parser.add_argument('--seed', default=0, type=int)
- parser.add_argument('--resume', default='', help='resume from checkpoint')
- parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
- help='start epoch')
- #cvpr data format
- parser.add_argument('--data_dir_train', default='/dataset/train', type=str,
- help='train dataset path')
- parser.add_argument('--meta_file_train', default='/dataset/p2_train.txt', type=str,
- help='train dataset label path')
- parser.add_argument('--data_dir_test', default='/dataset/test', type=str,
- help='dataset path')
- parser.add_argument('--meta_file_test', default='/dataset/p2_val.txt', type=str,
- help='dataset path')
- parser.add_argument('--save_ckpt_freq', default=50, type=int)
- # eval parameters
- parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
- parser.add_argument('--inc_path', default=None, type=str, help='imagenet-c')
- parser.add_argument('--ina_path', default=None, type=str, help='imagenet-a')
- parser.add_argument('--inr_path', default=None, type=str, help='imagenet-r')
- parser.add_argument('--insk_path', default=None, type=str, help='imagenet-sketch')
- parser.add_argument('--fgsm_test', action='store_true', default=False, help='test on FGSM attacker')
- parser.add_argument('--pgd_test', action='store_true', default=False, help='test on PGD attacker')
-
- parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
- parser.add_argument('--num_workers', default=10, type=int)
- parser.add_argument('--pin-mem', action='store_true',
- help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
- parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
- help='')
- parser.set_defaults(pin_mem=True)
-
- # distributed training parameters
- parser.add_argument("--local_rank", default=0, type=int)
- parser.add_argument('--world_size', default=1, type=int,
- help='number of distributed processes')
- parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
- return parser
-
- def setup_for_distributed(is_master):
- """
- This function disables printing when not in master process
- """
- import builtins as __builtin__
- builtin_print = __builtin__.print
-
- def print(*args, **kwargs):
- force = kwargs.pop('force', False)
- if is_master or force:
- builtin_print(*args, **kwargs)
-
- __builtin__.print = print
-
- def main(args):
- if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
- args.distributed = True
- torch.cuda.set_device(args.local_rank)
- args.dist_backend = 'nccl'
- torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url)
- args.world_size = torch.distributed.get_world_size()
- args.rank = torch.distributed.get_rank()
- print('| distributed init {}(rank {})'.format(
- args.world_size, args.rank), flush=True)
- torch.distributed.barrier()
- setup_for_distributed(args.rank == 0)
- else:
- print('Not using distributed mode')
- args.distributed = False
-
- print(args)
-
- if args.distillation_type != 'none' and args.finetune and not args.eval:
- raise NotImplementedError("Finetuning with distillation not yet supported")
-
- device = torch.device(args.device)
-
- # fix the seed for reproducibility
- seed = args.seed + utils.get_rank()
- torch.manual_seed(seed)
- np.random.seed(seed)
- # random.seed(seed)
-
- cudnn.benchmark = True
-
- dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
- dataset_val, _ = build_dataset(is_train=False, args=args)
-
- import pdb
- pdb.set_trace()
-
- if args.distributed: #True:
- num_tasks = utils.get_world_size()
- global_rank = utils.get_rank()
- if args.repeated_aug:
- sampler_train = RASampler(
- dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
- )
- else:
- sampler_train = torch.utils.data.DistributedSampler(
- dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
- )
- if args.dist_eval:
- if len(dataset_val) % num_tasks != 0:
- print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
- 'This will slightly alter validation results as extra duplicate entries are added to achieve '
- 'equal num of samples per-process.')
- sampler_val = torch.utils.data.DistributedSampler(
- dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
- else:
- sampler_val = torch.utils.data.SequentialSampler(dataset_val)
- else:
- sampler_train = torch.utils.data.RandomSampler(dataset_train)
- sampler_val = torch.utils.data.SequentialSampler(dataset_val)
-
- data_loader_train = torch.utils.data.DataLoader(
- dataset_train, sampler=sampler_train,
- batch_size=args.batch_size,
- num_workers=args.num_workers,
- pin_memory=args.pin_mem,
- drop_last=True,
- )
-
- data_loader_val = torch.utils.data.DataLoader(
- dataset_val, sampler=sampler_val,
- batch_size=int(1.5 * args.batch_size),
- num_workers=args.num_workers,
- pin_memory=args.pin_mem,
- drop_last=False
- )
-
- mixup_fn = None
- mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
- if mixup_active:
- mixup_fn = Mixup(
- mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
- prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
- label_smoothing=args.smoothing, num_classes=args.nb_classes)
-
- print(f"Creating model: {args.model}")
- model = create_model(
- args.model,
- pretrained=args.pretrained,
- num_classes=args.nb_classes,
- drop_rate=args.drop,
- drop_path_rate=args.drop_path,
- drop_block_rate=None
- )
-
- if args.finetune:
- if args.finetune.startswith('https'):
- checkpoint = torch.hub.load_state_dict_from_url(
- args.finetune, map_location='cpu', check_hash=True)
- else:
- checkpoint = torch.load(args.finetune, map_location='cpu')
-
- checkpoint_model = checkpoint['model']
- state_dict = model.state_dict()
- for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
- if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
- print(f"Removing key {k} from pretrained checkpoint")
- del checkpoint_model[k]
-
- # interpolate position embedding
- pos_embed_checkpoint = checkpoint_model['pos_embed']
- embedding_size = pos_embed_checkpoint.shape[-1]
- num_patches = model.patch_embed.num_patches
- num_extra_tokens = model.pos_embed.shape[-2] - num_patches
- # height (== width) for the checkpoint position embedding
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
- # height (== width) for the new position embedding
- new_size = int(num_patches ** 0.5)
- # class_token and dist_token are kept unchanged
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
- # only the position tokens are interpolated
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
- pos_tokens = torch.nn.functional.interpolate(
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
- checkpoint_model['pos_embed'] = new_pos_embed
-
- model.load_state_dict(checkpoint_model, strict=False)
-
- model.to(device)
-
- model_ema = None
- if args.model_ema:
- # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
- model_ema = ModelEma(
- model,
- decay=args.model_ema_decay,
- device='cpu' if args.model_ema_force_cpu else '',
- resume='')
-
- model_without_ddp = model
- if args.distributed:
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
- model_without_ddp = model.module
- n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print('number of params:', n_parameters)
-
- linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
- args.lr = linear_scaled_lr
- optimizer = create_optimizer(args, model_without_ddp)
- loss_scaler = NativeScaler()
-
- lr_scheduler, _ = create_scheduler(args, optimizer)
-
- criterion = LabelSmoothingCrossEntropy()
-
- if args.mixup > 0.:
- # smoothing is handled with mixup label transform
- criterion = SoftTargetCrossEntropy()
- elif args.smoothing:
- criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
- else:
- criterion = torch.nn.CrossEntropyLoss()
-
- teacher_model = None
- if args.distillation_type != 'none':
- assert args.teacher_path, 'need to specify teacher-path when using distillation'
- print(f"Creating teacher model: {args.teacher_model}")
- teacher_model = create_model(
- args.teacher_model,
- pretrained=False,
- num_classes=args.nb_classes,
- global_pool='avg',
- )
- if args.teacher_path.startswith('https'):
- checkpoint = torch.hub.load_state_dict_from_url(
- args.teacher_path, map_location='cpu', check_hash=True)
- else:
- checkpoint = torch.load(args.teacher_path, map_location='cpu')
- teacher_model.load_state_dict(checkpoint['model'])
- teacher_model.to(device)
- teacher_model.eval()
-
- # wrap the criterion in our custom DistillationLoss, which
- # just dispatches to the original criterion if args.distillation_type is 'none'
- criterion = DistillationLoss(
- criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
- )
-
- output_dir = Path(args.output_dir)
- if args.resume:
- if args.resume.startswith('https'):
- checkpoint = torch.hub.load_state_dict_from_url(
- args.resume, map_location='cpu', check_hash=True)
- else:
- checkpoint = torch.load(args.resume, map_location='cpu')
- model_without_ddp.load_state_dict(checkpoint['model'])
- if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
- optimizer.load_state_dict(checkpoint['optimizer'])
- lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
- args.start_epoch = checkpoint['epoch'] + 1
- if args.model_ema:
- utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
- if 'scaler' in checkpoint:
- loss_scaler.load_state_dict(checkpoint['scaler'])
-
- if args.eval:
-
- test_transform = build_transform(False, args)
-
- test_stats = evaluate(data_loader_val, model, device)
- print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
-
- if args.inc_path:
- result_dict = {}
- ce_alexnet = utils.get_ce_alexnet()
-
- # transform for imagenet-c
- inc_transform = torchvision.transforms.Compose([torchvision.transforms.CenterCrop(224),
- torchvision.transforms.ToTensor(),
- torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
-
- for name, path in utils.data_loaders_names.items():
- for severity in range(1, 6):
- inc_dataset = torchvision.datasets.ImageFolder(os.path.join(args.inc_path, path, str(severity)), transform=inc_transform)
- inc_data_loader = torch.utils.data.DataLoader(
- inc_dataset, batch_size=int(1.5 * args.batch_size),
- num_workers=args.num_workers,
- pin_memory=args.pin_mem,
- drop_last=False
- )
- test_stats = evaluate(inc_data_loader, model, device)
- print(f"Accuracy on the {name+'({})'.format(severity)}: {test_stats['acc1']:.1f}%")
- result_dict[name+'({})'.format(severity)] = test_stats['acc1']
-
- mCE = 0
- counter = 0
- overall_acc = 0
- for name, path in utils.data_loaders_names.items():
- acc_top1 = 0
- for severity in range(1, 6):
- acc_top1 += result_dict[name+'({})'.format(severity)]
- acc_top1 /= 5
- CE = utils.get_mce_from_accuracy(acc_top1, ce_alexnet[name])
- mCE += CE
- overall_acc += acc_top1
- counter += 1
- print("{0}: Top1 accuracy {1:.2f}, CE: {2:.2f}".format(
- name, acc_top1, 100. * CE))
-
- overall_acc /= counter
- mCE /= counter
- print("Corruption Top1 accuracy {0:.2f}, mCE: {1:.2f}".format(overall_acc, mCE * 100.))
-
- if args.ina_path:
- all_wnids = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141']
- imagenet_a_wnids = ['n01498041', 'n01531178', 'n01534433', 'n01558993', 'n01580077', 'n01614925', 'n01616318', 'n01631663', 'n01641577', 'n01669191', 'n01677366', 'n01687978', 'n01694178', 'n01698640', 'n01735189', 'n01770081', 'n01770393', 'n01774750', 'n01784675', 'n01819313', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01882714', 'n01910747', 'n01914609', 'n01924916', 'n01944390', 'n01985128', 'n01986214', 'n02007558', 'n02009912', 'n02037110', 'n02051845', 'n02077923', 'n02085620', 'n02099601', 'n02106550', 'n02106662', 'n02110958', 'n02119022', 'n02123394', 'n02127052', 'n02129165', 'n02133161', 'n02137549', 'n02165456', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02231487', 'n02233338', 'n02236044', 'n02259212', 'n02268443', 'n02279972', 'n02280649', 'n02281787', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02361337', 'n02410509', 'n02445715', 'n02454379', 'n02486410', 'n02492035', 'n02504458', 'n02655020', 'n02669723', 'n02672831', 'n02676566', 'n02690373', 'n02701002', 'n02730930', 'n02777292', 'n02782093', 'n02787622', 'n02793495', 'n02797295', 'n02802426', 'n02814860', 'n02815834', 'n02837789', 'n02879718', 'n02883205', 'n02895154', 'n02906734', 'n02948072', 'n02951358', 'n02980441', 'n02992211', 'n02999410', 'n03014705', 'n03026506', 'n03124043', 'n03125729', 'n03187595', 'n03196217', 'n03223299', 'n03250847', 'n03255030', 'n03291819', 'n03325584', 'n03355925', 'n03384352', 'n03388043', 'n03417042', 'n03443371', 'n03444034', 'n03445924', 'n03452741', 'n03483316', 'n03584829', 'n03590841', 'n03594945', 'n03617480', 'n03666591', 'n03670208', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03775071', 'n03788195', 'n03804744', 'n03837869', 'n03840681', 'n03854065', 'n03888257', 'n03891332', 'n03935335', 'n03982430', 'n04019541', 'n04033901', 'n04039381', 'n04067472', 'n04086273', 'n04099969', 'n04118538', 'n04131690', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04179913', 'n04208210', 'n04235860', 'n04252077', 'n04252225', 'n04254120', 'n04270147', 'n04275548', 'n04310018', 'n04317175', 'n04344873', 'n04347754', 'n04355338', 'n04366367', 'n04376876', 'n04389033', 'n04399382', 'n04442312', 'n04456115', 'n04482393', 'n04507155', 'n04509417', 'n04532670', 'n04540053', 'n04554684', 'n04562935', 'n04591713', 'n04606251', 'n07583066', 'n07695742', 'n07697313', 'n07697537', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07749582', 'n07753592', 'n07760859', 'n07768694', 'n07831146', 'n09229709', 'n09246464', 'n09472597', 'n09835506', 'n11879895', 'n12057211', 'n12144580', 'n12267677']
- imagenet_a_mask = [wnid in set(imagenet_a_wnids) for wnid in all_wnids]
- ina_dataset = torchvision.datasets.ImageFolder(args.ina_path, transform=test_transform)
- ina_data_loader = torch.utils.data.DataLoader(
- ina_dataset, batch_size=int(1.5 * args.batch_size),
- num_workers=args.num_workers,
- pin_memory=args.pin_mem,
- drop_last=False
- )
- test_stats = evaluate(ina_data_loader, model, device, mask=imagenet_a_mask)
- print(f"Accuracy on the ImageNet-A: {test_stats['acc1']:.1f}%")
-
- if args.inr_path:
- all_wnids = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141']
- imagenet_r_wnids = ['n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214', 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677']
- imagenet_r_mask = [wnid in imagenet_r_wnids for wnid in all_wnids]
- inr_dataset = torchvision.datasets.ImageFolder(args.inr_path, transform=test_transform)
- inr_data_loader = torch.utils.data.DataLoader(
- inr_dataset, batch_size=int(1.5 * args.batch_size),
- num_workers=args.num_workers,
- pin_memory=args.pin_mem,
- drop_last=False
- )
- test_stats = evaluate(inr_data_loader, model, device, mask=imagenet_r_mask)
- print(f"Accuracy on the ImageNet-R: {test_stats['acc1']:.1f}%")
-
- if args.insk_path:
- insk_dataset = torchvision.datasets.ImageFolder(args.insk_path, transform=test_transform)
- insk_data_loader = torch.utils.data.DataLoader(
- insk_dataset, batch_size=int(1.5 * args.batch_size),
- num_workers=args.num_workers,
- pin_memory=args.pin_mem,
- drop_last=False
- )
- test_stats = evaluate(insk_data_loader, model, device)
- print(f"Accuracy on the ImageNet-Sketch: {test_stats['acc1']:.1f}%")
-
- if args.fgsm_test:
- test_stats = evaluate(data_loader_val, model, device, adv='FGSM')
- print(f"Accuracy of the FGSM: {test_stats['acc1']:.1f}%")
-
- if args.pgd_test:
- test_stats = evaluate(data_loader_val, model, device, adv='PGD')
- print(f"Accuracy of the PGD: {test_stats['acc1']:.1f}%")
-
- return
-
- print(f"Start training for {args.epochs} epochs")
- start_time = time.time()
- max_accuracy = 0.0
- for epoch in range(args.start_epoch, args.epochs):
- if args.distributed:
- data_loader_train.sampler.set_epoch(epoch)
-
- train_stats = train_one_epoch(
- args, model, criterion, data_loader_train,
- optimizer, device, epoch, loss_scaler,
- args.clip_grad, model_ema, mixup_fn,
- set_training_mode=args.finetune == '' # keep in eval mode during finetuning
- )
-
- lr_scheduler.step(epoch)
- if args.output_dir:
- checkpoint_paths = [output_dir / 'checkpoint.pth']
- if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
- for checkpoint_path in checkpoint_paths:
- utils.save_on_master({
- 'model': model_without_ddp.state_dict(),
- 'optimizer': optimizer.state_dict(),
- 'lr_scheduler': lr_scheduler.state_dict(),
- 'epoch': epoch,
- 'model_ema': get_state_dict(model_ema),
- 'scaler': loss_scaler.state_dict(),
- 'args': args,
- }, checkpoint_path)
-
- test_stats = evaluate(data_loader_val, model, device)
- print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
- max_accuracy = max(max_accuracy, test_stats["acc1"])
- print(f'Max accuracy: {max_accuracy:.2f}%')
-
- log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
- **{f'test_{k}': v for k, v in test_stats.items()},
- 'epoch': epoch,
- 'n_parameters': n_parameters}
-
- if args.output_dir and utils.is_main_process():
- with (output_dir / "log.txt").open("a") as f:
- f.write(json.dumps(log_stats) + "\n")
-
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print('Training time {}'.format(total_time_str))
-
-
-
- if __name__ == '__main__':
-
- parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
- args = parser.parse_args()
-
- if args.output_dir:
- Path(args.output_dir).mkdir(parents=True, exist_ok=True)
- main(args)
|