|
- #coding=utf-8
-
- import os
- import numpy as np
- #import keras.models as KM
-
- import client_agent
-
-
- class ParamHunter:
- def __init__(self, keras_model, debug=False):
- self.keras_model = keras_model
- #assert(isinstance(self.keras_model, KM.Model))
- self.param_keys = self._get_param_list(keras_model)
- self.debug = debug
- if debug:
- print(self.param_keys)
-
- #获取模型参数名
- def _get_param_list(self, keras_model):
- param_keys=[]
- #if isinstance(keras_model, KM.Model):
- for layer in keras_model.layers:
- layer_weights = layer.get_weights()
- # if layer_weights:
- for index in range(len(layer_weights)):
- param_keys.append(f"{layer.name}_{index:02d}")
-
- return param_keys
-
-
- def init_params(self, initial):
- '''
- 从服务端获取参数
- initial==True: 获取同样的初始化参数
- initial==False: 获取各个节点平均后的参数
- '''
- #获取参数名称列表
- params_num=0
- try:
- if initial == True:
- initial_params_op = client_agent.initialize_params_from_server(self.param_keys, self.debug)
- else:
- initial_params_op = client_agent.get_avg_params_from_server(self.param_keys, self.debug)
-
-
- for layer in self.keras_model.layers:
- for name, weights in initial_params_op.items():
- if layer.name == name:
- layer.set_weights(weights)
- params_num += 1
- break
- except Exception as e:
- print(e)
-
- return self.keras_model, params_num
-
- def upload_params(self, keras_model, uuid, step_per_round):
- '''
- :@keras_model: Keras的模型
- '''
- params_num=0
- try:
- var_value = []
- for layer in keras_model.layers:
- layer_weights = layer.get_weights()
- # if layer_weights:
- for index in range(len(layer_weights)):
- params_num+=1
- var_value.append((f'{layer.name}_{index:02d}', layer_weights[index]))
-
- client_agent.upload_final_models_to_server(uuid, step_per_round, var_value)
- except Exception as e:
- pass
- return params_num
-
- #从agent的50054端口获取模型参数
- def fill_params(self, step, _round, uuid):
- try:
- #50054端口加载模型参数时需要传递的变量
- uuids = [f'{step}-{_round}-{uuid}']
- initial_params_op = client_agent.get_avg_params_from_server_localfile(self.param_keys, uuids, self.debug)
-
- for layer in self.keras_model.layers:
- for name, weights in initial_params_op.items():
- if layer.name == name:
- layer.set_weights(weights)
- break
- except Exception as e:
- print(e)
- return self.keras_model
-
- def first_init(self):
- '''
- 向参数服务器上传随机参数,当各个训练节点第一次训练时起始参数相同
- '''
- # 加载数据集,然后将数据集划分为train/val两部分
- try:
- var_value={}
- for layer in self.keras_model.layers:
- layer_weights = layer.get_weights()
- print(layer.name)
- print(len(layer_weights))
- for index in range(len(layer_weights)):
- var_value[f'{layer.name}_{index:02d}'] = layer_weights[index]
-
- print(var_value)
- client_agent.first_init(var_value)
- except Exception as e:
- print(e)
-
-
-
- if __name__ == "__main__":
- pass
|