|
- import sys
- sys.path.append('/code/')
- import onnx
- import onnxruntime
- import numpy as np
- import torch
- import argparse
- from timm.models import create_model, load_checkpoint
- from timm.models.helpers import load_state_dict
- import os
- from prototype.prototype.model.vit.swin_transformer import swin_tiny
- def main():
- # parser = argparse.ArgumentParser(description='Adversarial Solver')
- # parser.add_argument('--ori_torch_model', required=True, type=str)
- # parser.add_argument('--new_torch_model', required=True, type=str)
- #
- # args = parser.parse_args()
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
- im = np.random.randn(*[1, 3, 224, 224]).astype(np.float32)
- # im = im.to(device)
-
- ori_model = create_model(
- 'swin_tiny_patch4_window7_224',
- pretrained=False,
- num_classes=100,
- drop_rate=0.1,
- drop_path_rate=0.0
- )
-
- load_checkpoint(ori_model, '/userhome/magic_liu/checkpoint/ckpt_swinTiny.pth')
- # state_dict_ori = load_state_dict(os.path.join(os.path.dirname(__file__), 'ckpt.pth'))
- # ori_model.load_state_dict(state_dict_ori, strict=True)
- # state_dict_ori = torch.load(os.path.join(os.path.dirname(__file__), 'ckpt.pth'))
- # ori_model.load_state_dict(state_dict_ori, strict=True)
- ori_model.eval()
- output_ori = ori_model(torch.from_numpy(im))
-
- print('ori_torch_model:', output_ori)
-
- from model import Model
- torch_model = Model().eval()
- torch_outs = torch_model(torch.from_numpy(im).to(device))
- print('torch_output:', torch_outs)
-
- #RobustART model
- Rob_model = swin_tiny(drop_rate=0.1, attn_drop_rate=0.0, drop_path_rate=0.0, num_classes=100)
- # params = torch.load(os.path.join('/userhome/magic_liu/checkpoint/ckpt_swinTiny.pth'), map_location=torch.device('cpu'))
- params = load_state_dict('/userhome/magic_liu/checkpoint/ckpt_swinTiny.pth')
- Rob_model.load_state_dict(params, strict=True)
- Rob_model.eval()
- Rob_output = Rob_model(torch.from_numpy(im))
- print('Rob_output:', Rob_output)
-
- if __name__ == '__main__':
- main()
|