|
-
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """train resnet."""
- import moxing as mox
- import os
- import argparse
- import ast
- from mindspore import context
- from mindspore import Tensor
- from mindspore.nn.optim.momentum import Momentum
- from mindspore.nn.optim.sgd import SGD
- ## from mindspore.train.model import Model
- from src.model import Model
- from mindspore.context import ParallelMode
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
- from mindspore.train.loss_scale_manager import FixedLossScaleManager
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.common import set_seed
- import mindspore.nn as nn
- import mindspore.common.initializer as weight_init
- from src.lr_generator import get_lr, warmup_cosine_annealing_lr
- from src.CrossEntropySmooth import CrossEntropySmooth
- from mindspore.train.callback import Callback
- from thgy_client import THGYApiClient
- import datetime
- import yaml
- import time
- import numpy as np
- import datetime
-
- from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
- from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell
- from src.pangu_alpha_config import PANGUALPHAConfig#, set_parse
- from src.utils import LearningRate, FP32StateAdamWeightDecay
- from src.dataset_restore_data0_taoht import create_dataset3 as create_dataset
- import mindspore.common.dtype as mstype
- from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
- import mindspore.communication.management as D
- from mindspore.parallel import set_algo_parameters
- from mindspore.parallel._cost_model_context import _set_multi_subgraphs
- import math
- import logging
-
- from metrics import PPLMetric
- import math
-
- import AISyncore as aisc
-
-
- class LossCallBack(Callback):
- """
- Monitor the loss in training.
- If the loss in NAN or INF terminating training.
- """
-
- def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0, micro_size=1):
- super(LossCallBack, self).__init__()
- self._dataset_size = dataset_size
- self.local_rank = local_rank
- self.has_trained_epoch = has_trained_epoch
- self.has_trained_step = has_trained_step
- self.micro_size = micro_size
- print("load has trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True)
-
- def step_end(self, run_context):
- """
- Print loss after each step
- """
- cb_params = run_context.original_args()
- if self._dataset_size > 0 and self.local_rank % 8 == 0:
- percent, epoch_num = math.modf(cb_params.cur_step_num /
- self._dataset_size)
- if percent == 0:
- epoch_num -= 1
- date = time.asctime(time.localtime(time.time()))
- loss_value = cb_params.net_outputs[0].asnumpy() / self.micro_size
- print("time: {} local_rank: {}, epoch: {}, step: {}, loss is {}, overflow is {}, scale is {}, lr is {}".
- format(date, int(self.local_rank), int(epoch_num) + int(self.has_trained_epoch),
- cb_params.cur_step_num + int(self.has_trained_step), loss_value,
- cb_params.net_outputs[1].asnumpy(), cb_params.net_outputs[2].asnumpy(),
- cb_params.net_outputs[3].asnumpy()
- ))
-
- parser = argparse.ArgumentParser(description='Image classification')
- parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
- parser.add_argument('--train_url', type=str, default=None, help='train_url')
- parser.add_argument('--data_url', type=str, default=None, help='data_url')
- parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012')
- parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
- parser.add_argument('--device_num', type=int, default=1, help='Device num.')
- parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
- parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
- parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
- parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
- parser.add_argument('--initial', type=str, default='True', choices=['False', 'True'], help='initial flag')
- parser.add_argument('--globalStept', type=int, default=1, help='global step')
- parser.add_argument('--uuid', type=str, default='112_12', help='Whether to fetch average parameters from server.')
- parser.add_argument('--num_epoch', type=int, default=80, help='global step')
-
- parser.add_argument("--per_batch_size",
- type=int,
- default=0,
- help="The batch size for each data parallel way. default 6")
- parser.add_argument("--eod_reset",
- type=int,
- default=1,
- help="Enable eod mask, default is 1.")
- parser.add_argument("--param_init_type",
- type=str,
- default="fp32",
- help="The initialization type for parameters. Default fp32.")
- parser.add_argument("--word_emb_dp",
- type=int,
- default=1,
- choices=[0, 1],
- help="Whether do data parallel in word embedding. default 1")
- parser.add_argument("--eod_id",
- type=int,
- default=6,
- help="The id of end of document")
- parser.add_argument("--start_lr",
- type=float,
- default="2e-5",# Fix me 2.5e-5
- help="Start learning rate, default is 5e-5.")
- parser.add_argument("--end_lr",
- type=float,
- default="5e-6",
- help="End learning rate, default is 1e-10.")
- parser.add_argument("--warmup_step",
- type=int,
- default=60,
- help="Warmup step, default is 10000.")
- parser.add_argument("--epoch_size",
- type=int,
- default=2, # default 1
- help="Epoch size, default is 10.")
- args_opt = parser.parse_args()
- TRAIING_ROUND = args_opt.globalStept
- TRAINED_EPOCH = args_opt.globalStept
-
- DATASET_DIR = 'obs://pcl-verify/yizx/other_verify/pangu-mindspore-AISyn/wiki/'
- LOCAL_PATH = "/cache/wiki"
- NPY_LOCAL_PATH = '/cache/res3.npy'
-
- # #print(args_opt)
- # if (TRAIING_ROUND < 10):
- # TRAINED_EPOCH = TRAIING_ROUND * 5
- # elif (10 <= TRAIING_ROUND < 20):
- # TRAINED_EPOCH = (TRAIING_ROUND - 10) * 10 + 50
- # elif (20 <= TRAIING_ROUND < 25):
- # TRAINED_EPOCH = (TRAIING_ROUND -20) * 20 + 150
- # print('==============='*10)
- # print(f"trained_epoch : {TRAINED_EPOCH}")
-
- # TRAINED_EPOCH = args_opt.globalStept
- # print('==============='*10)
- # print(f"trained_epoch : {TRAINED_EPOCH}")
-
-
- set_seed(1)
-
- step_per_round = int(os.environ["CLIENT_STEP"]
- ) if 'CLIENT_STEP' in os.environ else 1
-
-
- class EvalCallBack(Callback):
- """
- Monitor the ppl loss in evaluating.
- Note:
- If per_print_times is 0, do NOT print loss.
-
- Args:
- print_per_step (int): Print loss every times. Default: 1.
- """
- def __init__(self, model, eval_dataset, ppl_metric, acc_record, print_per_step=200, has_trained_step=0):#ppl_metric,
- super(EvalCallBack, self).__init__()
- if not isinstance(print_per_step, int) or print_per_step < 0:
- raise ValueError("print_per_step must be int and >= 0.")
- self.print_per_step = print_per_step
- self.model = model
- self.eval_dataset = eval_dataset
- self.pplMetric = ppl_metric
- self.acc_record = acc_record
- self.has_trained_step = has_trained_step
- self.pplMetric.clear()
-
- def step_end(self, run_context):
- """
- step end
- """
- cb_params = run_context.original_args()
- current_step = cb_params.cur_step_num + self.has_trained_step
- # print("@@@@ current_step is: {} @@@@".format(current_step))
- if current_step % self.print_per_step != 0:
- return
- self.pplMetric.clear()
- rank_id = 0
- start_time = time.time()
- output = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
- print(output)
- loss, ppl = output['ppl'][0], output['ppl'][1]
- self.acc_record.append(loss)
- print("loss, ppl", loss, ppl)
- end_time = time.time()
- eval_time = int(end_time - start_time)
- time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime())
- out_str = "{} == Rank: {} == EvalCallBack model.eval()\n Loss={}, PPL={}; eval_time: {}s". \
- format(time_str, rank_id, loss, ppl, eval_time)
- print(out_str)
-
-
- def train(model, ds, callbacks, epochs=1):
- print("============== Starting Training ==============")
- model.train(epochs, ds, callbacks=callbacks, sink_size=2, dataset_sink_mode=True)
-
- def test(model, ds):
- print("============== Starting Testing ==============")
- ppl_mat = model.eval(ds)
- loss, ppl = ppl_mat['ppl']
- return loss, ppl
-
-
- if __name__ == '__main__':
- os.environ["GRPC_VERBOSITY"] = "INFO"
- # set logging level
- # level = logging.INFO
- # logger = FLOWER_LOGGER #logging.getLogger()
- # logger.setLevel(level)
- # for handler in logger.handlers:
- # handler.setLevel(level)
-
- target = args_opt.device_target
- project_root = os.path.abspath(
- os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
- # Set hccl connect time
- os.environ['HCCL_CONNECT_TIMEOUT'] = "6000"
- EXEC_PATH = os.path.join(project_root, 'pangu-mindspore-AISyn')
- device_id = int(os.getenv("DEVICE_ID"))
- rank_id_str = os.getenv('RANK_ID', '0')
- rank_id = int(
- rank_id_str[rank_id_str.rfind('-') +
- 1:]) # 'RANK_ID': 'job24535502-job-facereidtome-hn-0/1'
- print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str))
- device_id = int(os.getenv('DEVICE_ID'))
- local_rank = rank_id
- print('local_rank:{}, device id:{}'.format(local_rank, device_id))
-
- # donload dataset
- if local_rank % 8 == 0:
- os.environ['HCCL_CONNECT_TIMEOUT'] = "6000"
- os.system('ulimit -s 102400')
- mox.file.copy_parallel(src_url=DATASET_DIR, dst_url=LOCAL_PATH)
- RANK_FILE = open('/cache/wiki/rank_id.txt', 'r')
- RANK_ID = int(RANK_FILE.read())
- RANK_FILE.close()
- if RANK_ID == 19:
- tmp_f = open('/cache/wiki/rank_id.txt', 'w+')
- tmp_f.write('0')
- tmp_f.close()
- print("Download dataset success.")
- f = open("%s/install.txt" % (EXEC_PATH), 'w')
- f.close()
- # 此处用于阻塞其他进程,直到刷包以及下载数据集完成为止
- while not os.path.exists("%s/install.txt" % (EXEC_PATH)):
- time.sleep(1)
-
- if args_opt.initial=='False':
- initial=False
- else:
- initial=True
- print(f"initial:{initial}")
-
- # init context
- context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
- context.set_context(variable_memory_max_size="31GB")
-
- if args_opt.parameter_server:
- context.set_ps_context(enable_ps=True)
- if args_opt.run_distribute:
- if target == "Ascend":
- D.init()
- device_num = D.get_group_size()
- rank = D.get_rank()
- print("rank_id is {}, device_num is {}".format(rank, device_num))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(
- parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, ##ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
- device_num=device_num,
- full_batch=False,
- enable_parallel_optimizer=False)
- set_algo_parameters(elementwise_op_strategy_follow=True)
- _set_multi_subgraphs()
- else:
- exit('Wrong Device in modelarts!')
- #print('11111111111111111111111111111')
- D.init()
- rank = 0
- device_num = 1
-
- # Set model property
- model_parallel_num = 1 #args_opt.op_level_model_parallel_num
- data_parallel_num = int(device_num / model_parallel_num)
- args_opt.per_batch_size = 16
- batch_size = args_opt.per_batch_size * data_parallel_num
- print("@@@@@ batch_size_perDevice is : {} @@@@@".format(batch_size))
-
- pangu_config = PANGUALPHAConfig(
- data_parallel_num=data_parallel_num, model_parallel_num=model_parallel_num, batch_size=batch_size,
- seq_length=1024, vocab_size=40000, embedding_size=1024,
- num_layers=24, num_heads=16, expand_ratio=4, dropout_rate=0.1,
- compute_dtype=mstype.float16, stage_num=1, micro_size=1,
- eod_reset=bool(args_opt.eod_reset), load_ckpt_path='/cache',
- param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
- word_emb_dp=bool(args_opt.word_emb_dp))
- print("===config is: ", pangu_config, flush=True)
- gpt = PanguAlpha(pangu_config, is_teacher=False)
-
- uuid = args_opt.uuid
-
- # Initial scaling sens
- loss_scale_value = math.pow(2, 32)
- epoch_num = args_opt.epoch_size
- loss = CrossEntropyLoss(pangu_config)
- gpt_with_loss = PanguAlphaWithLoss(pangu_config, gpt, loss, eos_token=args_opt.eod_id)
- pangu_alpha_with_loss = gpt_with_loss
-
- dataset = create_dataset(pangu_config.batch_size, data_path=LOCAL_PATH + '/train', data_start_index=0, eod_reset=pangu_config.eod_reset, eod_id=args_opt.eod_id, device_num=device_num, rank=rank, hash_check=False)
- step_per_epoch = dataset.get_dataset_size()
- print("step_per_epoch is: {}".format(step_per_epoch))
- callback_size = 2
- actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
- # define callbacks
- time_cb = TimeMonitor(callback_size)# data_size=step_per_epoch, callback_size
- # time_cb = TimeMonitor(data_size=step_per_epoch)# data_size=step_per_epoch
- #打印loss信息
- # loss_cb = LossMonitor(callback_size, local_rank, 0, 0)
- loss_cb = LossCallBack(callback_size, local_rank, 0, 0)
- cb = [time_cb, loss_cb]
-
- update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
-
- decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
- params = pangu_alpha_with_loss.trainable_params()
- decay_params = list(filter(decay_filter, params))
- other_params = list(filter(lambda x: not decay_filter(x), params))
- group_params = [{
- 'params': decay_params,
- 'weight_decay': 1e-1
- }, {
- 'params': other_params,
- 'weight_decay': 0.0
- }, {
- 'order_params': params
- }]
- ######################################################################################################
- lr = LearningRate(learning_rate=5e-5, end_learning_rate=5e-6,
- warmup_steps=0, decay_steps=200)
- optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
- ######################################################################################################
- pangu_alpha_with_grads = PanguAlphaTrainOneStepWithLossScaleCell(
- pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True,
- config=pangu_config)
-
- acc_record=[]
- eval_dataset = create_dataset(pangu_config.batch_size, data_path=LOCAL_PATH + '/test', data_start_index=0, eod_reset=pangu_config.eod_reset, eod_id=args_opt.eod_id, device_num=device_num, rank=rank, hash_check=False, is_train=False)
- ppl_metric = PPLMetric(1024)
-
- # model = Model(pangu_alpha_with_grads, eval_network=pangu_alpha_with_loss, metrics={"ppl": ppl_metric})
- model = Model(pangu_alpha_with_grads, eval_network=pangu_alpha_with_loss, metrics={"ppl": ppl_metric})
- #eval_cb = EvalCallBack(model, eval_dataset, ppl_metric, acc_record)
- #cb += [eval_cb]
-
- ##################################################################################################################
- actual_epoch_num = dataset.get_dataset_size()
- print("Dataset size: {}, actual_epoch_num: {}".format(dataset.get_dataset_size(), actual_epoch_num), flush=True)
- ##################################################################################################################
-
- # Flower client
- class CifarClient(aisc.client.NumPyClientAdap):
- # def get_parameters(self):
-
- # def set_parameters(self, parameters):
-
- # def fit(self, parameters, config):
-
- def evaluate(self, parameters, config):
- self.set_parameters(parameters)
- loss, ppl = test(model, eval_dataset)
- return float(loss), eval_dataset.get_dataset_size(), {'ppl': ppl}
-
- # Start client
- #'dl_framework', 'model', 'train_ds', 'test_ds', 'train_func', 'test_func' and callbacks
-
- aisc.client.run_numpyAdap_client("xxx.xxx.xxx.xxx:30003", client=CifarClient('mindspore', model, dataset, eval_dataset, train, test, callbacks=cb),
- grpc_max_message_length=1536870912)
|