|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """train launch."""
-
- import os
-
- import argparse
-
- import mindspore
- import mindspore.nn as nn
- import mindspore.ops as ops
- from mindspore import Tensor
- from mindspore import context
- from mindspore.train.model import Model
- from mindspore.communication.management import init, get_rank, get_group_size
- from src.deepmar_dataset import create_dataset
- from src.deepmar import Deep_Mar
- from src.lr_schedule import get_multi_step_lr, warmup_cosine_annealing_lr
- from src.create_callback import DeepMAR_Callback
-
-
- class Loss_Network(nn.Cell):
-
- def __init__(self, deepmar):
- super(Loss_Network, self).__init__()
- self.deepmar = deepmar
- self.criterion = ops.BinaryCrossEntropy()
-
- def construct(self, x, y, w):
- output = self.deepmar(x)
- loss = self.criterion(output, y, w)
- return loss
-
-
- def get_dynamic_lr(args, train_dataset):
- lr = args.init_lr
- if args.lr_strategy == 'Default':
- lr = args.init_lr
- elif args.lr_strategy == 'Multistep':
- if args.is_distributed:
- lr = get_multi_step_lr(train_dataset.get_dataset_size() * 4, init_lr=args.init_lr, epoch=args.epoch_size)
- lr = Tensor(lr[::4])
- else:
- lr = get_multi_step_lr(train_dataset.get_dataset_size(), init_lr=args.init_lr, epoch=args.epoch_size)
- lr = Tensor(lr)
- elif args.lr_strategy == 'Cosine':
- if args.is_distributed:
- lr = warmup_cosine_annealing_lr(train_dataset.get_dataset_size() * 4, init_lr=args.init_lr,
- max_epoch=args.epoch_size)
- lr = Tensor(lr[::4])
- else:
- lr = warmup_cosine_annealing_lr(train_dataset.get_dataset_size(), init_lr=args.init_lr,
- max_epoch=args.epoch_size)
- lr = Tensor(lr)
- else:
- raise Exception("Please enter the correct learning rate policy keyword!")
-
- return lr
-
-
- def main(args):
- if args.enable_pengcheng_cloud:
- import moxing as mox
- data_dir = args.workroot + '/data'
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
-
- obs_data_url = args.data_url
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
-
- args.image_path = data_dir + '/images'
- args.peta_dataset_mat_dir = data_dir + '/PETA.mat'
- args.pretrain_chechpoint = data_dir + '/resnet50_frompytorch.ckpt'
- args.save_ckpt_dir = args.workroot + '/model'
-
- if not os.path.exists(args.save_ckpt_dir):
- os.mkdir(args.save_ckpt_dir)
-
- if args.is_distributed:
- device_id = int(os.getenv('DEVICE_ID'))
- context.set_context(device_id=device_id)
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
- init()
- rank_id = get_rank()
- rank_size = get_group_size()
- context.set_auto_parallel_context(parallel_mode=context.ParallelMode.DATA_PARALLEL, gradients_mean=True,
- parameter_broadcast=True)
- train_dataset = create_dataset(args, rank_size, rank_id)
-
- else:
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
- train_dataset = create_dataset(args, rank_size=None, rank_id=None)
-
- net = Deep_Mar()
- init_dict = mindspore.train.serialization.load_checkpoint(args.pretrain_chechpoint)
- rename_dict = {}
- for k, v in init_dict.items():
- num = k.split(".")
- num_len = len(num)
- num_index = None
- for index, name in enumerate(num):
- if name == "down_sample_layer":
- num_index = index
- break
- if num_index is not None:
- num[num_index] = "downsample"
- key_name = num[0]
- for i in range(1, num_len):
- key_name = key_name + "." + num[i]
- rename_dict[key_name] = v
- else:
- rename_dict[k] = v
- mindspore.train.serialization.load_param_into_net(net, parameter_dict=rename_dict, strict_load=False)
-
- classifier_param = []
- resnet_backbone_param = []
- for param in net.trainable_params():
- if 'classifier' not in param.name and 'add_block' not in param.name:
- resnet_backbone_param.append(param)
- else:
- classifier_param.append(param)
- gropu_param = [{'params': resnet_backbone_param, 'weight_decay': args.backbone_weight_decay},
- {'params': classifier_param, 'weight_decay': args.classifier_weight_decay},
- {'order_params': net.trainable_params()}]
-
- dynamic_lr = get_dynamic_lr(args, train_dataset)
-
- optim = nn.SGD(gropu_param, learning_rate=dynamic_lr, momentum=args.momentum, nesterov=True)
-
- loss_net = Loss_Network(net)
- model = Model(loss_net, optimizer=optim)
-
- if args.is_distributed:
- my_call = DeepMAR_Callback(args, train_dataset.get_dataset_size(), rank_id)
- else:
- my_call = DeepMAR_Callback(args, train_dataset.get_dataset_size())
-
- model.train(args.epoch_size, train_dataset, callbacks=[my_call], dataset_sink_mode=True)
-
- if args.enable_pengcheng_cloud:
- mox.file.copy_parallel(args.save_ckpt_dir, args.train_url)
- print("Successfully Upload {} to {}".format(args.save_ckpt_dir, args.train_url))
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='reidentification')
-
- # dataset option
- parser.add_argument('--split', type=str, default='train', choices=['trainval', 'train', "test"],
- help="Select dataset only or training set and validation set")
- parser.add_argument('--partition_idx', type=int, default=0,
- help="Dataset split sequence number")
- parser.add_argument('--image_resize', type=tuple, default=(224, 224),
- help="Data set picture specified size")
-
- # train option
- parser.add_argument('--batch_size', type=int, default=16)
- parser.add_argument('--momentum', type=float, default=0.9)
- parser.add_argument('--init_lr', type=float, default=0.1)
- parser.add_argument('--epoch_size', type=int, default=500)
- parser.add_argument('--num_work', type=int, default=2)
- parser.add_argument('--epoch_star_save', type=int, default=400,
- help="Start saving the initial epoch of the model")
- parser.add_argument('--epoch_per_save', type=int, default=10,
- help="The epoch interval to save the model")
- parser.add_argument('--backbone_weight_decay', type=float, default=4e-5,
- help="The weight decay of the backbone ")
- parser.add_argument('--classifier_weight_decay', type=float, default=4e-4,
- help="The weight decay of the classifier")
- parser.add_argument('--lr_strategy', type=str, default="Cosine",
- choices=['Default', 'Multistep', 'Cosine'],
- help="The dynamic learning rate strategy")
-
- # device option
- parser.add_argument('--device_target', type=str, default="Ascend")
- parser.add_argument('--is_distributed', type=int, default=0,
- help="Start distributed training")
- parser.add_argument('--train_mode', type=str, default='train', choices=['test', 'train'],
- help="the mode of loading a dataset")
- parser.add_argument('--save_ckpt_device', type=int, default=0,
- help=" In distributed mode, the id of device to save ckpt model ")
-
- # url option
- parser.add_argument('--save_ckpt_dir', type=str, default='./model',
- help="Absolute address to save the model")
- parser.add_argument('--pretrain_chechpoint', type=str, default="./data/resnet50_frompytorch.ckpt",
- help="The absolute address of pretraining weight")
- parser.add_argument('--peta_dataset_mat_dir', type=str, default='./data/PETA.mat',
- help="The absolute address of PETA.mat ")
- parser.add_argument('--image_path', type=str, default='./data/images',
- help="The absolute address of image folder")
-
- # PengCheng cloud brain option
- parser.add_argument('--enable_pengcheng_cloud', type=int, default=0,
- help="Whether it runs on Pengcheng cloud brain")
- parser.add_argument('--workroot', type=str, default='/home/work/user-job-dir',
- help="Cloud brain working environment for training tasks")
- parser.add_argument('--train_url', type=str, default=' ',
- help="Training task result saving address")
- parser.add_argument('--data_url', type=str, default=' ',
- help="Dataset address of training task")
-
- my_args = parser.parse_args()
- main(my_args)
|