Mixtral 8x7B
Install
# Install the latest xtuner
pip install -U 'xtuner[deepspeed]'
# Mixtral requires flash-attn
pip install flash-attn
# install the latest transformers
pip install -U transformers
QLoRA Fine-tune
QLoRA only need a single A100-80G
xtuner train mixtral_8x7b_instruct_qlora_oasst1_e3 --deepspeed deepspeed_zero2
Full Parameter Fine-tune
Full parameter fine-tune needs 16 A100-80G
slurm
Note: $PARTITION
means the virtual partition of slurm.
srun -p $PARTITION --job-name=mixtral --nodes=2 --gres=gpu:8 --ntasks-per-node=8 xtuner train mixtral_8x7b_instruct_full_oasst1_e3 --deepspeed deepspeed_zero3 --launcher slurm
torchrun
Note: $NODE_0_ADDR
means the ip address of the node_0 machine.
# excuete on node 0
NPROC_PER_NODE=8 NNODES=2 PORT=29600 ADDR=$NODE_0_ADDR NODE_RANK=0 xtuner train mixtral_8x7b_instruct_full_oasst1_e3 --deepspeed deepspeed_zero3
# excuete on node 1
NPROC_PER_NODE=8 NNODES=2 PORT=29600 ADDR=$NODE_0_ADDR NODE_RANK=1 xtuner train mixtral_8x7b_instruct_full_oasst1_e3 --deepspeed deepspeed_zero3
Speed
16 * A100 80G:
Model |
Sequence Length |
Use Varlen Attn |
Sequence Parallel World Size |
Tokens per Second |
mixtral_8x7b |
32k |
False |
1 |
853.7 |
mixtral_8x7b |
32k |
True |
1 |
910.1 |
mixtral_8x7b |
32k |
False |
2 |
635.2 |
mixtral_8x7b |
32k |
True |
2 |
650.9 |