Training Language Models to Follow Instructions with Human Feedback

Long Ouyang, Jeff Wu

2022 · NeurIPS

Training Language Models to Follow Instructions with Human Feedback

Problem

Framing

Scaling GPT-3 does not make it follow instructions, tell the truth, or avoid toxic output. The paper closes this gap with a three-stage RLHF pipeline: supervised demonstrations, a learned reward model, and PPO fine-tuning. A 1.3B InstructGPT model is preferred to 175B GPT-3 outputs in human evaluation.

Currently Used Methods

Foundational

Proposed Method

Architecture

The policy uses GPT-3 backbones at 1.3B, 6B, and 175B parameters. The reward model shares the transformer body but replaces the unembedding head with a scalar score; PPO adds a value head and regularizes to the SFT policy.

Three-stage RLHF pipeline: collect demonstrations for SFT, collect ranked comparisons for reward-model training, then optimize the policy with PPO against the learned reward.

Loss / Objective

Reward learning fits pairwise preferences; policy learning maximizes reward with KL control and an optional pretraining mix.

LRM(θ)=E(x,yw,yl)D[logσ(rθ(x,yw)rθ(x,yl))]\mathcal{L}_{\mathrm{RM}}(\theta) = -\mathbb{E}_{(x,y_w,y_l)\sim D}\left[\log \sigma\left(r_\theta(x,y_w)-r_\theta(x,y_l)\right)\right] J(ϕ)=E(x,y)DπϕRL[rθ(x,y)βlogπϕRL(yx)πSFT(yx)]+γExDpretrain[logπϕRL(x)]\mathcal{J}(\phi)=\mathbb{E}_{(x,y)\sim D_{\pi_\phi^{\mathrm{RL}}}}\left[r_\theta(x,y)-\beta \log \frac{\pi_\phi^{\mathrm{RL}}(y\mid x)}{\pi^{\mathrm{SFT}}(y\mid x)}\right]+\gamma\,\mathbb{E}_{x\sim D_{\mathrm{pretrain}}}\left[\log \pi_\phi^{\mathrm{RL}}(x)\right]

Algorithm

PPO treats each prompt-response pair as a bandit episode and uses the reward model plus KL penalty as the scalar return.

R(x,y)=rθ(x,y)βlogπϕRL(yx)πSFT(yx)R(x,y)=r_\theta(x,y)-\beta \log \frac{\pi_\phi^{\mathrm{RL}}(y\mid x)}{\pi^{\mathrm{SFT}}(y\mid x)}

Training Procedure

Evaluation

Datasets

Metrics

Headline results

Table 1: Held-out-worker preference results on the instruct distribution, cropped as a line plot rather than a readable numeric table.

ObservationValue
Compared systemsGPT, prompted GPT, SFT, PPO, PPO-ptx
X-axisModel size: 1.3B, 6B, 175B
Y-axisWin rate against SFT 175B
Main patternPPO-ptx is highest across scales
Secondary patternPPO also beats SFT and GPT baselines

Ablations

Method Strengths and Weaknesses

Strengths

Weaknesses

Suggestions from the authors

Links

Prior Papers

Further Papers