|
- import argparse
- import os
- import sys
- import time
-
- import moxing as mox
- from mindspore import Model, dataset, nn, context, load_checkpoint, load_param_into_net, save_checkpoint
- from mindspore.common import set_seed
- from mindspore.communication import init, get_rank, get_group_size
- from mindspore.context import ParallelMode
- from mindspore.nn import TrainOneStepCell, learning_rate_schedule
- from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
-
- from data.dataset import DataLoader
- from model.delg import Delg, DelgWithLoss
- from model.resnet import resnet50
- from model.delg_model import delg_model
-
- import ckpt.convert_h5_to_weight as h5
-
- print("*" * 10)
- print(sys.path)
- print("*" * 10)
-
- NUM_CLASSES = 81313
- INFO_PATH = sys.path[0] + "/data/"
- CKPT_PATH = sys.path[0] + "/ckpt/"
-
- ######################## 关于args的一些设定 ########################
- parser = argparse.ArgumentParser(description="DELG Mindspore Version")
- # data_url
- parser.add_argument('--data_url',
- help='Path to training dataset folder',
- default='/home/work/user-job-dir/data/')
- # train_url
- parser.add_argument('--train_url',
- help='Model folder to save/load',
- default='/home/work/user-job-dir/model/')
- # train_list_file
- parser.add_argument('--train_list_file',
- help='Get data info',
- default='train_list.txt')
- # device_target
- parser.add_argument('--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'GPU', 'CPU'],
- help='device where the code will be implemented (default: Ascend)')
- # epoch_size
- parser.add_argument('--epoch_size',
- type=int,
- default=21,
- help='Training epochs.')
- # batch_size
- parser.add_argument('--batch_size',
- type=int,
- default=32,
- help='batch size of datasets.')
-
-
- set_seed(0)
-
- if __name__ == '__main__':
- args = parser.parse_args()
-
- ######################## 将数据集从obs拷贝到训练镜像中 ########################
- obs_data_url = args.data_url
- args.data_url = '/home/work/user-job-dir/data/'
- if not os.path.exists(args.data_url):
- os.mkdir(args.data_url)
-
- obs_train_url = args.train_url
- args.train_url = '/home/work/user-job-dir/model/'
- if not os.path.exists(args.train_url):
- os.mkdir(args.train_url)
-
- try:
- mox.file.copy_parallel(obs_data_url, args.data_url)
- print("Successfully Download {} to {}".format(obs_data_url, args.data_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, args.data_url) + str(e))
-
- # 将dataset_path指向data_url,save_checkpoint_path指向train_url
- args.dataset_path = args.data_url
-
- ######################## 训练前的准备 ########################
- device_id = int(os.getenv("DEVICE_ID"))
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
- context.set_context(device_id=device_id)
- # context.set_context(GRAPH_OP_RUN=1)
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
- init()
-
- # get Ascend info
- rank_id = get_rank()
- rank_size = get_group_size()
-
- # define datasets for train model
- INFO_PATH = INFO_PATH + args.train_list_file
- dataset_generator = DataLoader(args, INFO_PATH)
- # ds_train = dataset.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id)
-
- ds_train = dataset.GeneratorDataset(dataset_generator, ["data", "label"], num_parallel_workers=8, shuffle=True)
- ds_train = ds_train.batch(args.batch_size)
- data_loader = ds_train.create_dict_iterator(num_epochs=1)
-
- # define model
- net = delg_model()
- imagenet_checkpoint = "./DELG/ckpt/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5"
- param_dict = h5.translate_h5(imagenet_checkpoint)
- load_param_into_net(net, param_dict)
- net_with_loss = DelgWithLoss(net)
-
- init_lr = 0.01
- lr_schedule = nn.PolynomialDecayLR(learning_rate=init_lr, end_learning_rate=0.0001, decay_steps=500000, power=1.0)
- optimizer = nn.SGD(net.trainable_params(), learning_rate=lr_schedule, momentum=0.9)
- net_with_grad = nn.TrainOneStepCell(net_with_loss, optimizer)
- net_with_grad.set_train()
-
- for epoch in range(args.epoch_size):
- epoch_start = time.time()
- for i, data in enumerate(data_loader):
- start = time.time()
- loss = net_with_grad(data["data"], data["label"])
- end = time.time()
- print("Epoch", epoch, " batch", i, " need time:", (end - start)*100, "loss:", loss)
- print("Epoch", epoch, " need time:", (time.time() - epoch_start))
- saved_url = args.train_url + "Epoch" + epoch + "-delg.ckpt"
- save_checkpoint(net, saved_url)
- try:
- mox.file.copy_parallel(saved_url, obs_train_url)
- print("Successfully Upload {} to {}".format(saved_url, obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(saved_url, obs_train_url) + str(e))
-
-
-
-
- # ######################## 将输出的模型拷贝到obs(固定写法) ########################
- # # 把训练后的模型数据从本地的运行环境拷贝回obs,在启智平台相对应的训练任务中会提供下载
- # try:
- # mox.file.copy_parallel(args.train_url, obs_train_url)
- # print("Successfully Upload {} to {}".format(args.train_url, obs_train_url))
- # except Exception as e:
- # print('moxing upload {} to {} failed: '.format(args.train_url, obs_train_url) + str(e))
- # ######################## 将输出的模型拷贝到obs ########################
|