|
- import os
-
- from mindspore import load_checkpoint
- from mindspore.dataset import GeneratorDataset
-
- from src.metric.calc import bpc, ppl
- from src.model.mem_transformer import MemTransformerLM
- from src.model_utils.config import config
- from src.model_utils.device_adapter import get_device_num, get_rank_id
- from src.utils.dataset_util import get_dataset
- from src.callback.eval import doEval
-
-
- def main():
- dataset = get_dataset()
- ntokens = len(dataset.vocab)
-
- # rank_size, rank_id = get_device_num(), get_rank_id()
-
- valid_dataset = GeneratorDataset(source=dataset.get_valid_generator(), column_names=['data', 'target'], shuffle=False)
- test_dataset = GeneratorDataset(source=dataset.get_test_generator(), column_names=['data', 'target'], shuffle=False)
- model_filename = os.path.join(config.load_path, 'model0.ckpt')
- # adaptive softmax / embedding
- cutoffs, tie_projs = [], [False]
- if config.adaptive:
- # assert config.dataset in ['wt103', 'lm1b']
- if config.dataset == 'wt103':
- cutoffs = [20000, 40000, 200000]
- tie_projs += [True] * len(cutoffs)
- elif config.dataset == 'lm1b':
- cutoffs = [60000, 100000, 640000]
- tie_projs += [False] * len(cutoffs)
- net = MemTransformerLM(ntokens, config.n_layer, config.n_head, config.d_model,
- config.d_head, config.d_inner, config.dropout, config.dropatt, batch_size=config.batch_size,
- d_embed=config.d_embed, div_val=config.div_val,
- pre_lnorm=config.pre_lnorm, tgt_len=config.tgt_len,
- ext_len=config.ext_len, mem_len=config.mem_len, eval_tgt_len=config.eval_tgt_len,
- cutoffs=cutoffs, same_length=config.same_length, clamp_len=config.clamp_len)
- load_checkpoint(net=net, ckpt_file_name=model_filename)
- valid_loss = doEval(net, valid_dataset, config.tgt_len, config.ext_len, config.mem_len, config.eval_tgt_len)
- test_loss = doEval(net, test_dataset, config.tgt_len, config.ext_len, config.mem_len, config.eval_tgt_len)
-
- print('=' * 100)
- if config.dataset in ['enwik8', 'text8']:
- print('| End of valid | valid loss {:5.2f} | valid bpc {:9.5f}'.format(
- valid_loss, bpc(valid_loss)))
- print('| End of test | test loss {:5.2f} | test bpc {:9.5f}'.format(
- test_loss, bpc(test_loss)))
- else:
- print('| End of valid | valid loss {:5.2f} | valid bpc {:9.5f}'.format(
- valid_loss, ppl(valid_loss)))
- print('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
- test_loss, ppl(test_loss)))
- print('=' * 100)
-
-
- if __name__ == '__main__':
- main()
|