|
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import os
- from mindspore import load_checkpoint, context
- from mindspore.dataset import GeneratorDataset
- from src.callback.eval import doEval
- from src.metric.calc import bpc, ppl
- from src.model.mem_transformer import MemTransformerLM
- from src.model_utils.config import config
- from src.model_utils.device_adapter import get_device_num, get_rank_id
- from src.utils.dataset_util import get_dataset
-
-
- def main():
- dataset = get_dataset()
- ntokens = len(dataset.vocab)
-
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", max_device_memory="39.0GB")
-
- rank_size, rank_id = get_device_num(), get_rank_id()
- valid_dataset = GeneratorDataset(source=dataset.get_valid_generator(), column_names=['data', 'target'],
- shuffle=False)
- test_dataset = GeneratorDataset(source=dataset.get_test_generator(), column_names=['data', 'target'], shuffle=False)
- model_filename = os.path.join(config.load_path, 'model0.ckpt')
- # adaptive softmax / embedding
- cutoffs, tie_projs = [], [False]
- if config.adaptive:
- # assert config.dataset in ['wt103', 'lm1b']
- if config.dataset == 'wt103':
- cutoffs = [20000, 40000, 200000]
- tie_projs += [True] * len(cutoffs)
- elif config.dataset == 'lm1b':
- cutoffs = [60000, 100000, 640000]
- tie_projs += [False] * len(cutoffs)
- net = MemTransformerLM(ntokens, config.n_layer, config.n_head, config.d_model,
- config.d_head, config.d_inner, config.dropout, config.dropatt, batch_size=config.batch_size,
- d_embed=config.d_embed, div_val=config.div_val,
- pre_lnorm=config.pre_lnorm, tgt_len=config.tgt_len,
- ext_len=config.ext_len, mem_len=config.mem_len, eval_tgt_len=config.eval_tgt_len,
- cutoffs=cutoffs, same_length=config.same_length, clamp_len=config.clamp_len)
- load_checkpoint(net=net, ckpt_file_name=model_filename)
- valid_loss = doEval(net, valid_dataset, config.tgt_len, config.ext_len, config.mem_len, config.eval_tgt_len)
- test_loss = doEval(net, test_dataset, config.tgt_len, config.ext_len, config.mem_len, config.eval_tgt_len)
-
- print('=' * 100)
- if config.dataset in ['enwik8', 'text8']:
- print('| End of valid | valid loss {:5.2f} | valid bpc {:9.5f}'.format(
- valid_loss, bpc(valid_loss)))
- print('| End of test | test loss {:5.2f} | test bpc {:9.5f}'.format(
- test_loss, bpc(test_loss)))
- else:
- print('| End of valid | valid loss {:5.2f} | valid bpc {:9.5f}'.format(
- valid_loss, ppl(valid_loss)))
- print('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
- test_loss, ppl(test_loss)))
- print('=' * 100)
-
-
- if __name__ == '__main__':
- main()
|