|
-
- from mindspore.train.callback import ModelCheckpoint
- from mindspore import context
- from mindspore._checkparam import Validator
- from mindspore.common.parameter import Parameter
- from mindspore.common.tensor import Tensor
- import mindspore.nn as nn
- import os
- from threading import Thread, Lock
-
- from param_server import ParamHunter
-
- _ckpt_mutex = Lock()
-
-
- def _get_merged_param_data(net, param_name, param_data, integrated_save):
- """
- Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
-
- Args:
- net (Cell): MindSpore network.
- param_name (str): The parameter name, which to be combined.
- param_data (Tensor): The parameter data on the local device, which was a slice of the whole parameter data.
- integrated_save (bool): Whether to integrated save in automatic model parallel scene.
- Returns:
- Tensor, the combined tensor which with the whole data value.
- """
- from mindspore.parallel._cell_wrapper import get_allgather_cell
- from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
- layout = net.parameter_layout_dict[param_name]
- if len(layout) < 6:
- return param_data
-
- dev_mat = layout[0]
- tensor_map = layout[1]
- field_size = layout[3]
- uniform_split = layout[4]
- opt_shard_group = layout[5]
-
- allgather_net = None
- if param_name in net.parallel_parameter_merge_net_dict:
- allgather_net = net.parallel_parameter_merge_net_dict[param_name]
-
- if integrated_save:
- if uniform_split == 0:
- raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.")
- # while any dim is not equal to -1, means param is split and needs to be merged
- # pipeline parallel need to be supported here later
- for dim in tensor_map:
- if dim != -1:
- if allgather_net is None:
- if opt_shard_group:
- allgather_net = get_allgather_cell(opt_shard_group, True)
- else:
- allgather_net = get_allgather_cell(opt_shard_group, False)
- net.parallel_parameter_merge_net_dict[param_name] = allgather_net
- param_data = allgather_net(param_data)
- if field_size:
- return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
- return _reshape_param_data(param_data, dev_mat, tensor_map)
- if opt_shard_group:
- if allgather_net is None:
- allgather_net = get_allgather_cell(opt_shard_group, False)
- net.parallel_parameter_merge_net_dict[param_name] = allgather_net
- param_data = allgather_net(param_data)
- elif opt_shard_group:
- if allgather_net is None:
- allgather_net = get_allgather_cell(opt_shard_group, False)
- net.parallel_parameter_merge_net_dict[param_name] = allgather_net
- param_data = allgather_net(param_data)
- return param_data
-
-
- def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False):
- """
- Saves checkpoint info to a specified file.
-
- Args:
- save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like
- [{"name": param_name, "data": param_data},...], the type of
- param_name would be string, and the type of param_data would
- be parameter or `Tensor`).
- ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
- integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
- async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
-
- Raises:
- TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter
- `integrated_save` and `async_save` are not bool type.
- """
-
- if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
- raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
- integrated_save = Validator.check_bool(integrated_save)
- async_save = Validator.check_bool(async_save)
-
-
- if isinstance(save_obj, nn.Cell):
- save_obj.init_parameters_data()
- param_dict = {}
- for _, param in save_obj.parameters_and_names():
- param_dict[param.name] = param
- param_list = []
- for (key, value) in param_dict.items():
- each_param = {"name": key}
- param_data = Tensor(value.data)
-
- # in automatic model parallel scenario, some parameters were spliteds to all the devices,
- # which should be combined before saving
- if key in save_obj.parameter_layout_dict:
- param_data = _get_merged_param_data(save_obj, key, param_data, integrated_save)
-
- each_param["data"] = param_data
- param_list.append(each_param)
- save_obj = param_list
-
- data_list = {}
- with _ckpt_mutex:
- for param in save_obj:
- key = param["name"]
- data_list[key] = []
- if isinstance(param["data"], Parameter):
- param["data"].init_data()
- dims = []
- if param['data'].shape == ():
- dims.append(0)
- else:
- for dim in param['data'].shape:
- dims.append(dim)
- data_list[key].append(dims)
- tensor_type = str(param["data"].dtype)
- data_list[key].append(tensor_type)
- data = param["data"].asnumpy().reshape(-1)
- data_list[key].append(data)
-
- print()
-
-
-
- class MyModelCheckpoint(ModelCheckpoint):
- def step_end(self, run_context):
- """
- Save the checkpoint at the end of step.
-
- Args:
- run_context (RunContext): Context of the train running.
- """
-
- from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
- from mindspore.train._utils import _make_directory
- from mindspore.train.serialization import _save_graph
- import threading, os
-
- if _is_role_pserver():
- self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
- cb_params = run_context.original_args()
- _make_directory(self._directory)
- # save graph (only once)
- if not self._graph_saved:
- graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
- if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE:
- os.remove(graph_file_name)
- _save_graph(cb_params.train_network, graph_file_name)
- self._graph_saved = True
- thread_list = threading.enumerate()
- for thread in thread_list:
- if thread.getName() == "asyn_save_ckpt":
- thread.join()
- self._save_ckpt(cb_params)
-
-
- def _save_ckpt(self, cb_params, force_to_save=False):
- """Save checkpoint files."""
- import time
-
- if cb_params.cur_step_num == self._last_triggered_step:
- return
-
- # if param is cache enable, flush data from cache to host before save_ckpt
- if self._need_flush_from_cache:
- self._flush_from_cache(cb_params)
-
- save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
- step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
-
- if save_ckpt:
- cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
- + str(step_num_in_epoch) + ".ckpt"
- # update checkpoint file list.
- self._manager.update_ckpoint_filelist(self._directory, self._prefix)
- # keep checkpoint files number equal max number.
- if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
- self._manager.remove_oldest_ckpoint_file()
- elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
- self._cur_time_for_keep = time.time()
- if (self._cur_time_for_keep - self._last_time_for_keep) \
- < self._config.keep_checkpoint_per_n_minutes * 60:
- self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
- self._cur_time_for_keep)
-
- # generate the new checkpoint file and rename it.
- global _save_dir
- _save_dir = self._directory
- cur_file = os.path.join(self._directory, cur_ckpoint_file)
- self._last_time_for_keep = time.time()
- self._last_triggered_step = cb_params.cur_step_num
-
- network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
- save_checkpoint(network, cur_file, self._config.integrated_save,
- self._config.async_save)
-
- self._latest_ckpt_file_name = cur_file
|