|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """train delf"""
- import os
-
- import mindspore.nn as nn
- from mindspore import context, Model
- from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
- from mindspore.context import ParallelMode
- from mindspore import load_checkpoint, load_param_into_net
- import mindspore.ops as ops
- from mindspore.train.callback import LossMonitor
- from mindspore.train.callback import Callback
- from mindspore.train.callback import SummaryCollector
- from mindspore.communication.management import init
- from mindspore.profiler import Profiler
- import moxing as mox
- import json
- import time
-
- import numpy as np
-
- import src.convert_h5_to_weight as h5
- import src.data_augmentation_parallel as daa
- import src.delg_model as model_h5
-
- from src.delg_model import delg_model, ArcFace
- from model_utils.config import config as args
- from model_utils.moxing_adapter import moxing_wrapper
- from model_utils.device_adapter import get_device_id, get_device_num
-
- ### Copy multiple datasets from obs to training image ###
- def MultiObsToEnv(multi_data_url, data_dir):
- #--multi_data_url is json data, need to do json parsing for multi_data_url
- multi_data_json = json.loads(multi_data_url)
- for i in range(len(multi_data_json)):
- path = data_dir + "/" + multi_data_json[i]["dataset_name"]
- if not os.path.exists(path):
- os.makedirs(path)
- try:
- mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
- print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- multi_data_json[i]["dataset_url"], path) + str(e))
- #Set a cache file to determine whether the data has been copied to obs.
- #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
- ### Copy the output model to obs ###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,
- obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,
- obs_train_url) + str(e))
- return
- def DownloadFromQizhi(multi_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- MultiObsToEnv(multi_data_url,data_dir)
- context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- init()
- #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank=int(os.getenv('RANK_ID'))
- if local_rank%8==0:
- MultiObsToEnv(multi_data_url,data_dir)
- #If the cache file does not exist, it means that the copy data has not been completed,
- #and Wait for 0th card to finish copying data
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank=int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank%8==0:
- EnvToObs(train_dir, obs_train_url)
- return
-
-
- class LossFunc(nn.Cell):
- """loss function"""
- def __init__(self, attention_loss_weight=1.0, state='tuning'):
- super(LossFunc, self).__init__()
- self.arcface = ArcFace(81313)
- self._loss_fn = nn.SoftmaxCrossEntropyWithLogits(
- sparse=True, reduction="mean")
- self._autoencoder_loss_fn = nn.MSELoss()
- self.state = state
- self.attention_loss_weight = attention_loss_weight
-
- def construct(self, base, label):
- """construct"""
- label = ops.clip_by_value(label, 0, 81313)
- if self.state == 'tuning':
- base = self.arcface(base, label)
- total_loss = self._loss_fn(base, label)
- else:
- (desc_prelogits, attn_logits, stop_block3, dim_expanded_features) = base
- desc_prelogits = self.arcface(desc_prelogits, label)
- global_loss = self._loss_fn(desc_prelogits, label)
- attn_loss = self._loss_fn(attn_logits, label)
- autoencoder_loss = self._autoencoder_loss_fn(stop_block3, dim_expanded_features) * 10.0
- total_loss = self.attention_loss_weight * attn_loss + autoencoder_loss + global_loss
-
- return total_loss
-
-
- class MySGD(nn.SGD):
- """my SGD"""
- def __init__(self, *args_in, **kwargs):
- super().__init__(*args_in, **kwargs)
- self._original_construct = super().construct
- self.gradient_names = [param.name +
- ".gradient" for param in self.parameters]
- self.count = len(self.gradient_names)
-
- def construct(self, grads):
- grads = ops.clip_by_global_norm(grads, 10.0)
- return self._original_construct(grads)
-
-
- class EvalCallBack(Callback):
- """my callback"""
- def __init__(self, cur_iters, fianel_iter, state,
- eval_interval=1000, eval_start_step=1):
- super(EvalCallBack, self).__init__()
-
- self.eval_start_step = eval_start_step
- if eval_interval < 1:
- raise ValueError("interval should >= 1.")
- self.eval_interval = eval_interval
- self.fianel_iter = fianel_iter
- self.state = state
- self.cur_iters = cur_iters
-
- def begin(self, run_context):
- cb_params = run_context.original_args()
- cb_params.cur_step_num = self.cur_iters
-
- def record_accuracy(self, logits, label):
- """Record accuracy given predicted logits and ground-truth labels."""
- y_pred = logits.asnumpy()
- label = label.asnumpy()
- indices = np.argmax(y_pred, axis=1)
- result = (np.equal(indices, label) * 1).reshape(-1)
- correct_num = np.sum(result)
- total_num = result.shape[0]
- return correct_num / total_num
-
- # calculate accuracy
- def compute_acc(self, attn_logits, label):
- # desc_acc = self.record_accuracy(desc_logits, label)
- attn_acc = self.record_accuracy(attn_logits, label)
- return attn_acc
-
- def step_end(self, run_context):
- """callback"""
- cb_params = run_context.original_args()
- cur_step = cb_params.cur_step_num
- cur_network = cb_params.network
- cur_data, cur_label = cb_params.train_dataset_element
-
- # print accuracy
- if cur_step >= self.eval_start_step and (cur_step - self.eval_start_step) % self.eval_interval == 0:
- if self.state != 'tuning':
- (desc_prelogits, attn_logits, stop_block3, dim_expanded_features) = cur_network(cur_data)
- attn_acc = self.compute_acc(attn_logits, cur_label)
- print("step: %s, train attn Acc %s" %
- (cur_step, attn_acc), flush=True)
-
- # stop the training
- if cur_step > self.fianel_iter:
- run_context.request_stop()
-
-
- def modelarts_pre_process():
- '''modelarts pre process function.'''
- args.save_ckpt = os.path.join(
- args.output_path, args.save_ckpt+str(get_device_id()))
-
-
- @moxing_wrapper(pre_process=modelarts_pre_process)
- def run_train():
- """train"""
- # load data
- train_dataset = daa.create_dataset(
- args.traindata_path, args.image_size, args.batch_size,
- seed=args.seed, augmentation=True, repeat=True)
-
- # initial forward net
- delg_net = model_h5.delg_model(state=args.train_state)
- param_dict = h5.translate_h5(args.imagenet_checkpoint)
- load_param_into_net(delg_net, param_dict)
-
- # load ckpt
- if args.checkpoint_path != "":
- param_dict = load_checkpoint(args.checkpoint_path)
- not_load = load_param_into_net(delg_net, param_dict)
- print('weights not load in ckpt: ', not_load)
-
- # freeze laysers
- # print('freeze param:')
- # if args.train_state == "attn":
- # for param in delg_net.get_parameters():
- # if ('attention.' not in param.name and
- # 'attn.' not in param.name and
- # 'autoencoder.' not in param.name):
- # print(param.name)
- # param.requires_grad = False
- # elif args.train_state == "tuning":
- # for param in delg_net.get_parameters():
- # if ('attention.' in param.name or
- # 'attn.' in param.name and
- # 'autoencoder.'in param.name):
- # print(param.name)
- # param.requires_grad = False
-
- # loss func
- loss_func = LossFunc(attention_loss_weight=args.attention_loss_weight,
- state=args.train_state)
-
- # dynamic lr
- init_lr = args.initial_lr * (1 - args.start_iter/250000)
- lr_schedule = nn.PolynomialDecayLR(
- learning_rate=init_lr, end_learning_rate=0.0001, decay_steps=500000, power=1.0)
-
- optim = MySGD(delg_net.trainable_params(), learning_rate=lr_schedule,
- momentum=0.9, weight_decay=0.0)
-
- config_ck = CheckpointConfig(
- save_checkpoint_steps=args.save_ckpt_step, keep_checkpoint_max=args.keep_checkpoint_max)
-
- ckpoint_cb = ModelCheckpoint(prefix="checkpoint_delf_"+args.train_state,
- directory=args.save_ckpt,
- config=config_ck)
-
- eval_cb = EvalCallBack(args.start_iter, args.max_iters, args.train_state)
-
- callback_list = [eval_cb, LossMonitor(100)]
- if device_id == 0:
- callback_list.append(ckpoint_cb)
- if args.need_summary:
- callback_list.append(summary_collector)
-
- print("Ready to train!")
- model = Model(network=delg_net, loss_fn=loss_func,
- optimizer=optim, amp_level="O3")
- model.train(1, train_dataset, callbacks=callback_list, dataset_sink_mode=False)
-
- if args.need_profile:
- profiler.analyse()
-
- print("Train successfully!")
-
-
- if __name__ == "__main__":
- data_dir = '/cache/data'
- train_dir = '/cache/output'
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(train_dir):
- os.makedirs(train_dir)
- if not os.path.exists(args.save_ckpt):
- os.makedirs(args.save_ckpt)
-
- DownloadFromQizhi(args.multi_data_url, data_dir)
-
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
- device_id = get_device_id()
- device_num = get_device_num()
- specified = {'histogram_regular': 'attention.*'}
- summary_collector = None
- profiler = None
- if args.enable_modelarts:
- args.save_summary = os.path.join(args.output_path, args.save_summary)
- if device_num > 1:
- context.set_context(device_id=device_id)
- if args.need_profile:
- profiler = Profiler(output_path=os.path.join(
- args.save_summary, 'summary_dir'+str(device_id)))
- if args.need_summary:
- summary_collector = SummaryCollector(
- summary_dir=os.path.join(
- args.save_summary, 'summary_dir'+str(device_id)),
- collect_specified_data=specified, collect_freq=200)
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=False)
- init()
- else:
- context.set_context(device_id=device_id)
- if args.need_profile:
- profiler = Profiler(output_path=os.path.join(
- args.save_summary, 'summary_dir'))
- if args.need_summary:
- summary_collector = SummaryCollector(summary_dir=os.path.join(args.save_summary, 'summary_dir'),
- collect_specified_data=specified, collect_freq=200)
- run_train()
- UploadToQizhi(train_dir,args.train_url)
|