|
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Transformer training script."""
-
- import math
- import os
- import numpy as np
- import mindspore as ms
- from mindspore import DynamicLossScaleManager
- from mindspore.communication import init
-
- import mindspore.nn.optim as optim
- import mindspore.context as context
- from mindspore.dataset import GeneratorDataset
- from mindspore.train.model import Model
-
- # from key_mapping import get_mapping
- from src.callback.eval import EvalDuringTrain, doEval
- from src.callback.log import TrainLogger
- from src.callback.flag import FlagModifiedCallback
- from src.model.mem_transformer import MemTransformerLM
- from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
- from src.model_utils.config import config
- from src.utils.dataset_util import get_dataset
- from src.utils.nnUtils import uniform_, normal_, constant_
- from src.metric.calc import bpc, ppl
- from static_lr import get_lr_of_40w_steps
-
- # os.environ['RANK_TABLE_FILE'] = '/home/wut/txl_ascend/rank_table_8pcs.json'
- # from torch2msp import convert, load_state
-
- # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
- os.environ['GLOG_v'] = '1'
-
-
- # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,3'
-
- # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
-
- def init_weight(weight, _config):
- """
- init weight
- Args:
- weight: init weight value
- _config: config
-
- Returns:
- """
- if _config.init == 'uniform':
- uniform_(weight, -_config.init_range, _config.init_range)
- elif _config.init == 'normal':
- normal_(weight, 0.0, _config.init_std)
-
-
- def init_bias(bias):
- constant_(bias, 0.0)
-
-
- def weights_init(m, config):
- classname = m.__class__.__name__
- if classname.find('Dense') != -1:
- if hasattr(m, 'weight') and m.weight is not None:
- init_weight(m.weight, config)
- if hasattr(m, 'bias') and m.bias is not None:
- init_bias(m.bias)
- elif classname.find('AdaptiveEmbedding') != -1:
- if hasattr(m, 'emb_projs'):
- for i in range(len(m.emb_projs)):
- if m.emb_projs[i] is not None:
- normal_(m.emb_projs[i], 0.0, config.proj_init_std)
- elif classname.find('Embedding') != -1:
- if hasattr(m, 'weight'):
- init_weight(m.weight, config)
- elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
- if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
- init_weight(m.cluster_weight, config)
- if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
- init_bias(m.cluster_bias)
- if hasattr(m, 'out_projs'):
- for i in range(len(m.out_projs)):
- if m.out_projs[i] is not None:
- normal_(m.out_projs[i], 0.0, config.proj_init_std)
- elif classname.find('LayerNorm') != -1:
- if hasattr(m, 'weight'):
- normal_(m.weight, 1.0, config.init_std)
- if hasattr(m, 'bias') and m.bias is not None:
- init_bias(m.bias)
- elif classname.find('TransformerLM') != -1:
- if hasattr(m, 'r_emb'):
- init_weight(m.r_emb, config)
- if hasattr(m, 'r_w_bias'):
- init_weight(m.r_w_bias, config)
- if hasattr(m, 'r_r_bias'):
- init_weight(m.r_r_bias, config)
- if hasattr(m, 'r_bias'):
- init_bias(m.r_bias)
-
-
- # def weights_init(m, _config):
- # classname = m.__class__.__name__
- # if classname.find('AdaptiveEmbedding') != -1:
- # if hasattr(m, 'emb_projs'):
- # for i in range(len(m.emb_projs)):
- # if m.emb_projs[i] is not None:
- # normal_(m.emb_projs[i], 0.0, _config.proj_init_std)
- # else:
- # if hasattr(m, 'r_emb'):
- # init_weight(m.r_emb, _config)
- # if hasattr(m, 'r_w_bias'):
- # init_weight(m.r_w_bias, _config)
- # if hasattr(m, 'r_r_bias'):
- # init_weight(m.r_r_bias, _config)
- # if hasattr(m, 'r_bias'):
- # init_bias(m.r_bias)
-
-
- # def update_dropout(m, config):
- # classname = m.__class__.__name__
- # if classname.find('Dropout') != -1:
- # if hasattr(m, 'p'):
- # m.p = config.dropout
- #
- #
- # def update_dropatt(m, config):
- # if hasattr(m, 'dropatt'):
- # m.dropatt.p = config.dropatt
-
-
- def get_optimizer(_config, net, scheduler):
- """
- get optimizer: adam,sgd
- Args:
- _config:
- net:
- scheduler:
-
- Returns:
- optimizer:
- optimizer_sparse: default is None
- """
- optimizer = optimizer_sparse = None
- lr = dynamic_lr()
- if _config.optim.lower() == 'sgd':
- if _config.sample_softmax > 0:
- dense_params, sparse_params = [], []
- for param in net.trainable_params():
- if len(param) == len(net.word_emb.embedding_table):
- sparse_params.append(param)
- else:
- dense_params.append(param)
- optimizer_sparse = optim.SGD(sparse_params, learning_rate=_config.lr * 2)
- optimizer = optim.SGD(dense_params, learning_rate=_config.lr, momentum=_config.mom)
- else:
- optimizer = optim.SGD(net.trainable_params(), learning_rate=_config.lr,
- momentum=_config.mom)
- elif _config.optim.lower() == 'adam':
- if _config.sample_softmax > 0:
- dense_params, sparse_params = [], []
- for param in net.trainable_params():
- if len(param) == len(net.word_emb.embedding_table):
- sparse_params.append(param)
- else:
- dense_params.append(param)
- # 无SparseAdam算子
- optimizer_sparse = optim.SparseAdam(sparse_params, lr=lr)
- optimizer = optim.Adam(dense_params, learning_rate=lr)
- else:
- optimizer = optim.Adam(net.trainable_params(), learning_rate=scheduler)
- # optimizer = optim.Adam(net.trainable_params(), learning_rate=lr)
- elif _config.optim.lower() == 'adagrad':
- optimizer = optim.Adagrad(net.trainable_params(), learning_rate=lr)
- return optimizer, optimizer_sparse
-
-
- def rsqrt_decay(warmup_steps, current_step):
- return float(max([current_step, warmup_steps])) ** -0.5
-
-
- def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
- lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
- learning_rate = float(init_lr) + lr_inc * current_step
- return learning_rate
-
-
- def a_cosine_learning_rate(current_step, base_lr, warmup_steps, total_steps):
- decay_steps = total_steps - warmup_steps
- linear_decay = (total_steps - current_step) / decay_steps
- cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * current_step / decay_steps))
- decayed = linear_decay * cosine_decay + 0.00001
- learning_rate = decayed * base_lr
- return learning_rate
-
-
- def dynamic_lr():
- """dynamic learning rate generator"""
- base_lr = config.lr
- total_steps = int(config.max_step)
- warmup_steps = int(config.warmup_step)
- lr = []
- for i in range(total_steps):
- if i < warmup_steps:
- lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
- else:
- lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
- return lr
-
-
- def get_scheduler(_config):
- scheduler = scheduler_sparse = None
- if _config.scheduler == 'cosine':
- # here we do not set eta_min to lr_min to be backward compatible
- # because in previous versions eta_min is default to 0
- # rather than the default value of lr_min 1e-6
-
- # 缺失CosineAnnealingLR算子
- from src.utils.additional_algorithms import CosineAnnealingLR
- # scheduler = CosineDecayLR(min_lr=config.eta_min, max_lr=config.lr, decay_steps=1)
-
- scheduler = CosineAnnealingLR(total_step=_config.max_step, lr=_config.lr, min_lr=_config.eta_min)
- # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
- # config.max_step, eta_min=config.eta_min) # should use eta_min arg
- # if config.sample_softmax > 0:
- # scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse,
- # config.max_step,
- # eta_min=config.eta_min) # should use eta_min arg
-
- elif _config.scheduler == 'inv_sqrt':
- pass
- # originally used for Transformer (in Attention is all you need)
- # def lr_lambda(step):
- # # return a multiplier instead of a learning rate
- # if step == 0 and _config.warmup_step == 0:
- # return 1.
- # else:
- # return 1. / (step ** 0.5) if step > _config.warmup_step \
- # else step / (_config.warmup_step ** 1.5)
-
- # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
- elif _config.scheduler == 'dev_perf':
- pass
- # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
- # factor=config.decay_rate, patience=config.patience,
- # min_lr=config.lr_min)
- # if config.sample_softmax > 0:
- # scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse,
- # factor=config.decay_rate, patience=config.patience,
- # min_lr=config.lr_min)
- elif _config.scheduler == 'constant':
- pass
- return scheduler, scheduler_sparse
-
-
- def set_seed():
- np.random.seed(config.seed)
- ms.set_seed(config.seed)
-
-
- def main():
- # Set the random seed manually for reproducibility.
- set_seed()
- print('data_url : ', config.data_url)
- print('train_url : ', config.train_url)
- # config.data = config.data_url
-
- # os.listdir(config.data_url)
-
- device_id = get_device_id()
- device_num = get_device_num()
- print('device_id : ', device_id)
- print('device_num : ', device_num)
-
- if config.device_target == 'ascend':
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
- if device_num > 1:
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num=device_num, parallel_mode=context.ParallelMode.DATA_PARALLEL,
- gradients_mean=True)
- init()
-
- elif config.device_target == 'gpu':
- # context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", max_device_memory="39.0GB")
- context.set_context(mode=context.GRAPH_MODE, device_target="GPU", max_device_memory="39.0GB")
- if device_num > 1:
- init()
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num=device_num, parallel_mode=context.ParallelMode.DATA_PARALLEL,
- gradients_mean=True)
- else:
- context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
-
- rank_size, rank_id = get_device_num(), get_rank_id()
- print('rank_id : ', rank_id)
-
- ###############################################################################
- # Load data
- ###############################################################################
-
- dataset = get_dataset()
- ntokens = len(dataset.vocab)
- config.n_token = ntokens
-
- # adaptive softmax / embedding
- cutoffs, tie_projs = [], [False]
- if config.adaptive:
- # assert config.dataset in ['wt103', 'lm1b']
- if config.dataset == 'wt103':
- cutoffs = [20000, 40000, 200000]
- tie_projs += [True] * len(cutoffs)
- elif config.dataset == 'lm1b':
- cutoffs = [60000, 100000, 640000]
- tie_projs += [False] * len(cutoffs)
-
- ###############################################################################
- # Build the model
- ###############################################################################
- # if config.restart:
- # with open(os.path.join(config.restart_dir, 'model.ckpt'), 'rb') as f:
- # net = ms.load(f)
- # if not config.fp16:
- # net = net.float()
- # update_dropout(net)
- # update_dropatt(net)
- # else:
- net = MemTransformerLM(ntokens, config.n_layer, config.n_head, config.d_model,
- config.d_head, config.d_inner, config.dropout, config.dropatt, batch_size=config.batch_size,
- d_embed=config.d_embed, div_val=config.div_val,
- pre_lnorm=config.pre_lnorm, tgt_len=config.tgt_len,
- ext_len=config.ext_len, mem_len=config.mem_len, eval_tgt_len=config.eval_tgt_len,
- cutoffs=cutoffs, same_length=config.same_length, clamp_len=config.clamp_len)
- # ensure embedding init is not overridden by out_layer in case of weight sharing
- weights_init(net, config)
- weights_init(net.word_emb, config)
-
- config.n_all_param = sum([p.size for p in net.trainable_params()])
- config.n_nonemb_param = sum([p.size for p in net.layers.trainable_params()])
-
- #### scheduler
- scheduler, _ = get_scheduler(config)
- # scheduler = get_lr_of_40w_steps() # 使用torch对应方法生成的lr
- #### optimizer
- optimizer, _ = get_optimizer(config, net, scheduler)
-
- # if config.restart:
- # if os.path.exists(os.path.join(config.restart_dir, 'optimizer.ckpt')):
- # with open(os.path.join(config.restart_dir, 'optimizer.ckpt'), 'rb') as f:
- # opt_state_dict = ms.load(f)
- # optimizer.load_state_dict(opt_state_dict)
- # else:
- # print('Optimizer was not saved. Start from scratch.')
- if device_id == 0:
- print('=' * 100)
- for k, v in config.__dict__.items():
- print(' - {} : {}'.format(k, v))
- print('=' * 100)
- print('#params = {}'.format(config.n_all_param))
- print('#non emb params = {}'.format(config.n_nonemb_param))
-
- ###############################################################################
- # Training code
- ###############################################################################
- config.n_batch = dataset.get_train_generator().n_batch
- config.max_epoch = math.ceil(config.max_step / config.n_batch)
-
- train_dataset = GeneratorDataset(source=dataset.get_train_generator(), column_names=['data', 'target'],
- num_shards=rank_size, shard_id=rank_id, shuffle=False)
- valid_dataset = GeneratorDataset(source=dataset.get_valid_generator(), column_names=['data', 'target'],
- shuffle=False)
-
- # Train #
-
- flagModifiedCallback = FlagModifiedCallback()
- train_log = TrainLogger(per_print_times=config.log_interval, n_batch=config.n_batch)
- evalDuringTrain = EvalDuringTrain(dataset=valid_dataset, per_print_times=config.eval_interval,
- tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
- eval_tgt_len=config.eval_tgt_len)
-
- model = Model(network=net, loss_fn=None, optimizer=optimizer, metrics=None)
- model.train(config.max_step, train_dataset, sink_size=1,
- callbacks=[flagModifiedCallback, train_log, evalDuringTrain])
-
- # Test #
-
- # 不使用eval进行验证及测试,另行编写代码
- if device_id == 0:
- test_dataset = GeneratorDataset(source=dataset.get_test_generator(), column_names=['data', 'target'],
- shuffle=False)
- test_loss = doEval(net=net, dataset=test_dataset, tgt_len=config.tgt_len, ext_len=config.ext_len,
- mem_len=config.mem_len, eval_tgt_len=config.eval_tgt_len)
- print('=' * 100)
- if config.dataset in ['enwik8', 'text8']:
- print('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
- test_loss, bpc(test_loss)))
- else:
- print('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
- test_loss, ppl(test_loss)))
- print('=' * 100)
-
-
- if __name__ == '__main__':
- main()
|