|
- """
- Benchmark inference speed of Deformable DETR.
- """
- import os
- import time
- import argparse
- import torch
- from main import get_args_parser as get_main_args_parser
- from models import build_model
- from datasets import build_dataset
- from util.misc import nested_tensor_from_tensor_list
-
-
- def get_benckmark_arg_parser():
- parser = argparse.ArgumentParser('Benchmark inference speed of Deformable DETR.')
- parser.add_argument('--num_iters', type=int, default=300, help='total iters to benchmark speed')
- parser.add_argument('--warm_iters', type=int, default=5, help='ignore first several iters that are very slow')
- parser.add_argument('--batch_size', type=int, default=1, help='batch size in inference')
- parser.add_argument('--resume', type=str, help='load the pre-trained checkpoint')
- return parser
-
-
- @torch.no_grad()
- def measure_average_inference_time(model, inputs, num_iters=100, warm_iters=5):
- ts = []
- for iter_ in range(num_iters):
- torch.cuda.synchronize()
- t_ = time.perf_counter()
- model(inputs)
- torch.cuda.synchronize()
- t = time.perf_counter() - t_
- if iter_ >= warm_iters:
- ts.append(t)
- print(ts)
- return sum(ts) / len(ts)
-
-
- def benchmark():
- args, _ = get_benckmark_arg_parser().parse_known_args()
- main_args = get_main_args_parser().parse_args(_)
- assert args.warm_iters < args.num_iters and args.num_iters > 0 and args.warm_iters >= 0
- assert args.batch_size > 0
- assert args.resume is None or os.path.exists(args.resume)
- dataset = build_dataset('val', main_args)
- model, _, _ = build_model(main_args)
- model.cuda()
- model.eval()
- if args.resume is not None:
- ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage)
- model.load_state_dict(ckpt['model'])
- inputs = nested_tensor_from_tensor_list([dataset.__getitem__(0)[0].cuda() for _ in range(args.batch_size)])
- t = measure_average_inference_time(model, inputs, args.num_iters, args.warm_iters)
- return 1.0 / t * args.batch_size
-
-
- if __name__ == '__main__':
- fps = benchmark()
- print(f'Inference Speed: {fps:.1f} FPS')
|