|
- #coding=utf-8
-
- import os
- import numpy as np
- #import keras.models as KM
-
- import client_agent
- from mindspore import Tensor
- import mindspore.nn
-
- class ParamHunter:
- def __init__(self, net, debug=False):
- self.net = net
- #assert(isinstance(self.keras_model, KM.Model))
- self.param_keys = self._get_param_list(net)
- self.debug = debug
- if debug:
- print('using debug mode!')
- print(self.param_keys[:3])
-
- #获取模型参数名
- def _get_param_list(self, net):
- param_keys=[]
- if 1: #isinstance(keras_model, KM.Model):
- for _, param in net.parameters_and_names():
- param_keys.append(f"{param.name}")
-
- return param_keys
-
-
- def _parameterNameValue(self):
- var_value = []
- for _, param in self.net.parameters_and_names():
- # if layer_weights:
- var_value.append((f'{param.name}', param.asnumpy()))
- return var_value
-
-
- def _setModelPara(self, items):
- params_num = 0
- for _, param in self.net.parameters_and_names():
- for name, weights in items:
- if param.name == name:
- param.set_data(Tensor(weights).astype(param.dtype))
- params_num += 1
- break
- return params_num
-
-
- def init_params(self, initial):
- '''
- 从服务端获取参数
- initial==True: 获取同样的初始化参数
- initial==False: 获取各个节点平均后的参数
- '''
- #获取参数名称列表
- params_num = 0
- try:
- if initial == True:
- initial_params = client_agent.initialize_params_from_server(self.param_keys, self.debug)
- else:
- initial_params = client_agent.get_avg_params_from_server(self.param_keys, self.debug)
-
- params_num = self._setModelPara(initial_params)
- except Exception as e:
- print('init params error!!')
- print(e)
-
- return self.net, params_num
-
- def upload_params(self, keras_model, uuid, step_per_round):
- '''
- :@keras_model: Keras的模型
- '''
- params_num=0
- try:
- var_value = self._parameterNameValue()
- #print(f"var_value:{var_value}")
- print('upload_params....................')
- if self.debug:
- print(var_value[:3])
- client_agent.upload_final_models_to_server(uuid, step_per_round, var_value)
- except Exception as e:
- pass
- return params_num
-
- #从agent的50054端口获取模型参数
- def fill_params(self, step, _round, uuid):
- params_num = 0
- try:
- #50054端口加载模型参数时需要传递的变量
- uuids = [f'{step}-{_round}-{uuid}']
- #print(f"self.param_keys:{self.param_keys}")
- initial_params = client_agent.get_avg_params_from_server_localfile(self.param_keys, uuids, self.debug)
-
- params_num = self._setModelPara(initial_params)
- except Exception as e:
- print(e)
-
- return self.net, params_num
-
- def first_init(self):
- '''
- 向参数服务器上传随机参数,当各个训练节点第一次训练时起始参数相同
- '''
- # 加载数据集,然后将数据集划分为train/val两部分
- try:
- var_value = self._parameterNameValue()
- #print(var_value)
- client_agent.first_init(var_value)
- except Exception as e:
- print(e)
-
-
-
- if __name__ == "__main__":
- pass
|