|
- #coding=utf-8
-
- import os
- import numpy as np
- #import keras.models as KM
-
- import client_agent
- from mindspore import Tensor
- import mindspore.nn
- import datetime
-
- # upload_f = open("/cache/upload_param.txt", "a")
- # download_f = open("/cache/download_param.txt", "a")
-
- # upload_f = open("./upload_param.txt", "a")
- # download_f = open("./download_param.txt", "a")
-
- class ParamHunter:
- def __init__(self, net, debug=False):
- self.net = net
- #assert(isinstance(self.keras_model, KM.Model))
- self.param_keys = self._get_param_list(net)
- self.debug = debug
- if debug:
- print('using debug mode!')
- print(self.param_keys[:3])
-
- #获取模型参数名
- def _get_param_list(self, net):
- param_keys=[]
- if 1: #isinstance(keras_model, KM.Model):
- for _, param in net.parameters_and_names():
- #for param in net.trainable_params():
- #if('gamma' in param.name):
- # param.name = param.name.replace('gamma', 'weight')
- #if('beta' in param.name):
- # param.name = param.name.replace('beta', 'bias')
- #if('down_sample_layer' in param.name):
- # param.name = param.name.replace('down_sample_layer', 'downsample')
- #if('end_point' in param.name):
- # param.name = param.name.replace('end_point', 'fc')
- #print(f"param.name:{param.name}")
- param_keys.append(f"{param.name}")
- param_keys.sort()
- return param_keys
-
-
- def _parameterNameValue(self):
- var_value = []
- params_num = 0
- for _, param in self.net.parameters_and_names():
- # print(f"param.name:{param.name}")
- params_num += 1
- var_value.append((f'{param.name}', param.asnumpy()))
- var_value.sort(key=lambda x: x[0])
- return var_value, params_num
-
-
- def _setModelPara(self, items):
- params_num = 0
- print('download_params....................')
- for _, param in self.net.parameters_and_names():
- #print(f"items.size():{len(items)}")
- for name, weights in items:
- if param.name == name:
- #if self.debug:
- #print(Tensor(weights).astype(param.dtype))
- # print("download_var_value:", Tensor(weights).astype(param.dtype)[:3], file=download_f)
-
- param.set_data(Tensor(weights,mindspore.float32))
- params_num += 1
- break
- return params_num
-
-
- def init_params(self, initial, globalStep):
- '''
- 从服务端获取参数
- initial==True: 获取同样的初始化参数
- initial==False: 获取各个节点平均后的参数
- '''
- #获取参数名称列表
- params_num = 0
- start_time = datetime.datetime.now()
- try:
- if initial == True and globalStep == 0:
- initial_params = client_agent.initialize_params_from_server(self.param_keys, self.debug)
- else:
- initial_params = client_agent.get_avg_params_from_server(self.param_keys, self.debug)
- end_pull_time = datetime.datetime.now()
- print('[1] Pull time is {}s'.format((end_pull_time-start_time).seconds))
- params_num = self._setModelPara(initial_params)
- end_load_time = datetime.datetime.now()
- print('[2] Then load params time is {}s'.format((end_load_time-end_pull_time).seconds))
- except Exception as e:
- print('init params error!!')
- print(e)
-
- return self.net, params_num
-
- def upload_params(self, uuid, step_per_round):
- '''
- :@keras_model: Keras的模型
- '''
- var_value,params_num = self._parameterNameValue()
- #print(params_num)
- #print(f"var_value:{var_value}")
- print('upload_params....................')
- if self.debug:
- pass
- #print(var_value[:3])
- # print("var_value:", var_value[:3], file=upload_f)
- client_agent.upload_final_models_to_server(uuid, step_per_round, var_value)
- # except Exception as e:
- # pass
- print('!! upload finished !!')
- return params_num
-
- #从agent的50054端口获取模型参数
- def fill_params(self, step, _round, uuid):
- params_num = 0
- try:
- #50054端口加载模型参数时需要传递的变量
- uuids = [f'{step}-{_round}-{uuid}']
- #print(f"self.param_keys:{self.param_keys}")
- initial_params = client_agent.get_avg_params_from_server_localfile(self.param_keys, uuids, self.debug)
-
- params_num = self._setModelPara(initial_params)
- except Exception as e:
- print(e)
-
- return self.net, params_num
-
- def first_init(self):
- '''
- 向参数服务器上传随机参数,当各个训练节点第一次训练时起始参数相同
- '''
- # 加载数据集,然后将数据集划分为train/val两部分
- params_num=0
- try:
- var_value, params_num = self._parameterNameValue()
- #print(var_value)
- client_agent.first_init(var_value)
- except Exception as e:
- print(e)
-
-
-
- if __name__ == "__main__":
- pass
|