|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Train RegNet."""
- import mindspore as ms
- import mindspore.dataset as ds
- import mindspore.dataset.transforms.c_transforms as C
- import mindspore.dataset.vision.c_transforms as vision
- from mindspore import dtype as mstype
- from mindspore import nn
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.context import ParallelMode
- from mindspore.train import Model
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
-
- from src.config import config as args_opt
- from src.regnet import regnet20, RegNetWithLossCell, TrainingWrapper
-
-
- def create_dataset(data_path, repeat_num=1, batch_size=32, rank_id=0, rank_size=1):
- rescale = 1.0 / 255.0
- shift = 0.0
-
- data_set = ds.Cifar10Dataset(data_path, num_shards=rank_size, shard_id=rank_id, usage='train')
-
- random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4))
- random_horizontal_op = vision.RandomHorizontalFlip()
- rescale_op = vision.Rescale(rescale, shift)
- normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
- changeswap_op = vision.HWC2CHW()
- type_cast_op = C.TypeCast(mstype.int32)
-
- c_trans = [random_crop_op, random_horizontal_op]
- c_trans += [rescale_op, normalize_op, changeswap_op]
- data_set = data_set.map(operations=type_cast_op, input_columns="label")
- data_set = data_set.map(operations=c_trans, input_columns="image")
- data_set = data_set.shuffle(buffer_size=10)
- data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
- data_set = data_set.repeat(repeat_num)
-
- return data_set
-
-
- def train(arguments):
- seed = arguments['seed']
- device_id = arguments['device_id']
- batch_size = arguments['batch_size']
- image_size = arguments['image_size']
- initial_lr = arguments['initial_lr']
- class_num = arguments['class_num']
- save_checkpoint_steps = arguments['save_checkpoint_steps']
- keep_checkpoint_max = arguments['keep_checkpoint_max']
- checkpoint_prefix = arguments['checkpoint_prefix']
- checkpoint_save_path = arguments['checkpoint_save_path']
- dataset_path = arguments['dataset_path']
- max_epoch = arguments['max_epoch']
- is_distributed = arguments['is_distributed']
- device_target = arguments['device_target']
- ms.common.seed.set_seed(seed)
- if device_target == 'Ascend':
- ms.set_context(mode=ms.GRAPH_MODE, device_target='Ascend', save_graphs=False,
- device_id=device_id)
- if is_distributed:
- init()
- rank_id = get_rank()
- device_num = get_group_size()
- ms.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True)
- checkpoint_save_path = checkpoint_save_path + "ckpt_" + str(get_rank()) + "/"
- dataset = create_dataset(data_path=dataset_path, repeat_num=1, batch_size=batch_size, rank_id=rank_id,
- rank_size=device_num)
- else:
- dataset = create_dataset(data_path=dataset_path, repeat_num=1, batch_size=batch_size, rank_id=0,
- rank_size=1)
- elif device_target == 'GPU':
- if is_distributed:
- ms.set_context(mode=ms.GRAPH_MODE, device_target='GPU', save_graphs=False)
- init('nccl')
- ms.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
- dataset = create_dataset(data_path=dataset_path, repeat_num=1, batch_size=batch_size,
- rank_size=get_group_size(), rank_id=get_rank())
- else:
- dataset = create_dataset(data_path=dataset_path, repeat_num=1, batch_size=batch_size, rank_id=0,
- rank_size=1)
- elif device_target == 'CPU':
- ms.set_context(mode=ms.GRAPH_MODE, device_target='CPU', save_graphs=False)
- dataset = create_dataset(data_path=dataset_path, repeat_num=1, batch_size=batch_size, rank_id=0,
- rank_size=1)
- ds_train = dataset
- print('dataset size is : \n', ds_train.get_dataset_size())
- regnet = regnet20(batch_size=batch_size, im_size=image_size, class_num=class_num)
- ce_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
- regnet.set_train(True)
- net = RegNetWithLossCell(regnet, ce_loss)
- opt = ms.nn.Adam(net.trainable_params(), initial_lr)
- net = TrainingWrapper(net, opt)
- model = Model(net)
- config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps,
- keep_checkpoint_max=keep_checkpoint_max)
- ckpoint_cb = ModelCheckpoint(prefix=checkpoint_prefix, directory=checkpoint_save_path, config=config_ck)
- time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
- callback_list = [LossMonitor(), time_cb, ckpoint_cb]
- model.train(max_epoch, ds_train, callbacks=callback_list,
- dataset_sink_mode=True)
-
-
- if __name__ == '__main__':
- args_opt = vars(args_opt)
- for i in args_opt:
- if isinstance(args_opt[i], str):
- args_opt[i] = args_opt[i].rstrip()
- train(arguments=args_opt)
|