|
- import onnx
- import onnxruntime
- import numpy as np
- import torch
- import argparse
- from timm.models import create_model, load_checkpoint
- from timm.models.helpers import load_state_dict
- import os
- def main():
- # parser = argparse.ArgumentParser(description='Adversarial Solver')
- # parser.add_argument('--ori_torch_model', required=True, type=str)
- # parser.add_argument('--new_torch_model', required=True, type=str)
- #
- # args = parser.parse_args()
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
- im = np.random.randn(*[1, 3, 224, 224]).astype(np.float32)
- # im = im.to(device)
-
- ori_model = create_model(
- 'deit_small_patch16_224',
- pretrained=False,
- num_classes=100,
- drop_rate=0.0,
- drop_path_rate=0.1,
- drop_block_rate=None
- )
- load_checkpoint(ori_model, '/model/checkpoint-299-ares_cvpr_deit_small_baseline.pth.tar')
- # state_dict_ori = load_state_dict(os.path.join(os.path.dirname(__file__), 'ckpt.pth'))
- # ori_model.load_state_dict(state_dict_ori, strict=True)
- # state_dict_ori = torch.load(os.path.join(os.path.dirname(__file__), 'ckpt.pth'))
- # ori_model.load_state_dict(state_dict_ori, strict=True)
- ori_model.eval()
- output_ori = ori_model(torch.from_numpy(im))
-
- print('ori_torch_model:', output_ori)
-
- from model import Model
- torch_model = Model().eval()
- torch_outs = torch_model(torch.from_numpy(im).to(device))
- print('torch_output:', torch_outs)
-
- if __name__ == '__main__':
- main()
|