Multi-Token Prediction, Simply Explained

AI, But Simple Issue #104

Multi-Token Prediction, Simply Explained

AI, But Simple Issue #104

When humans read text, we don't only process the words one at a time, pausing to guess the next word before continuing. 

We read ahead, absorbing chunks of words at a time, anticipating where the thought is heading before it arrives, forming a broader sense of direction as we read.

This "reading ahead" is more than a speed trick, as it helps you understand the deeper structure of language and its presentation over a span of language.

If you’re familiar with large language models (LLMs), you’ll know that the dominant training paradigm has actually been the opposite: perform next-token prediction (NTP) by generating a single token at a time until completion.

In April 2024, researchers at Meta's FAIR lab published a paper, "Better & Faster Large Language Models via Multi-token Prediction," proposing an inference approach where LLMs adopt this human-like reading technique.

By training models to simultaneously predict multiple future tokens at every position in a sequence, they unlocked meaningful gains in overall performance, efficiency, and inference speed.

Incredibly, a deeper semantic and contextual understanding came without increasing the training cost.

What You’ll Learn

  1. The issue with next-token prediction

  2. The mechanism behind multi-token prediction

  3. Multi-token prediction’s memory problem

  4. Self-speculative decoding, a new technique

  5. Empirical results of multi-token prediction

What’s Helpful to Know

  • Tokens

    • Portions of an input fed into a transformer model.

  • Vocabulary

    • The total number of unique tokens a model recognizes.

  • Next-Token Prediction (NTP)

    • The standard autoregressive training objective, calculating the conditional probability for the next token.

  • Logits

    • Raw, unnormalized network outputs. Models must generate a logit for every token in the vocabulary.

  • Latent Representation

    • A compressed vector representation produced by the transformer's hidden layers that encodes the context of a sequence.

Subscribe to keep reading

This content is free, but you must be subscribed to AI, But Simple to continue reading.

Already a subscriber?Sign in.Not now