基于trlx
库使用RLHF训练Pangu 2.6B中文对话模型pipeline
基于chat-gpt的人工反馈的强化学习(RLHF)流程,开发了基于盘古-alpha 2.6B GPU版本模型的RLHF pipeline。我们的pipeline是基于OpenAI论文 "Learning to Summarize from human feedback"的复现代码trlx进行修改。
使用盘古-alpha 2.6B 模型为基础模型,通过监督预训练(SFT)在webtext等对话语料上进行微调得到对话版本盘古-alpha模型。标注人员从15个常见领域设计问题对盘古对话模型进行提问,针对盘古对话模型的输出结果,从适用性,具体性,正确性,安全性4个维度进行人工反馈评测,并收集人工反馈数据用于训练评价模型(RM)代替人工反馈。最后,使用经典RL方法PPO算法和RM模型对SFT阶段的盘古模型进行强化学习训练。
准备阶段
1). 需要配置trlx库相关环境,参考 "trlx"
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 # for cuda
pip install -e .
2). 下载盘古-2.6B模型:
https://huggingface.co/imone/pangu_2_6B
模型.bin文件保存至 ./Pangu_chk
3). 准备SFT数据集(以webtext为例):
https://paperswithcode.com/dataset/webtext
数据样例保存至: ./dialogue_dir/demo.json
收集人工反馈数据
数据样例保存至: ./reward_data_dir/processed/demo.json
下图所示为用户标注界面,数据标注相关细节可参考: PanGu-Dialog-HFDataset
代码替换
将repalce 文件夹内的 ppo_models.py 文件替换trlx/trainer/nn文件夹下的ppo_models.py
主要修改为盘古模型的载入部分.
if "pangu" in config.lower():
self.config = transformers.AutoConfig.from_pretrained(config,trust_remote_code=True)
self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config,trust_remote_code=True)
gpt_branch_supported_archs = [
"GPTJForCausalLM",
"GPT2LMHeadModel",
"GPTNeoForCausalLM",
"GPTNeoXForCausalLM",
"GPTPanguForCausalLM",
]
盘古的分词器
为了和Trlx兼容,我们将分词器修改为 与CPM分词器的接口相同格式,与原始的盘古分词器有所不同。SPM文件是盘古的。可以使用入下命令导入。
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("./PanguTokenizer")
或者从Hugging face下载
tokenizer = AutoTokenizer.from_pretrained("Hanlard/Pangu_alpha")
训练步骤
1). 监督微调 (SFT):
cd sft/ && deepspeed train_SFT.py
2). 训练 Reward 模型:
cd reward_model/ && deepspeed train_reward_model.py
3). 使用PPO算法强化学习:
accelerate launch --config_file configs/default_accelerate_config.yaml trlx_pangu.py
备注: 至少需要1张V100显卡。
参考文献
- Nisan Stiennon, Long Ouyang, Jeff Wu, Daniel M. Ziegler, Ryan Lowe, Chelsea Voss, Alec Radford, Dario Amodei, Paul Christiano, "Learning to Summarize from human feedback", Neural Information Processing Systems, 2020.