Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
Jiaqi Liu fa1e0ba115 | 2 years ago | |
---|---|---|
.. | ||
README.md | 3 years ago | |
args.py | 3 years ago | |
data.py | 3 years ago | |
model.py | 2 years ago | |
predict.py | 3 years ago | |
train.py | 2 years ago |
以下是本范例模型的简要目录结构及说明:
.
├── README.md # 文档,本文件
├── args.py # 训练、预测以及模型参数配置程序
├── data.py # 数据读入程序
├── train.py # 训练主程序
├── predict.py # 预测主程序
└── model.py # 带注意力机制的对联生成程序
Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)结构,用编码器将源序列编码成vector,再用解码器将该vector解码为目标序列。Seq2Seq 广泛应用于机器翻译,自动对话机器人,文档摘要自动生成,图片描述自动生成等任务中。
本目录包含Seq2Seq的一个经典样例:自动对联生成,带attention机制的文本生成模型。
本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,在预测时我们使用Beam Search算法来生对联的下联。
本教程使用couplet数据集作为训练语料,该数据集来源于这个github repo,其中train_src.tsv及train_tgt.tsv为训练集,dev_src.tsv及dev_tgt.tsv为开发集,test_src.tsv及test_tgt.tsv为测试集。
数据集会在调用paddlenlp.datasets.load_dataset
时自动下载,在linux系统下,数据集会自动下载到~/.paddlenlp/datasets/Couplet/
目录下
执行以下命令即可训练带有注意力机制的Seq2Seq模型:
python train.py \
--num_layers 2 \
--hidden_size 512 \
--batch_size 128 \
--device gpu \
--model_path ./couplet_models \
--max_epoch 20
各参数的具体说明请参阅 args.py
。训练程序会在每个epoch训练结束之后,保存一次模型。
NOTE: 如需恢复模型训练,则init_from_ckpt
只需指定到文件名即可,不需要添加文件尾缀。如--init_from_ckpt=couplet_models/19
即可,程序会自动加载模型参数couplet_models/19.pdparams
,也会自动加载优化器状态couplet_models/19.pdopt
。
训练完成之后,可以使用保存的模型(由 --init_from_ckpt
指定)对测试集进行beam search解码,命令如下:
python predict.py \
--num_layers 2 \
--hidden_size 512 \
--batch_size 128 \
--init_from_ckpt couplet_models/19 \
--infer_output_file infer_output.txt \
--beam_size 10 \
--device gpu
各参数的具体说明请参阅 args.py
,注意预测时所用模型超参数需和训练时一致。
上联:崖悬风雨骤 下联:月落水云寒
上联:约春章柳下 下联:邀月醉花间
上联:箬笠红尘外 下联:扁舟明月中
上联:书香醉倒窗前月 下联:烛影摇红梦里人
上联:踏雪寻梅求雅趣 下联:临风把酒觅知音
上联:未出南阳天下论 下联:先登北斗汉中书
上联:朱联妙语千秋颂 下联:赤胆忠心万代传
上联:月半举杯圆月下 下联:花间对酒醉花间
上联:挥笔如剑倚麓山豪气干云揽月去 下联:落笔似龙飞沧海龙吟破浪乘风来
我们的数据集采用了开源对联数据集couplet-clean-dataset,地址:https://github.com/v-zich/couplet-clean-dataset ,该数据集过滤了couplet-dataset(地址:https://github.com/wb14123/couplet-dataset )中的低俗、敏感内容。
黑客松task_55,在PaddleNLP的Roberta中,新增 MultipleChoice,MaskedLM 和 CausalLM三个类,7个模型权重. ,新增BPETokenizer
Python C++ Cuda Text Shell other
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》