|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- '''cifar_resnet50
- The sample can be run on Ascend 910 AI processor.
- '''
- import os
- import random
- import cv2
- import json
- import argparse
- import numpy as np
- import mindspore.common.dtype as mstype
- import mindspore.dataset as ds
- import mindspore.dataset.vision.c_transforms as C
- import mindspore.dataset.transforms.c_transforms as C2
- from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
- from mindspore.communication.management import init
- from mindspore.nn.optim.momentum import Momentum
- from mindspore.train.model import Model
- from mindspore.context import ParallelMode
- from mindspore import context
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
- from resnet import resnet50
- from mindspore import Tensor
- import moxing as mox
-
-
- parser = argparse.ArgumentParser(description='Image classification')
- parser.add_argument('--data_url',
- help='path to training/inference dataset folder',
- default= '/cache/data1/')
-
- parser.add_argument('--multi_data_url',
- help='path to multi dataset',
- default= '/cache/data/')
-
- parser.add_argument('--ckpt_url',
- help='model to save/load',
- default= '/cache/checkpoint.ckpt')
-
- parser.add_argument('--result_url',
- help='model folder to save/load',
- default= '/cache/result/')
-
-
- parser.add_argument(
- '--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'CPU'],
- help='device where the code will be implemented (default: CPU),若要在启智平台上使用NPU,需要在启智平台训练界面上加上运行参数device_target=Ascend')
-
-
- args, unknown = parser.parse_known_args()
-
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
-
- if args.device_target == "Ascend":
- device_id = int(os.getenv('DEVICE_ID'))
- context.set_context(device_id=device_id)
-
- def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
- try:
- mox.file.copy(obs_ckpt_url, ckpt_url)
- print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
- return
-
- def MultiObsToEnv(multi_data_url, data_dir):
- #--multi_data_url is json data, need to do json parsing for multi_data_url
- multi_data_json = json.loads(multi_data_url)
- for i in range(len(multi_data_json)):
- path = data_dir + "/" + multi_data_json[i]["dataset_name"]
- if not os.path.exists(path):
- os.makedirs(path)
- try:
- mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
- print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- multi_data_json[i]["dataset_url"], path) + str(e))
- #Set a cache file to determine whether the data has been copied to obs.
- #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
- ### Copy the output model to obs ###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,
- obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,
- obs_train_url) + str(e))
- return
- def DownloadFromQizhi(multi_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- MultiObsToEnv(multi_data_url,data_dir)
- context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- init()
- #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank=int(os.getenv('RANK_ID'))
- if local_rank%8==0:
- MultiObsToEnv(multi_data_url,data_dir)
- #If the cache file does not exist, it means that the copy data has not been completed,
- #and Wait for 0th card to finish copying data
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank=int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank%8==0:
- EnvToObs(train_dir, obs_train_url)
- return
-
- def DirAll(pathName,type):
- image_list = [];
- if os.path.exists(pathName):
- fileList = os.listdir(pathName)
- for f in fileList:
- f=os.path.join(pathName, f)
- if os.path.isdir(f):
- image_list.extend(DirAll(f,type))
- else:
- dirName = os.path.dirname(f)
- baseName = os.path.basename(f)
- if type=="npy":
- if baseName.endswith("inputs.npy"):
- image_list.append(dirName+os.sep+baseName)
- else:
- if baseName.endswith(".txt"):
- continue
- if dirName.endswith(os.sep):
- image_list.append(dirName+baseName)
- else:
- image_list.append(dirName+os.sep+baseName)
- return image_list
-
-
-
- def pre_deal(data_path):
- if data_path.endswith(".npy"):
- images = np.load(data_path)
- images =Tensor(images, mstype.float32)
- else:
- image = cv2.imread(data_path)
- norm_img = normalize(image)
- #norm_img = ms_normalize(image)
- images = [norm_img]
- images = Tensor(images, mstype.float32)
- return images
-
- def normalize(image):
- # ****************************个人修改开始 ******************
- # 数据集处理的尺寸根据实际替换
- mean = [0.4914 * 255, 0.4822 * 255, 0.4465 * 255]
- std = [0.2023 * 255, 0.1994 * 255, 0.2010 * 255]
- image = cv2.resize(image, (224, 224), cv2.INTER_LINEAR)
-
- # ****************************个人修改结束 ******************
- image = image / 1.0
- image = (image[:, :] - mean) / std
- image = image[:, :, ::-1].transpose((2, 0, 1)) # HWC-->CHW
- return image
-
- if __name__ == '__main__':
- # in this way by judging the mark of args, users will decide which function to use
-
- # ****************************个人修改开始 ******************
- # 定义模型及初始化模型。 还有上面的 normalize 方式需要可能需要修改一下。
- net = resnet50(32,10)
- ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
- model = Model(net, loss_fn=ls, metrics={'acc'})
-
- # ****************************个人修改结束 ******************
-
- data_dir = '/cache/data'
- train_dir = '/cache/output'
- ckpt_url = '/cache/checkpoint.ckpt'
-
-
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(train_dir):
- os.makedirs(train_dir)
-
- ObsUrlToEnv(args.ckpt_url, ckpt_url)
- ###Initialize and copy data to training image
- DownloadFromQizhi(args.multi_data_url, data_dir)
-
- dataset = data_dir
- c_dataset = data_dir
- type = "jpg"
- for file_name in os.listdir(data_dir):
- if file_name =="ImageNet1000_100_FGSM":
- c_dataset = data_dir + "/ImageNet1000_100_FGSM/fgsm_ImageNet1000_100/"
- elif file_name =="ImageNet1000_100":
- dataset = data_dir + "/ImageNet1000_100/ImageNet1000_100/ImageNet1000_100/"
- elif file_name =="CIFAR10_1000":
- dataset = data_dir + "/CIFAR10_1000/cifar10_1000/cifar10_1000/"
- type = "npy"
- elif file_name =="CIFAR10_1000_FGSM":
- c_dataset = data_dir + "/CIFAR10_1000_FGSM/fgsm_cifar10_1000/"
- type = "npy"
-
- print(c_dataset)
- print(dataset)
- print(type)
-
- save_path = train_dir
- save_name = "result.json"
-
- checkpoint_path = ckpt_url
- # as for evaluation, users could use model.eval
- result_dic = {"model": {}}
- result_dic["model"]["BDResult"] =[]
-
- param_dict = load_checkpoint(checkpoint_path)
- load_param_into_net(net, param_dict)
- origin_pic_list = DirAll(dataset,type)
- origin_pic_list.sort()
- print("length")
- print(len(origin_pic_list))
- for imgpath in origin_pic_list:
- print('path:' + imgpath)
- image = pre_deal(imgpath)
- output = model.predict(image)
- output_num=output.asnumpy()
- result_dic["model"]["BDResult"].extend(output_num.tolist())
-
- for c_dir in os.listdir(c_dataset):
- print("path222=" + c_dir)
- if os.path.isdir(c_dataset + "/" + c_dir):
- child_data = DirAll(os.path.join(c_dataset, c_dir),type)
- child_data.sort()
- for imgpath in child_data:
- image = pre_deal(imgpath)
- output = model.predict(image)
- output_num=output.asnumpy()
- if "CDResult" not in result_dic["model"]:
- result_dic["model"]["CDResult"] = {}
- if c_dir not in result_dic["model"]["CDResult"]:
- result_dic["model"]["CDResult"][c_dir]=[]
-
- result_dic["model"]["CDResult"][c_dir].extend(output_num.tolist())
-
- #print(result_dic)
-
- # 保存预测结果
- if not os.path.exists(save_path):
- os.makedirs(save_path)
- with open(os.path.join(save_path, save_name), "w") as f:
- json.dump(result_dic, f)
-
- UploadToQizhi(train_dir,args.result_url)
-
|