|
- # -*- coding: utf-8 -*-
- """
- @author:XuMing(xuming624@qq.com)
- @description:
- """
- import argparse
- from loguru import logger
- import sys
-
- sys.path.append('../..')
- from textgen.seq2seq import ConvSeq2SeqModel
-
-
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument('--train_file', default='../data/zh_dialog.tsv', type=str, help='Training data file')
- parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
- parser.add_argument('--do_predict', action='store_true', help='Whether to run predict.')
- parser.add_argument('--output_dir', default='./outputs/convseq2seq_zh/', type=str, help='Model output directory')
- parser.add_argument('--max_seq_length', default=50, type=int, help='Max sequence length')
- parser.add_argument('--num_epochs', default=200, type=int, help='Number of training epochs')
- parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
- args = parser.parse_args()
- logger.info(args)
-
- if args.do_train:
- logger.info('Loading data...')
- model = ConvSeq2SeqModel(epochs=args.num_epochs, batch_size=args.batch_size,
- model_dir=args.output_dir, max_length=args.max_seq_length)
- model.train_model(args.train_file)
- print(model.eval_model(args.train_file))
-
- if args.do_predict:
- model = ConvSeq2SeqModel(epochs=args.num_epochs, batch_size=args.batch_size,
- model_dir=args.output_dir, max_length=args.max_seq_length)
- sentences = ["什么是ai", "你是什么类型的计算机", "你知道热力学吗"]
- print("inputs:", sentences)
- print("outputs:", model.predict(sentences))
-
-
- if __name__ == '__main__':
- main()
|