Policy Gradients (PPO)
Walk through PPO and why it stabilises post-training policy updates.
- Estimated time
- 40 minutes
- Difficulty
- intermediate
- Prerequisites
- 2 module(s)
Clipped PPO Objective
Chapter 11 recaps policy gradient methods and introduces Proximal Policy Optimisation (PPO) as the workhorse for RLHF. PPO modifies REINFORCE by clipping the policy ratio and applying a KL regulariser to stay near the instruction-tuned reference policy.
Notation. is the policy ratio, the advantage estimate, the clip range, and the KL weight that Chapter 11 recommends decaying or adapting during training.
Why PPO Stabilises Updates
Vanilla policy gradients push probabilities in proportion to the advantage. When the reward model spikes on a batch, the update can double or halve token probabilities, derailing the language model. PPO keeps the desirable gradient signal but shaves off updates that move the ratio beyond .
Chapter 11 also recommends limiting each rollout batch to a few gradient epochs and optionally whitening rewards or advantages. Together with the KL penalty, these tactics keep the RLHF fine-tuning loop on-track without sacrificing exploration.
PPO variants highlighted in the chapter—RLOO, GRPO, and trust-region methods—can often be seen as choosing different advantage estimators or adapting the internal step size while preserving the clipping intuition.
Analogy: Arcade Bot With Bumpers
Picture the Atari-style analogy from the introduction: the bot wants to chase a higher score but bumper rails keep it from flying off the playfield. PPO’s clip range is that bumper; the KL term is the cord back to the reference console.
Arcade bot
Learns to clear a level faster but has bumpers that stop reckless moves. PPO’s clip range plays the same role described in Chapter 11.
Coach with joystick
Reviews past runs and nudges the joystick with limited pressure so the bot continues exploring without flipping off the table.
When the bot earns a high advantage, it can still surge forward—but only until it hits the bumper. Negative advantages pull it back toward steady play.
Policy Update Lab
Experiment with the hyperparameters Chapter 11 emphasises: learning rate, clip range, and KL penalty. Then inspect how clipping alters the local objective compared with vanilla policy gradients.
PPO policy update playground
Track how policy ratios evolve over mini-batches when you vary the learning rate, clipping range, and KL penalty — mirroring the PPO discussion in Chapter 11.
Parameters
Reading the curves
The blue line shows where plain policy gradient updates would push the ratio. The green line applies PPO clipping and the KL penalty, emulating the stabilisation strategies emphasised in Chapter 11. When clipping is tight or the KL weight grows, the curve hugs 1.0 rather than overshooting.
Final policy ratio ≈ 1.19 (clipped) vs 1.19 (unclipped).
Clipping intuition
See how PPO’s clip(min,max) guardrail limits the policy update compared to vanilla policy gradients and why it prevents runaway ratios.
Parameters
Unclipped policy gradient
Objective = ratio × advantage = 1.040. Vanilla REINFORCE would follow this value even if the ratio drifts far from 1.
PPO clipped objective
Clipped ratio = 1.200 ⇒ objective = 0.960.
When the advantage is positive and the ratio overshoots 1.20, the clip holds the update at the safe boundary. For negative advantages the same guardrail applies with 0.80.
Comparison with vanilla policy gradients
Chapter 11 contrasts PPO with the simpler REINFORCE objective. This control shows how PPO flattens the objective near large ratios, avoiding steps that would irreversibly warp the policy. If ε → ∞ the two objectives converge; if ε → 0 PPO refuses to change the policy at all.
The RLHF book provides a compact PyTorch loop for PPO. The snippet below adapts that pattern, showing how clipping and KL penalties appear in code.
import torch
import torch.nn.functional as F
def ppo_step(model, optimiser, batch, clip_range=0.2, kl_weight=0.01):
logits, values = model(batch["input_ids"], attention_mask=batch["attention_mask"])
logprobs = F.log_softmax(logits, dim=-1)
ratio = torch.exp(logprobs - batch["logprobs_old"]) # r_t(θ)
advantages = batch["advantages"]
pg_loss_unclipped = -advantages * ratio
pg_loss_clipped = -advantages * torch.clamp(ratio, 1.0 - clip_range, 1.0 + clip_range)
policy_loss = torch.max(pg_loss_unclipped, pg_loss_clipped).mean()
value_loss = 0.5 * (batch["returns"] - values).pow(2).mean()
kl = torch.distributions.kl.kl_divergence(
torch.distributions.Categorical(logits=logprobs),
torch.distributions.Categorical(logits=batch["logits_ref"]),
).mean()
loss = policy_loss + 0.5 * value_loss + kl_weight * kl
optimiser.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimiser.step()
return loss.item()
Implementation Notes
- PPO clips the policy ratio around 1 ± ε and pairs it with a KL penalty to prevent drift from the reference model.
- Reward/advantage normalisation, value heads, and limited epochs per batch keep updates on-policy (Chapter 11).
- Variants like GRPO and RLOO tweak advantage estimation but share the same trust region intuition.
- Entropy bonuses and adaptive KL schedules (measuring actual KL against a target) help balance exploration with safety constraints.
- Sequence packing and per-token losses reduce the wall-clock cost of PPO for long generations, as discussed in the implementation notes.
Many RLHF deployments still fall back to rejection sampling when compute is scarce. Chapter 11 positions PPO as the flexible middle ground between one-shot filtering and expensive RL loops.
PPO Check
Test your grasp of the PPO objective, clipping behaviour, and stabilisation tricks from Chapter 11.
Answered 0/5 · Correct 0/5
- 1
What objective does PPO maximise according to Chapter 11?
- 2
Why introduce the clip min/max operation?
- 3
How does the KL penalty interact with PPO in RLHF?
- 4
According to the book, what failure mode does PPO avoid that vanilla policy gradients suffer from?
- 5
Why do PPO implementations often cap the number of gradient steps per batch?