|
- import argparse
- import os
- import sys
- import json
- import time
-
- import moxing as mox
- from mindspore import Model, dataset, nn, context, load_checkpoint, load_param_into_net
- from mindspore.common import set_seed
- from mindspore.communication import init, get_rank, get_group_size
- # from mindspore.communication.management import init
- from mindspore.context import ParallelMode
- from mindspore.nn import TrainOneStepCell, learning_rate_schedule
- from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
-
- from model.delg_model import delg_model, DelgWithGlobalLoss, DelgWithLocalLoss
-
- import ckpt.convert_h5_to_weight as h5
- import mindspore.dataset.transforms.c_transforms as C
- import mindspore.dataset.vision.c_transforms as vision
- import mindspore.common.dtype as mstype
-
- ### Copy multiple datasets from obs to training image ###
- def MultiObsToEnv(multi_data_url, data_dir):
- #--multi_data_url is json data, need to do json parsing for multi_data_url
- multi_data_json = json.loads(multi_data_url)
- for i in range(len(multi_data_json)):
- path = data_dir + "/" + multi_data_json[i]["dataset_name"]
- if not os.path.exists(path):
- os.makedirs(path)
- try:
- mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
- print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- multi_data_json[i]["dataset_url"], path) + str(e))
- #Set a cache file to determine whether the data has been copied to obs.
- #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
- ### Copy the output model to obs ###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,
- obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,
- obs_train_url) + str(e))
- return
- def DownloadFromQizhi(multi_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- MultiObsToEnv(multi_data_url,data_dir)
- context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- init()
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank=int(os.getenv('RANK_ID'))
- if local_rank%8==0:
- MultiObsToEnv(multi_data_url,data_dir)
- #If the cache file does not exist, it means that the copy data has not been completed,
- #and Wait for 0th card to finish copying data
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- #Set up parallel training mode
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- return
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank=int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank%8==0:
- EnvToObs(train_dir, obs_train_url)
- return
-
- # print("*" * 10)
- # print(sys.path)
- # print("*" * 10)
-
- NUM_CLASSES = 81313
- INFO_PATH = sys.path[0] + "/data/"
- CKPT_PATH = sys.path[0] + "/ckpt/"
-
- ######################## 关于args的一些设定 ########################
- parser = argparse.ArgumentParser(description="DELG Mindspore Version")
- # data_url
- parser.add_argument('--data_url',
- help='Path to training dataset folder',
- default='/home/work/user-job-dir/data/')
- # train_url
- parser.add_argument('--train_url',
- help='Model folder to save/load',
- default='/home/work/user-job-dir/model/')
- # multi_data_url
- parser.add_argument('--multi_data_url',
- help='path to multi dataset',
- default= '/cache/data/')
- # device_target
- parser.add_argument('--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'GPU', 'CPU'],
- help='device where the code will be implemented (default: Ascend)')
- # epoch_size
- parser.add_argument('--epoch_size',
- type=int,
- default=15,
- help='Training epochs.')
- # batch_size
- parser.add_argument('--batch_size',
- type=int,
- default=64,
- help='batch size of datasets.')
- # train_mode
- parser.add_argument('--train_mode',
- type=str,
- default='global',
- choices=['global', 'local'],
- help='train global part or local part.')
- # ckpt_url
- parser.add_argument('--ckpt_url',
- type=str,
- default='checkpoint.ckpt',
- help='ckpt_url.')
-
-
- set_seed(114514)
-
- if __name__ == '__main__':
- args = parser.parse_args()
-
- data_dir = '/cache/data'
- train_dir = '/cache/output'
- # if not os.path.exists(data_dir):
- # os.makedirs(data_dir)
- # if not os.path.exists(train_dir):
- # os.makedirs(train_dir)
- DownloadFromQizhi(args.multi_data_url, data_dir)
-
- # 将dataset_path指向data_url,save_checkpoint_path指向train_url
- args.dataset_path = args.data_url
-
- # ######################## 训练前的准备 ########################
- # device_id = int(os.getenv("DEVICE_ID"))
- # context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
- # context.set_context(device_id=device_id)
- # init()
- # # context.set_context(GRAPH_OP_RUN=1)
- # context.reset_auto_parallel_context()
- # context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
-
-
- # get Ascend info
- rank_id = get_rank()
- rank_size = get_group_size()
-
- # define datasets for train model
- # data_path = "/home/work/user-job-dir/data/MindRecord/train.mindrecord000"
- data_path = os.path.join(data_dir + "/GLDv2_Clean_MindRecord/MindRecord", "train.mindrecord000")
- ds_train = dataset.MindDataset(data_path, num_shards=rank_size, shard_id=rank_id)
- columns_to_project = ["data", "label"]
- ds_train = ds_train.project(columns=columns_to_project)
-
- # data aug
- decode_op = vision.Decode()
- normalize_op = vision.Normalize(mean=[128.0, 128.0, 128.0], std=[128.0, 128.0, 128.0])
- crop_resize_op = vision.RandomResizedCrop(321, max_attempts=100)
- chw_op = vision.HWC2CHW()
- type_cast_op = C.TypeCast(mstype.float32)
- label_type_cast_op = C.TypeCast(mstype.int32)
- transforms_list = [decode_op, type_cast_op, normalize_op, crop_resize_op, chw_op, type_cast_op]
-
- ds_train = ds_train.map(operations=transforms_list, input_columns="data")
- ds_train = ds_train.map(operations=[label_type_cast_op], input_columns="label")
- seed = 0
- dataset.config.set_seed(seed)
- ds_train = ds_train.shuffle(buffer_size=100)
-
- ds_train = ds_train.batch(args.batch_size)
-
- # define model
- if args.train_mode == 'global':
- net = delg_model(mode='global')
- net.update_parameters_name('net.')
- # ckpt_url = "/home/work/user-job-dir/data/checkpoint/" + args.ckpt_url
- ckpt_url = os.path.join(data_dir + "/checkpoint", args.ckpt_url)
- param_dict = load_checkpoint(ckpt_url)
- load_param_into_net(net, param_dict)
- net_with_loss = DelgWithGlobalLoss(net)
- elif args.train_mode == 'local':
- net = delg_model(mode='local')
- net.update_parameters_name('net.')
- # ckpt_url = "/home/work/user-job-dir/data/checkpoint/" + args.ckpt_url
- ckpt_url = os.path.join(data_dir + "/checkpoint", args.ckpt_url)
- param_dict = load_checkpoint(ckpt_url)
- load_param_into_net(net, param_dict)
- net_with_loss = DelgWithLocalLoss(net)
-
- # lr = learning_rate_schedule.CosineDecayLR(min_lr=0.01, max_lr=0.01, decay_steps=500000)
- # init_lr = 0.01 * (1 - 1/250000)
- init_lr = 0.001
- lr_schedule = nn.PolynomialDecayLR(learning_rate=init_lr, end_learning_rate=0.0001, decay_steps=500000, power=1.0)
-
- optimizer = nn.SGD(net.trainable_params(), learning_rate=lr_schedule, momentum=0.9)
-
- net_with_grad = nn.TrainOneStepCell(net_with_loss, optimizer)
-
- net_with_grad.set_train()
-
- # checkpoint setting
- ck_config = CheckpointConfig(save_checkpoint_steps=2877, keep_checkpoint_max=(args.epoch_size*8))
- # ck_config = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=args.epoch_size)
- model_ck_cb = ModelCheckpoint(prefix="Delg_resnet50_100+", directory=args.train_url+str(get_rank())+"/", config=ck_config)
- # model_ck_cb = ModelCheckpoint(prefix="Delg_resnet50", directory=args.train_url, config=ck_config)
- loss_cb = LossMonitor()
-
- # train
- model = Model(net_with_grad)
- model.train(args.epoch_size, ds_train, callbacks=[loss_cb, model_ck_cb], dataset_sink_mode=True)
- UploadToQizhi(train_dir,args.train_url)
|