Skip to main content
Atari Game Bot

Policy Gradients (PPO)

Walk through PPO and why it stabilises post-training policy updates.

Estimated time
40 minutes
Difficulty
intermediate
Prerequisites
2 module(s)
Equation

Clipped PPO Objective

Chapter 11 recaps policy gradient methods and introduces Proximal Policy Optimisation (PPO) as the workhorse for RLHF. PPO modifies REINFORCE by clipping the policy ratio and applying a KL regulariser to stay near the instruction-tuned reference policy.

J(θ)=Et[min(rt(θ)At,  clip(rt(θ),1ϵ,1+ϵ)At)βDKL(πθ(st)πref(st))]J(\theta) = \mathbb{E}_{t}\Big[\min\big(r_t(\theta) A_t,\; \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t\big) - \beta\, D_{\mathrm{KL}}(\pi_\theta(\cdot|s_t)\,\|\, \pi_{\text{ref}}(\cdot|s_t))\Big]

Notation. rt(θ)=πθ(atst)πθold(atst)r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} is the policy ratio, AtA_t the advantage estimate, ϵ\epsilon the clip range, and β\beta the KL weight that Chapter 11 recommends decaying or adapting during training.

Intuition

Why PPO Stabilises Updates

Vanilla policy gradients push probabilities in proportion to the advantage. When the reward model spikes on a batch, the update can double or halve token probabilities, derailing the language model. PPO keeps the desirable gradient signal but shaves off updates that move the ratio beyond 1±ϵ1 \pm \epsilon.

Chapter 11 also recommends limiting each rollout batch to a few gradient epochs and optionally whitening rewards or advantages. Together with the KL penalty, these tactics keep the RLHF fine-tuning loop on-track without sacrificing exploration.

PPO variants highlighted in the chapter—RLOO, GRPO, and trust-region methods—can often be seen as choosing different advantage estimators or adapting the internal step size while preserving the clipping intuition.

Analogy

Analogy: Arcade Bot With Bumpers

Picture the Atari-style analogy from the introduction: the bot wants to chase a higher score but bumper rails keep it from flying off the playfield. PPO’s clip range is that bumper; the KL term is the cord back to the reference console.

Arcade bot

Learns to clear a level faster but has bumpers that stop reckless moves. PPO’s clip range plays the same role described in Chapter 11.

Coach with joystick

Reviews past runs and nudges the joystick with limited pressure so the bot continues exploring without flipping off the table.

When the bot earns a high advantage, it can still surge forward—but only until it hits the bumper. Negative advantages pull it back toward steady play.

Visualization

Policy Update Lab

Experiment with the hyperparameters Chapter 11 emphasises: learning rate, clip range, and KL penalty. Then inspect how clipping alters the local objective compared with vanilla policy gradients.

PPO policy update playground

Track how policy ratios evolve over mini-batches when you vary the learning rate, clipping range, and KL penalty — mirroring the PPO discussion in Chapter 11.

Parameters

Interactive visualization

Clipping intuition

See how PPO’s clip(min,max) guardrail limits the policy update compared to vanilla policy gradients and why it prevents runaway ratios.

Parameters

Interactive visualization

The RLHF book provides a compact PyTorch loop for PPO. The snippet below adapts that pattern, showing how clipping and KL penalties appear in code.

import torch
import torch.nn.functional as F

def ppo_step(model, optimiser, batch, clip_range=0.2, kl_weight=0.01):
    logits, values = model(batch["input_ids"], attention_mask=batch["attention_mask"])
    logprobs = F.log_softmax(logits, dim=-1)
    ratio = torch.exp(logprobs - batch["logprobs_old"])  # r_t(θ)

    advantages = batch["advantages"]
    pg_loss_unclipped = -advantages * ratio
    pg_loss_clipped = -advantages * torch.clamp(ratio, 1.0 - clip_range, 1.0 + clip_range)
    policy_loss = torch.max(pg_loss_unclipped, pg_loss_clipped).mean()

    value_loss = 0.5 * (batch["returns"] - values).pow(2).mean()
    kl = torch.distributions.kl.kl_divergence(
        torch.distributions.Categorical(logits=logprobs),
        torch.distributions.Categorical(logits=batch["logits_ref"]),
    ).mean()

    loss = policy_loss + 0.5 * value_loss + kl_weight * kl
    optimiser.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimiser.step()
    return loss.item()
Takeaways

Implementation Notes

  • PPO clips the policy ratio around 1 ± ε and pairs it with a KL penalty to prevent drift from the reference model.
  • Reward/advantage normalisation, value heads, and limited epochs per batch keep updates on-policy (Chapter 11).
  • Variants like GRPO and RLOO tweak advantage estimation but share the same trust region intuition.
  • Entropy bonuses and adaptive KL schedules (measuring actual KL against a target) help balance exploration with safety constraints.
  • Sequence packing and per-token losses reduce the wall-clock cost of PPO for long generations, as discussed in the implementation notes.

Many RLHF deployments still fall back to rejection sampling when compute is scarce. Chapter 11 positions PPO as the flexible middle ground between one-shot filtering and expensive RL loops.

Self-check

PPO Check

Test your grasp of the PPO objective, clipping behaviour, and stabilisation tricks from Chapter 11.

Answered 0/5 · Correct 0/5

  1. 1

    What objective does PPO maximise according to Chapter 11?

  2. 2

    Why introduce the clip min/max operation?

  3. 3

    How does the KL penalty interact with PPO in RLHF?

  4. 4

    According to the book, what failure mode does PPO avoid that vanilla policy gradients suffer from?

  5. 5

    Why do PPO implementations often cap the number of gradient steps per batch?