|
- from mindspore import Tensor
- from mindspore.train.serialization import save_checkpoint
- import mindspore as ms
- import mindspore
- import torch
- import numpy as np
- import collections
-
-
-
-
- def transfer_name(name):
- if name.endswith('bn1.weight') or name.endswith('bn2.weight') or name.endswith('bn3.weight') or name.endswith(
- "downsample.1.weight"):
- name = name.replace("weight", "gamma")
- elif name.endswith('bn1.bias') or name.endswith('bn2.bias') or name.endswith('bn3.bias') or name.endswith(
- "downsample.1.bias"):
- name = name.replace("bias", "beta")
- elif name.endswith('bn1.running_mean') or name.endswith('bn2.running_mean') or name.endswith(
- 'bn3.running_mean') or name.endswith("downsample.1.running_mean"):
- name = name.replace("running_mean", "moving_mean")
- elif name.endswith('bn1.running_var') or name.endswith('bn2.running_var') or name.endswith(
- 'bn3.running_var') or name.endswith("downsample.1.running_var"):
- name = name.replace("running_var", "moving_variance")
- return name
-
-
- def pytorch2mindspore():
- par_dict = torch.load('mvtec_wide_resnet50_2_freia-cflow_pl3_cb8_inp512_run0_bottle_2022-07-28-10_58_16.pt',
- map_location='cpu')
-
- for name in par_dict.keys():
- if name == 'encoder_state_dict':
- parameter = par_dict[name]
- new_parameter = []
- for k, v in parameter.items():
- if k.endswith("num_batches_tracked"):
- continue;
- v = v.numpy()
- v = ms.Tensor(v)
- k = transfer_name(k)
- temp = dict()
- temp['name'] = k
- temp['data'] = v
- new_parameter.append(temp)
-
- print('Save encoder')
- # ms.load_param_into_net(encoder, new_parameter)
- save_checkpoint(new_parameter, './ckpt/pt2ckpt/wide_resnet50_2_encoder.ckpt')
- print('Save successfully')
-
- else:
- with open('decoder weight size.txt', 'w') as file:
- for i in range(len(par_dict[name])):
- parameter = par_dict[name][i]
- new_parameter = []
- #file.write('\ndecoder'+str(i)+':')
- for k, v in parameter.items():
- file.write(str(v.shape)+'\n')
- v = v.numpy()
- v = ms.Tensor(v)
- k = transfer_name(k)
- temp = dict()
- temp['name'] = 'net.'+k
- temp['data'] = v
- new_parameter.append(temp)
-
- print('Save decoder {}'.format(i))
- save_checkpoint(new_parameter, './ckpt/pt2ckpt/wide_resnet50_2_decoder_'+str(i)+'.ckpt')
- print('Save successfully')
-
-
-
- '''
- encoder_dict = par_dict['encoder_state_dict']
- decoder_dict_list = par_dict['decoder_state_dict']
- with open('pt_weight.txt', 'w') as file:
- file.write('Encoder weight:\n')
- for k, v in encoder_dict.items():
- file.write(str(k)+'\n')
- # print('---------------------------------------')
- for i in range(len(decoder_dict_list)):
- file.write('\nDecoder {} weight:\n'.format(i+1))
- for k, v in decoder_dict_list[i].items():
- file.write(str(k)+'\n')
- # print('---------------------------------------')
-
-
- for param in new_params_list:
- # k: encoder, decoder
- # v: dict, list[dict]
- if param['name'] == 'encoder_state_dict':
- ms.load_param_into_net(encoder, param['data'])
- save_checkpoint(encoder, './ckpt/wide_resnet50_2_encoder.ckpt')
- else:
- for i in range(3):
- ms.load_param_into_net(decoders[i], param['data'])
- save_checkpoint(decoders[i], './ckpt/wide_resnet50_2_decoder_'+str(i+1)+'.ckpt')
-
- '''
-
- #pytorch2mindspore()
-
- def printckpt(ckpt_file_name1,ckpt_file_name2):
- param_dict1 = mindspore.load_checkpoint(ckpt_file_name1)
- param_dict2 = mindspore.load_checkpoint(ckpt_file_name2)
- for name1, name2 in zip(param_dict1, param_dict2):
- print(name1+" "+name2)
- print(param_dict1[name1])
- print(param_dict2[name2])
- print("------------------------------------------")
-
- printckpt("./weights/bottle/2022-07-30-20:49:30/epoch_7_mvtec_wide_resnet50_2_freia-cflow_pl3_cb8_inp512_run0_bottle_2022-07-30-20:49:30_encoder.ckpt","./ckpt/pt2ckpt/wide_resnet50_2_encoder.ckpt")
-
-
-
-
-
-
-
-
-
-
-
-
-
|