|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """train resnet."""
- import os
- import argparse
- from mindspore import context
- from mindspore.common import set_seed
- from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
- from mindspore.train.model import Model
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore import save_checkpoint
- from src.CrossEntropySmooth import CrossEntropySmooth
-
- from global_init_new import ParamHunter
- from thgy_client import THGYApiClient
- import datetime
- import time
-
-
- parser = argparse.ArgumentParser(description='Image classification')
- parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
- parser.add_argument('--dataset', type=str, default="/root/jointcloud/data/eval", help='Dataset, either cifar10 or imagenet2012')
- parser.add_argument('--val_dir', type=str, default="/root/jointcloud/data/eval", help='Dataset, either cifar10 or imagenet2012')
- parser.add_argument('--checkpoint_path', type=str, default='/root/jointcloud/models/avg', help='Checkpoint file path')
- parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
- parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
- parser.add_argument('--model_path', type=str, default='models/avg', help='model path')
- parser.add_argument('--uuid', type=str, default='test', help='uuid.')
- args_opt = parser.parse_args()
- # uuid = args_opt.uuid
- # if '_' in uuid:
- # group_id, task_id = map(int, uuid.split('_'))
- # else:
- # group_id = 0
- # task_id = int(uuid)
-
- set_seed(1)
- model_path = args_opt.model_path
- val_dir = args_opt.val_dir
-
- if args_opt.net == "resnet50":
- from src.resnet import resnet50 as resnet
- if args_opt.dataset == "cifar10":
- from src.config import config1 as config
- from src.dataset import create_dataset1 as create_dataset
- else:
- from src.config import config2 as config
- from src.dataset import create_dataset2 as create_dataset
- elif args_opt.net == "resnet101":
- from src.resnet import resnet101 as resnet
- from src.config import config3 as config
- from src.dataset import create_dataset3 as create_dataset
- else:
- from src.resnet import se_resnet50 as resnet
- from src.config import config4 as config
- from src.dataset import create_dataset4 as create_dataset
-
- def _parse_param(model_path):
- split_items = os.path.basename(model_path).split('-')
- if len(split_items) == 4:
- timestamp, step, _round, uuid = split_items
- elif len(split_items) > 4:
- timestamp = split_items[0]
- step = split_items[1]
- _round = split_items[2]
- uuid = split_items[3]
- else:
- uuid = 'avg'
- timestamp=''
- _round = 1
- step = 1
- #界面展示用
- #获取任务组id
- group_id = 0
- _task_id = 0
- try:
- client_uuids = [d.name for d in os.scandir(os.path.dirname(os.path.dirname(model_path))) if d.is_dir() and 'avg' not in d.name]
- _group_id, _task_id = client_uuids[0].split('_')
- group_id=int(_group_id)
- except Exception as e:
- print(e)
- return step, _round, uuid, group_id, _task_id
-
- # @app.route('/')
- # def index():
- # #训练之前,需要从JCCE.agent初始化model的参数
- # model_path = request.args.get("model_path")
-
- def _watchdog_callback(model_path):
- print('======'*20)
- print(model_path)
- step, _round, uuid, group_id, task_id= _parse_param(model_path)
- # step, _round, uuid, group_id = 1, 1, 'avg', 0
- print(step, _round, uuid, group_id, task_id)
- net, fill_params_num = param_hunter.fill_params(step, _round, uuid)
- print(f'fill_params_num : {fill_params_num}')
- #accu = test(model, test_data)
- loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
- model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
- # save_checkpoint(net, model_path+'.ckpt')
- accu = model.eval(dataset, dataset_sink_mode=False)
-
-
- print("result:", accu, "ckpt=", args_opt.checkpoint_path)
- start_time = datetime.datetime.now()
- end_time = datetime.datetime.now()
- result = api_client.add_task_training_data(group_id, task_id, _round,
- recall=0, precision=accu['top_5_accuracy'],
- startTime=start_time.strftime("%Y-%m-%d %H:%M:%S.%f"),
- endTime=end_time.strftime("%Y-%m-%d %H:%M:%S.%f"))
- print(result)
-
- #同步监听avg模型目录
- def _watchdog(model_path):
- try:
- path_to_watch = model_path
- before = dict ([(f.path, None) for f in os.scandir(path_to_watch)])
- time_log = time.time()
- while True:
- time.sleep(2)
- after = dict ([(f.path, None) for f in os.scandir(path_to_watch)])
- added = [f for f in after if not f in before]
- removed = [f for f in before if not f in after]
- if added:
- print(f'-----------Took {time.time()-time_log} seconds per epoch-----------')
- time_log = time.time()
- _watchdog_callback(added[0])
- print("Added: ", ", ".join(added))
- before = after
- except KeyboardInterrupt:
- print(f'stop watching folder:{model_path}')
-
-
- if __name__ == '__main__':
- target = args_opt.device_target
- if not os.path.exists(model_path):
- os.makedirs(model_path)
- # init context
- context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
- if target != "GPU":
- device_id = int(os.getenv('DEVICE_ID'))
- context.set_context(device_id=device_id)
-
- # create dataset
- dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size,
- target=target)
- step_size = dataset.get_dataset_size()
-
- # define net
- net = resnet(class_num=config.class_num)
-
- # load checkpoint
- #param_dict = load_checkpoint(args_opt.checkpoint_path)
- #load_param_into_net(net, param_dict)
- net.set_train(False)
-
- # define loss, model
- if args_opt.dataset == "imagenet2012":
- if not config.use_label_smooth:
- config.label_smooth_factor = 0.0
- loss = CrossEntropySmooth(sparse=True, reduction='mean',
- smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
- else:
- loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
-
- # define model获取模型
- model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
-
- # eval model
- #res = model.eval(dataset)
- #print("result:", res, "ckpt=", args_opt.checkpoint_path)
- print(f'get model success')
- # 参数管理器
- param_hunter = ParamHunter(net, debug=False)
- # api客户端
- api_client = THGYApiClient()
-
- # _watchdog_callback(model_path)
-
- _watchdog(model_path)
|