|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """train resnet."""
- import os
- import argparse
- from mindspore import context
- from mindspore.common import set_seed
- from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
- from mindspore.train.model import Model
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from src.CrossEntropySmooth import CrossEntropySmooth
- import moxing as mox
-
- DATASET_DIR = 'obs://pcl-verify/yizx/other_verify/mindspore-client/cifar10/'
- LOCAL_PATH = "/cache/cifar10"
-
- parser = argparse.ArgumentParser(description='Image classification')
- parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
- parser.add_argument('--train_url', type=str, default=None, help='train_url')
- parser.add_argument('--data_url', type=str, default=None, help='data_url')
- parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012')
-
- parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
- parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
- parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
- args_opt = parser.parse_args()
-
- set_seed(1)
-
- if args_opt.net == "resnet50":
- from src.resnet import resnet50 as resnet
- if args_opt.dataset == "cifar10":
- from src.config import config1 as config
- from src.dataset import create_dataset1 as create_dataset
- else:
- from src.config import config2 as config
- from src.dataset import create_dataset2 as create_dataset
- elif args_opt.net == "resnet18":
- from src.resnet import resnet18 as resnet
- if args_opt.dataset == "cifar10":
- from src.config import config1 as config
- from src.dataset import create_dataset1 as create_dataset
- else:
- from src.config import config2 as config
- from src.dataset import create_dataset2 as create_dataset
- elif args_opt.net == "resnet101":
- from src.resnet import resnet101 as resnet
- from src.config import config3 as config
- from src.dataset import create_dataset3 as create_dataset
- else:
- from src.resnet import se_resnet50 as resnet
- from src.config import config4 as config
- from src.dataset import create_dataset4 as create_dataset
-
- def _parse_param(model_path):
- split_items = os.path.basename(model_path).split('-')
- if len(split_items) == 4:
- timestamp, step, _round, uuid = split_items
- elif len(split_items) > 4:
- timestamp = split_items[0]
- step = split_items[1]
- _round = split_items[2]
- uuid = split_items[3]
- else:
- uuid = 'avg'
- timestamp=''
- _round = 1
- step = 1
- #界面展示用
- #获取任务组id
- group_id=0
- try:
- client_uuids = [d.name for d in os.scandir(os.path.dirname(os.path.dirname(model_path))) if d.is_dir() and 'avg' not in d.name]
- _group_id, _ = client_uuids[0].split('_')
- group_id=int(_group_id)
- except Exception as e:
- print(e)
- return step, _round, uuid, group_id
-
- def _watchdog_callback(model_path):
- print(model_path)
- step, _round, uuid, group_id = _parse_param(model_path)
- # step, _round, uuid, group_id = 1, 1, 'avg', 0
- print(step, _round, uuid, group_id)
- model = param_hunter.fill_params(step, _round, uuid)
- accu = test(model, test_data)
- start_time = datetime.datetime.now()
- end_time = datetime.datetime.now()
- result = api_client.add_task_training_data(group_id, 0, _round,
- recall=0, precision=accu,
- startTime=start_time.strftime("%Y-%m-%d %H:%M:%S.%f"),
- endTime=end_time.strftime("%Y-%m-%d %H:%M:%S.%f"))
- print(result)
-
- #同步监听avg模型目录
- def _watchdog(model_path):
- try:
- path_to_watch = model_path
- before = dict ([(f.path, None) for f in os.scandir(path_to_watch)])
- while True:
- time.sleep(2)
- after = dict ([(f.path, None) for f in os.scandir(path_to_watch)])
- added = [f for f in after if not f in before]
- removed = [f for f in before if not f in after]
- if added:
- _watchdog_callback(added[0])
- print("Added: ", ", ".join(added))
- before = after
- except KeyboardInterrupt:
- print(f'stop watching folder:{model_path}')
-
- if __name__ == '__main__':
- target = args_opt.device_target
-
- # init context
- context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
- if target != "GPU":
- device_id = int(os.getenv('DEVICE_ID'))
- context.set_context(device_id=device_id)
-
- mox.file.copy_parallel(src_url=DATASET_DIR, dst_url=LOCAL_PATH)
- f = open("/cache/install.txt", 'w')
- f.close()
- # 此处用于阻塞其他进程,直到刷包以及下载数据集完成为止
- while not os.path.exists("/cache/install.txt"):
- time.sleep(0.6)
-
- # create dataset
- dataset = create_dataset(dataset_path=args_opt.dataset_path+'/eval', do_train=False, batch_size=config.batch_size,
- target=target)
- step_size = dataset.get_dataset_size()
-
- # define net
- net = resnet(class_num=config.class_num)
-
- # load checkpoint
- # load ckpt from obs
- tmp_ckpt_path = '/cache/tmp.ckpt'
- mox.file.copy(src_url=args_opt.checkpoint_path, dst_url=tmp_ckpt_path)
- param_dict = load_checkpoint(tmp_ckpt_path)
- load_param_into_net(net, param_dict)
- net.set_train(False)
-
- # define loss, model
- if args_opt.dataset == "imagenet2012":
- if not config.use_label_smooth:
- config.label_smooth_factor = 0.0
- loss = CrossEntropySmooth(sparse=True, reduction='mean',
- smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
- else:
- loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
-
- # define model
- model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
-
- # eval model
- res = model.eval(dataset)
- print("result:", res, "ckpt=", args_opt.checkpoint_path)
|