|
- # Copyright 2022 Huawei Technologies Co., Ltd
-
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
-
- # http://www.apache.org/licenses/LICENSE-2.0
-
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import os
- import moxing as mox
- import numpy as np
- import mindspore as ms
- import time
-
- from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context, nn, Model
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
-
- from data_provider.mnist_to_mindrecord import create_mnist_dataset
- from nets.predrnn_pp import PreRNN, NetWithLossCell
- from config import config
- from mindspore.communication.management import init, get_rank
- from mindspore.context import ParallelMode
-
- def ObsToEnv(obs_data_url, data_dir):
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
- #Set a cache file to determine whether the data has been copied to obs.
- #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
-
- ### Copy the output to obs###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
- return
-
- def DownloadFromQizhi(obs_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- ObsToEnv(obs_data_url,data_dir)
- context.set_context(mode=context.GRAPH_MODE,device_target='Ascend')
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- init()
- #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank=int(os.getenv('RANK_ID'))
- if local_rank%8==0:
- ObsToEnv(obs_data_url,data_dir)
- #If the cache file does not exist, it means that the copy data has not been completed,
- #and Wait for 0th card to finish copying data
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return
-
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank=int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank%8==0:
- EnvToObs(train_dir, obs_train_url)
- return
-
- if __name__ == "__main__":
- data_dir = "/cache/data"
- train_dir = "/cache/output"
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(train_dir):
- os.makedirs(train_dir)
- ###Initialize and copy data to training image
- DownloadFromQizhi(config.data_url, data_dir)
- ###################################
- device_num = int(os.getenv('RANK_SIZE'))
- config.ckpt_save_dir = os.path.join(train_dir, "train.ckpt")
- #train
- context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=config.device_id)
- device_num = config.device_num
- rank = 0
-
- num_hidden = [int(x) for x in config.num_hidden.split(',')]
- num_layers = len(num_hidden)
-
- shape = [config.batch_size,
- config.seq_length,
- config.patch_size*config.patch_size*config.img_channel,
- int(config.img_width/config.patch_size),
- int(config.img_width/config.patch_size)]
-
- shape = list(map(int, shape))
-
- network = PreRNN(input_shape=shape,
- num_layers=num_layers,
- num_hidden=num_hidden,
- filter_size=config.filter_size,
- stride=config.stride,
- seq_length=config.seq_length,
- input_length=config.input_length,
- tln=config.layer_norm)
-
- netwithloss = NetWithLossCell(network, config.batch_size, config.seq_length, \
- config.input_length, shape[-3], shape[-1], config.reverse_input, True)
- exponential_decay_lr = nn.ExponentialDecayLR(config.lr, 0.95, 10000, is_stair=True)
- opt = nn.Adam(params=netwithloss.trainable_params(), learning_rate=config.lr)
-
- train_step = nn.TrainOneStepCell(netwithloss, opt).set_train()
-
- model = Model(train_step)
-
- dataset_dir = os.path.join(data_dir, "mnist_train")
- ds = create_mnist_dataset(dataset_files=os.path.join(dataset_dir, "mnist_train.mindrecord"), rank_size=device_num, \
- rank_id=rank, do_shuffle=True, batch_size=config.batch_size)
-
- time_cb = TimeMonitor(data_size=ds.get_dataset_size())
- config_ck = CheckpointConfig(save_checkpoint_steps=config.snapshot_interval, keep_checkpoint_max=5)
- ckpt_cb = ModelCheckpoint(prefix=config.model_name, directory=config.ckpt_save_dir, config=config_ck)
-
- model.train(epoch=int(config.max_iterations/config.sink_size), train_dataset=ds, \
- sink_size=config.sink_size, dataset_sink_mode=True, callbacks=[time_cb, ckpt_cb, LossMonitor()])
-
- ###Copy the trained output data from the local running environment back to obs,
- ###and download it in the training task corresponding to the Qizhi platform
- #This step is not required if UploadOutput is called
- UploadToQizhi(train_dir,config.train_url)
|