|
- import os
- import mindspore
- import mindspore.nn as nn
- import mindspore.common.dtype as mstype
- from mindspore import context
- import mindspore.dataset as ds
- import mindspore.dataset.transforms.c_transforms as C2
- import mindspore.dataset.vision.c_transforms as C
- from mindspore import Tensor
- from CosFace import MobileFaceNet, WholeNet
- from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
- from mindspore.nn.optim import Momentum, SGD
- from mindspore.ops import operations as P
- from mindspore.train.model import Model
- from mindspore.common.initializer import initializer
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
- from lr_schedule import get_multi_step_lr, warmup_cosine_annealing_lr
- import numpy as np
- import argparse
- from collections import Counter
- import moxing as mox
- #配置默认的工作空间根目录
- # environment = 'debug'
- environment = 'train'
- if environment == 'debug':
- workroot = '/home/ma-user/work' #调试任务使用该参数
- else:
- workroot = '/home/work/user-job-dir' # 训练任务使用该参数
- print('current work mode:' + environment + ', workroot:' + workroot)
- class BuildTrainNetwork(nn.Cell):
- def __init__(self, network, criterion):
- super(BuildTrainNetwork, self).__init__()
- self.network = network
- self.criterion = criterion
-
- def construct(self, input_data, label):
- output = self.network(input_data, label)
- loss = self.criterion(output, label)
- return loss
-
- def get_parse():
- parser = argparse.ArgumentParser(description='Face Recognization')
- parser.add_argument('--batch_size', type=int, default=512, help='the batch num')
- parser.add_argument("--init_lr", type=float, default=0.1, help="the init Learning rate, default is 0.075.")
- parser.add_argument("--epoch_size", type=int, default=1, help="Epoch size, default is 100.")
- parser.add_argument("--lr_strategy", type=str, default="Multistep", help="the lr strategy, default is preserve init lr, consine, Multistep")
- parser.add_argument('--data_url',help='path to training/inference dataset folder',default= workroot + '/data/')
- parser.add_argument('--train_url',help='model folder to save/load',default= workroot + '/model/')
- parser.add_argument('--device_target',type=str,default="Ascend",choices=['Ascend', 'CPU'])
- args_opt = parser.parse_args()
- return args_opt
-
- def create_dataset(data_dir, batch_size):
- mean = [0.4914*255, 0.4822*255, 0.4465*255]
- std = [0.2023*255, 0.1994*255, 0.2010*255]
- casia_ds = ds.ImageFolderDataset(data_dir, decode=True)
- random_horizontal = C.RandomHorizontalFlip()
- resize_op = C.Resize((112, 96))
- normalize_op = C.Normalize(mean=mean, std=std)
- changeswap_op = C.HWC2CHW()
- transform_img = [random_horizontal, resize_op, normalize_op, changeswap_op]
-
- type_cast_op = C2.TypeCast(mstype.int32)
- transform_label = [type_cast_op]
-
- casia_ds = casia_ds.map(input_columns='image', operations=transform_img)
- casia_ds = casia_ds.map(input_columns='label', operations=transform_label)
-
- casia_ds = casia_ds.project(columns=["image", "label"])
-
- casia_ds = casia_ds.shuffle(buffer_size=10)
- casia_ds = casia_ds.batch(batch_size=batch_size, drop_remainder=True)
- casia_ds = casia_ds.repeat(1)
-
- return casia_ds
-
- if __name__ == '__main__':
- args_opt = get_parse()
- print('args:')
- print(args_opt)
- data_dir = workroot + '/data' #数据集存放路径
- train_dir = workroot + '/model' #模型存放路径
- #初始化数据存放目录
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
- #初始化模型存放目录
- obs_train_url = args_opt.train_url
- train_dir = workroot + '/model/'
- if not os.path.exists(train_dir):
- os.mkdir(train_dir)
- ######################## 将数据集从obs拷贝到训练镜像中 (固定写法)########################
- # 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径,以下写法是将数据拷贝到/home/work/user-job-dir/data/目录下,可修改为其他目录
- #创建数据存放的位置
- if environment == 'train':
- obs_data_url = args_opt.data_url
- #将数据拷贝到训练环境
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url,
- data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- obs_data_url, data_dir) + str(e))
- ######################## 将数据集从obs拷贝到训练镜像中 ########################
- context.set_context(mode=context.GRAPH_MODE,device_target=args_opt.device_target)
- epoch_size = args_opt.epoch_size
- # data_dir = '../data/face_ds1/CASIA/CASIA-WebFace-112X96/'
- dataset = create_dataset(os.path.join(data_dir, "Align-CASIA-WebFace/CASIA-WebFace-112X96"), args_opt.batch_size)
- if dataset.get_dataset_size() == 0:
- raise ValueError(
- "Please check dataset size > 0 and batch_size <= dataset size")
- net = WholeNet(num_class = 10575)
- loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
- if args_opt.lr_strategy == 'default':
- lr = args_opt.init_lr
- elif args_opt.lr_strategy == 'Multistep':
- lr = Tensor(get_multi_step_lr(dataset.get_dataset_size(), init_lr = args_opt.init_lr, epoch=epoch_size))
- elif args_opt.lr_strategy == 'cosine':
- lr = Tensor(warmup_cosine_annealing_lr(dataset.get_dataset_size(), init_lr=args_opt.init_lr, max_epoch=epoch_size))
-
- linear1_params = list(filter(lambda x: 'backbone.linear1' in x.name, net.trainable_params()))
- product_weight_params = list(filter(lambda x: 'product.weight' in x.name, net.trainable_params()))
- prelu_params = list(filter(lambda x: 'prelu' in x.name, net.trainable_params()))
- base_params = list(filter(lambda x: 'backbone.linear1' not in x.name and
- 'product.weight' not in x.name and
- 'prelu' not in x.name, net.trainable_params()))
- group_params = [{'params': base_params, 'weight_decay': 4e-5},
- {'params': linear1_params, 'weight_decay': 4e-4},
- {'params': product_weight_params, 'weight_decay': 4e-4}]
- opt = Momentum(group_params, lr, 0.9)
- train_net = BuildTrainNetwork(net, loss)
- model = Model(train_net, optimizer=opt)
- print (dataset.num_classes())
- batch_num = dataset.get_dataset_size()
- print (batch_num)
- config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=100)
- ckpoint_cb = ModelCheckpoint(prefix="mobilefacenet_cosineface_{}_casia10575".format(args_opt.lr_strategy), directory=train_dir, config=config_ck)
- loss_cb = LossMonitor()
- model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb], dataset_sink_mode=False)
- # 把训练后的模型数据从本地的运行环境拷贝回obs,在启智平台相对应的训练任务中会提供下载
- if environment == 'train':
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir, obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir, obs_train_url) + str(e))
-
|