#291 [ MSAdapter使用体验] 模型参数转换工具pth2ckpt优化

Closed
created 1 year ago by mirror_yun · 2 comments
pytorch模型可能被torch.nn.DataParallel封装,pth文件也可能包含不仅仅只包参数的dict。 这里给出我用的: ``` import argparse from collections import OrderedDict import torch from mindspore import Tensor, save_checkpoint def parse_args(): parser = argparse.ArgumentParser(description='convert pth2ckpt') parser.add_argument('model', help='pytorch pth file path') args = parser.parse_args() return args def pth2ckpt(path): # pth_file = sys.argv[1] torch_dict = torch.load(path, map_location='cpu') ms_params = [] if 'state_dict' in torch_dict.keys(): torch_dict = torch_dict['state_dict'] if 'module' in list(torch_dict.keys())[0]: # remove module new_state_dict = OrderedDict() for k, v in torch_dict.items(): name = k[7:] new_state_dict[name] = v torch_dict = new_state_dict for name, value in torch_dict.items(): if isinstance(value, dict): for k, v in value.items(): param_dict = {} param_dict['name'] = k if isinstance(v, torch.Tensor): param_dict['data'] = Tensor(v.detach().cpu().numpy()) else: param_dict['data'] = Tensor(v) ms_params.append(param_dict) continue else: param_dict = {} param_dict['name'] = name if isinstance(value, torch.Tensor): param_dict['data'] = Tensor(value.detach().cpu().numpy()) else: param_dict['data'] = Tensor(value) ms_params.append(param_dict) out_path = path[:-3] + ".ckpt" save_checkpoint(ms_params,out_path) print(f"convert ckpt finish, output file: {out_path}") def main(): args = parse_args() pth2ckpt(args.model) if __name__ == '__main__': main() ``` 使用方法 ``` python pth2ckpt.py path_to_pth_file ```
SuMuzi commented 1 year ago
good project
ZezhengZ commented 1 year ago
nice job
zoulq closed this issue 3 days ago
Sign in to join this conversation.
No Label
No Milestone
No Assignees
3 Participants
Notifications
Due Date

No due date set.

Dependencies

This issue currently doesn't have any dependencies.

Loading…
There is no content yet.