|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
-
- # example:
- # from models.vgg import vgg11_bn as msa_vgg11
- # from t_models.vgg import vgg11_bn as t_vgg11
- # net1 = msa_vgg11()
- # net2 = t_vgg11()
- # input_shape = (5, 3, 32, 32)
- # comp_torch_inf_msa_inf(net2, net1, input_shape, 'vgg11')
-
- import numpy as np
-
- import mindtorch.torch as m_torch
-
- import mindspore as ms
- from mindspore import Tensor, save_checkpoint
-
- import torch
-
-
- def comp_torch_inf_msa_inf(torch_net, msa_net, input_shape, net_prefix='', eps=1e-5):
- np_input = np.random.random(input_shape).astype(np.float32)
- torch_input = torch.tensor(np_input)
- msa_input = m_torch.tensor(np_input)
-
- torch.save(torch_net.state_dict(), net_prefix + '_model.pth')
- pth2ckpt(net_prefix + '_model.pth')
-
- #torch load weights
- torch_net.load_state_dict(torch.load(net_prefix + '_model.pth'), strict=True)
- torch_net.eval()
- torch_output = torch_net(torch_input)
-
- # mindtorch load weights
- ms.load_checkpoint(net_prefix + '_model.ckpt', msa_net)
- msa_net.eval()
- msa_output = msa_net(msa_input)
-
- if np.allclose(torch_output.detach().numpy(), msa_output.numpy(), eps):
- print("Model %s has been verified." % net_prefix)
- else:
- print("Model %s fails the verification." % net_prefix)
-
-
- def pth2ckpt(pth_file):
- torch_dict = torch.load(pth_file)
- ms_params = []
- for name, value in torch_dict.items():
- # print(name, type(value))
- if isinstance(value, dict):
- for k, v in value.items():
- param_dict = {}
- param_dict['name'] = k
- if isinstance(v, torch.Tensor):
- param_dict['data'] = Tensor(v.detach().cpu().numpy())
- else:
- param_dict['data'] = Tensor(v)
- ms_params.append(param_dict)
- continue
- else:
- param_dict = {}
- param_dict['name'] = name
- if isinstance(value, torch.Tensor):
- param_dict['data'] = Tensor(value.detach().cpu().numpy())
- else:
- param_dict['data'] = Tensor(value)
- ms_params.append(param_dict)
-
- save_checkpoint(ms_params, pth_file[:-3] + "ckpt")
- # print("convert ckpt finish.")
|