|
- """
- python pth2ckpt.py
- """
- import torch
- from mindspore.train.serialization import save_checkpoint
- from mindspore import Tensor
- import numpy
-
- paramorg = {
- 'pcd_align.offset_conv1_l3.weight':'pcd_align.offset_conv1.l3.weight',
- 'pcd_align.offset_conv1_l3.bias':'pcd_align.offset_conv1.l3.bias',
- 'pcd_align.offset_conv2_l3.weight': 'pcd_align.offset_conv2.l3.weight',
- 'pcd_align.offset_conv2_l3.bias': 'pcd_align.offset_conv2.l3.bias',
- 'pcd_align.dcn_pack_l3.conv.weight':'pcd_align.dcn_pack.l3.conv.weight',
- 'pcd_align.dcn_pack_l3.p_conv.weight':'pcd_align.dcn_pack.l3.p_conv.weight',
- 'pcd_align.dcn_pack_l2.conv.weight':'pcd_align.dcn_pack.l2.conv.weight',
- 'pcd_align.dcn_pack_l2.p_conv.weight':'pcd_align.dcn_pack_l2.p_conv.weight',
- 'pcd_align.offset_conv1_l2.weight':'pcd_align.offset_conv1.l2.weight',
- 'pcd_align.offset_conv1_l2.bias': 'pcd_align.offset_conv1.l2.bias',
- 'pcd_align.offset_conv2_l2.weight':'pcd_align.offset_conv2.l2.weight',
- 'pcd_align.offset_conv2_l2.bias':'pcd_align.offset_conv2.l2.bias',
- 'pcd_align.offset_conv3_l2.weight':'pcd_align.offset_conv3.l2.weight',
- 'pcd_align.offset_conv3_l2.bias':'pcd_align.offset_conv3.l2.bias',
- 'pcd_align.feat_conv_l2.weight':'pcd_align.feat_conv.l2.weight',
- 'pcd_align.feat_conv_l2.bias':'pcd_align.feat_conv.l2.bias',
- 'pcd_align.offset_conv1_l1.weight':'pcd_align.offset_conv1.l1.weight',
- 'pcd_align.offset_conv1_l1.bias':'pcd_align.offset_conv1.l1.bias',
- 'pcd_align.offset_conv2_l1.weight':'pcd_align.offset_conv2.l1.weight',
- 'pcd_align.offset_conv2_l1.bias':'pcd_align.offset_conv2.l1.bias',
- 'pcd_align.offset_conv3_l1.weight':'pcd_align.offset_conv3.l1.weight',
- 'pcd_align.offset_conv3_l1.bias':'pcd_align.offset_conv3.l1.bias',
- 'pcd_align.dcn_pack_l1.conv.weight':'pcd_align.dcn_pack.l1.conv.weight',
- 'pcd_align.dcn_pack_l1.p_conv.weight':'pcd_align.dcn_pack.l1.p_conv.weight',
- 'pcd_align.feat_conv_l1.weight':'pcd_align.feat_conv.l1.weight',
- 'pcd_align.feat_conv_l1.bias':'pcd_align.feat_conv.l1.bias',
- #'pcd_align.cas_dcnpack.conv.weight':'pcd_align.cas_dcnpack.conv_offset.weight',
- }
-
- param={}
- for key in paramorg:
- rkey= paramorg[key]
- rvalue=key
- param[rkey]=rvalue
- print(param)
-
-
-
- def pytorch2mindspore():
- """
-
- Returns:
- object:
- """
- #par_dict = torch.load('EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth', map_location='cpu')
- par_dict = torch.load('EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth')
- new_params_list = []
- for name in par_dict['params']:
- param_dict = {}
- parameter = par_dict['params'][name]
- print(name)
- for fix in param:
- if name.endswith(fix):
- name = name[:name.rfind(fix)]
- name = name + param[fix]
-
- print('========================ibn_name', name)
-
- param_dict['name'] = name
-
- param_dict['data'] = Tensor(parameter.numpy())
- new_params_list.append(param_dict)
-
- save_checkpoint(new_params_list, 'EDVR_L_x4_SR_Vimeo90K_official-162b54e4.ckpt')
-
-
- pytorch2mindspore()
|