On-Policy Distillation for LLMs, Simply Explained

AI, But Simple Issue #106

On-Policy Distillation for LLMs, Simply Explained

AI, But Simple Issue #106

Large language models are impressive. They summarize articles, translate languages, solve math problems, and write code.

However, one of the key problems surrounding them is that they are expensive. Serving a model with billions of parameters costs real money at real scale.

This has made model compression a critical topic in modern AI research.

One of the most effective compression strategies is knowledge distillation: train a smaller student model to replicate the behavior of a larger teacher model. The student runs faster, uses less memory, and can be deployed cheaply.

This is simple in principle but harder in practice.

The Problem With LLM Distillation

Here is the core issue with how distillation is done today.

During training, the student model learns from a fixed dataset of sequences. These sequences are either from the ground-truth data (teacher forcing) or generated by the teacher in advance.

During inference, the student generates its own sequences from scratch, one token at a time.

The problem, however, is that these are two completely different situations.

The sequences the student encounters while training look nothing like the sequences it will generate when deployed. More formally, the distribution of contexts seen during training differs from the distribution encountered at inference.

This gap is called the train-inference distribution mismatch, and it is a well-known problem in imitation learning research.

Here is why the mismatch matters. In auto-regressive models, every token depends on all the previous tokens, meaning that the student is conditioned on its own outputs.

If the student makes a slightly wrong prediction early in a sequence, that error feeds forward into the next token, and the next, and the next. This way, the errors compound.

The student was trained on clean, expert-level sequences. It has never learned how to recover from its own mistakes.

This is the central problem that Agarwal et al. (2024), published at ICLR, set out to fix.

What You’ll Learn

  1. Knowledge distillation from the ground up

  2. Adding “on-policy” to distillation

  3. Generalized knowledge distillation (GKD)

  4. Combining GKD with reinforcement learning

  5. How on-policy distillation outperforms

  6. Why this all matters

What’s Helpful to Know

  • Auto-Regressive Generation

    • A model that produces one token at a time, conditioning each new token on everything generated so far. They are the core generation method of language models like GPT and T5.

  • KL Divergence

    • A measure of how different two probability distributions are. For discrete distributions P and Q:

       

       

    • KL divergence is not symmetric. DKL(P || Q) is not DKL(Q || P). This asymmetry turns out to matter a great deal for distillation.

  • Forward vs. Reverse KL

    • DKL(P || Q) is called the forward KL. Minimizing it forces Q to cover all the regions where P has mass. It is mode-covering.

    • DKL(Q || P) is called the reverse KL. Minimizing it forces Q to concentrate only where P is large. It is mode-seeking.

  • Jensen-Shannon Divergence (JSD)

    • A symmetric and bounded divergence that interpolates between forward and reverse KL. For a parameter β between (0,1):

       

       

    • Here, M = βP + (1-β)Q. When β approaches 0, JSD behaves like forward KL. When β approaches 1, it behaves like reverse KL.

  • Imitation Learning

    • A family of techniques where an agent learns a policy by mimicking an expert. The key challenge is that mistakes during deployment can compound in ways the expert never demonstrated.

Subscribe to keep reading

This content is free, but you must be subscribed to AI, But Simple to continue reading.

Already a subscriber?Sign in.Not now