Back to Skills

speculative-decoding

davila7
Updated Today
31 views
15,516
1,344
15,516
View on GitHub
MetaEmerging TechniquesSpeculative DecodingMedusaLookahead DecodingFast InferenceDraft ModelsTree AttentionParallel GenerationLatency ReductionInference Optimization

About

This Claude Skill accelerates LLM inference using speculative decoding techniques like Medusa and lookahead decoding, achieving 1.5-3.6× speedups. It's designed for developers optimizing latency in real-time applications or deploying models on limited compute. The implementation covers draft models, tree-based attention, and parallel token generation strategies.

Quick Install

Claude Code

Recommended
Plugin CommandRecommended
/plugin add https://github.com/davila7/claude-code-templates
Git CloneAlternative
git clone https://github.com/davila7/claude-code-templates.git ~/.claude/skills/speculative-decoding

Copy and paste this command in Claude Code to install this skill

Documentation

Speculative Decoding: Accelerating LLM Inference

When to Use This Skill

Use Speculative Decoding when you need to:

  • Speed up inference by 1.5-3.6× without quality loss
  • Reduce latency for real-time applications (chatbots, code generation)
  • Optimize throughput for high-volume serving
  • Deploy efficiently on limited hardware
  • Generate faster without changing model architecture

Key Techniques: Draft model speculative decoding, Medusa (multiple heads), Lookahead Decoding (Jacobi iteration)

Papers: Medusa (arXiv 2401.10774), Lookahead Decoding (ICML 2024), Speculative Decoding Survey (ACL 2024)

Installation

# Standard speculative decoding (transformers)
pip install transformers accelerate

# Medusa (multiple decoding heads)
git clone https://github.com/FasterDecoding/Medusa
cd Medusa
pip install -e .

# Lookahead Decoding
git clone https://github.com/hao-ai-lab/LookaheadDecoding
cd LookaheadDecoding
pip install -e .

# Optional: vLLM with speculative decoding
pip install vllm

Quick Start

Basic Speculative Decoding (Draft Model)

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load target model (large, slow)
target_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",
    device_map="auto",
    torch_dtype=torch.float16
)

# Load draft model (small, fast)
draft_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    device_map="auto",
    torch_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")

# Generate with speculative decoding
prompt = "Explain quantum computing in simple terms:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# Transformers 4.36+ supports assisted generation
outputs = target_model.generate(
    **inputs,
    assistant_model=draft_model,  # Enable speculative decoding
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

Medusa (Multiple Decoding Heads)

from medusa.model.medusa_model import MedusaModel

# Load Medusa-enhanced model
model = MedusaModel.from_pretrained(
    "FasterDecoding/medusa-vicuna-7b-v1.3",  # Pre-trained with Medusa heads
    torch_dtype=torch.float16,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained("FasterDecoding/medusa-vicuna-7b-v1.3")

# Generate with Medusa (2-3× speedup)
prompt = "Write a Python function to calculate fibonacci numbers:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.medusa_generate(
    **inputs,
    max_new_tokens=256,
    temperature=0.7,
    posterior_threshold=0.09,  # Acceptance threshold
    posterior_alpha=0.3,       # Tree construction parameter
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)

Lookahead Decoding (Jacobi Iteration)

from lookahead.lookahead_decoding import LookaheadDecoding

# Load model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.float16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# Initialize lookahead decoding
lookahead = LookaheadDecoding(
    model=model,
    tokenizer=tokenizer,
    window_size=15,    # Lookahead window (W)
    ngram_size=5,      # N-gram size (N)
    guess_size=5       # Number of parallel guesses
)

# Generate (1.5-2.3× speedup)
prompt = "Implement quicksort in Python:"
output = lookahead.generate(prompt, max_new_tokens=256)
print(output)

Core Concepts

1. Speculative Decoding (Draft Model)

Idea: Use small draft model to generate candidates, large target model to verify in parallel.

Algorithm:

  1. Draft model generates K tokens speculatively
  2. Target model evaluates all K tokens in parallel (single forward pass)
  3. Accept tokens where draft and target agree
  4. Reject first disagreement, continue from there
def speculative_decode(target_model, draft_model, prompt, K=4):
    """Speculative decoding algorithm."""
    # 1. Generate K draft tokens
    draft_tokens = draft_model.generate(prompt, max_new_tokens=K)

    # 2. Target model evaluates all K tokens in one forward pass
    target_logits = target_model(draft_tokens)  # Parallel!

    # 3. Accept/reject based on probability match
    accepted = []
    for i in range(K):
        p_draft = softmax(draft_model.logits[i])
        p_target = softmax(target_logits[i])

        # Acceptance probability
        if random.random() < min(1, p_target[draft_tokens[i]] / p_draft[draft_tokens[i]]):
            accepted.append(draft_tokens[i])
        else:
            break  # Reject, resample from target

    return accepted

Performance:

  • Speedup: 1.5-2× with good draft model
  • Zero quality loss (mathematically equivalent to target model)
  • Best when draft model is 5-10× smaller than target

2. Medusa (Multiple Decoding Heads)

Source: arXiv 2401.10774 (2024)

Innovation: Add multiple prediction heads to existing model, predict future tokens without separate draft model.

Architecture:

Input → Base LLM (frozen) → Hidden State
                                ├→ Head 1 (predicts token t+1)
                                ├→ Head 2 (predicts token t+2)
                                ├→ Head 3 (predicts token t+3)
                                └→ Head 4 (predicts token t+4)

Training:

  • Medusa-1: Freeze base LLM, train only heads
    • 2.2× speedup, lossless
  • Medusa-2: Fine-tune base LLM + heads together
    • 2.3-3.6× speedup, better quality

Tree-based Attention:

# Medusa constructs tree of candidates
# Example: Predict 2 steps ahead with top-2 per step

#         Root
#        /    \
#      T1a    T1b  (Step 1: 2 candidates)
#     /  \    / \
#  T2a  T2b T2c T2d  (Step 2: 4 candidates total)

# Single forward pass evaluates entire tree!

Advantages:

  • No separate draft model needed
  • Minimal training (only heads)
  • Compatible with any LLM

3. Lookahead Decoding (Jacobi Iteration)

Source: ICML 2024

Core idea: Reformulate autoregressive decoding as solving system of equations, solve in parallel using Jacobi iteration.

Mathematical formulation:

Traditional:  y_t = f(x, y_1, ..., y_{t-1})  (sequential)
Jacobi:       y_t^{(k+1)} = f(x, y_1^{(k)}, ..., y_{t-1}^{(k)})  (parallel)

Two branches:

  1. Lookahead Branch: Generate n-grams in parallel

    • Window size W: How many steps to look ahead
    • N-gram size N: How many past tokens to use
  2. Verification Branch: Verify promising n-grams

    • Match n-grams with generated tokens
    • Accept if first token matches
class LookaheadDecoding:
    def __init__(self, model, window_size=15, ngram_size=5):
        self.model = model
        self.W = window_size  # Lookahead window
        self.N = ngram_size   # N-gram size

    def generate_step(self, tokens):
        # Lookahead branch: Generate W × N candidates
        candidates = {}
        for w in range(1, self.W + 1):
            for n in range(1, self.N + 1):
                # Generate n-gram starting at position w
                ngram = self.generate_ngram(tokens, start=w, length=n)
                candidates[(w, n)] = ngram

        # Verification branch: Find matching n-grams
        verified = []
        for ngram in candidates.values():
            if ngram[0] == tokens[-1]:  # First token matches last input
                if self.verify(tokens, ngram):
                    verified.append(ngram)

        # Accept longest verified n-gram
        return max(verified, key=len) if verified else [self.model.generate_next(tokens)]

Performance:

  • Speedup: 1.5-2.3× (up to 3.6× for code generation)
  • No draft model or training needed
  • Works out-of-the-box with any model

Method Comparison

MethodSpeedupTraining NeededDraft ModelQuality Loss
Draft Model Speculative1.5-2×NoYes (external)None
Medusa2-3.6×Minimal (heads only)No (built-in heads)None
Lookahead1.5-2.3×NoneNoNone
Naive Batching1.2-1.5×NoNoNone

Advanced Patterns

Training Medusa Heads

from medusa.model.medusa_model import MedusaModel
from medusa.model.kv_cache import initialize_past_key_values
import torch.nn as nn

# 1. Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "lmsys/vicuna-7b-v1.3",
    torch_dtype=torch.float16
)

# 2. Add Medusa heads
num_heads = 4
medusa_heads = nn.ModuleList([
    nn.Linear(base_model.config.hidden_size, base_model.config.vocab_size, bias=False)
    for _ in range(num_heads)
])

# 3. Training loop (freeze base model for Medusa-1)
for param in base_model.parameters():
    param.requires_grad = False  # Freeze base

optimizer = torch.optim.Adam(medusa_heads.parameters(), lr=1e-3)

for batch in dataloader:
    # Forward pass
    hidden_states = base_model(**batch, output_hidden_states=True).hidden_states[-1]

    # Predict future tokens with each head
    loss = 0
    for i, head in enumerate(medusa_heads):
        logits = head(hidden_states)
        # Target: tokens shifted by (i+1) positions
        target = batch['input_ids'][:, i+1:]
        loss += F.cross_entropy(logits[:, :-i-1], target)

    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Hybrid: Speculative + Medusa

# Use Medusa as draft model for speculative decoding
draft_medusa = MedusaModel.from_pretrained("medusa-vicuna-7b")
target_model = AutoModelForCausalLM.from_pretrained("vicuna-33b")

# Draft generates multiple candidates with Medusa
draft_tokens = draft_medusa.medusa_generate(prompt, max_new_tokens=5)

# Target verifies in single forward pass
outputs = target_model.generate(
    prompt,
    assistant_model=draft_medusa,  # Use Medusa as draft
    max_new_tokens=256
)

# Combines benefits: Medusa speed + large model quality

Optimal Draft Model Selection

def select_draft_model(target_model_size, target):
    """Select optimal draft model for speculative decoding."""
    # Rule: Draft should be 5-10× smaller
    if target_model_size == "70B":
        return "7B"  # 10× smaller
    elif target_model_size == "33B":
        return "7B"  # 5× smaller
    elif target_model_size == "13B":
        return "1B"  # 13× smaller
    else:
        return None  # Target too small, use Medusa/Lookahead instead

# Example
draft = select_draft_model("70B", target_model)
# Returns "7B" → Use Llama-2-7b as draft for Llama-2-70b

Best Practices

1. Choose the Right Method

# New deployment → Medusa (best overall speedup, no draft model)
if deploying_new_model:
    use_method = "Medusa"

# Existing deployment with small model available → Draft speculative
elif have_small_version_of_model:
    use_method = "Draft Model Speculative"

# Want zero training/setup → Lookahead
elif want_plug_and_play:
    use_method = "Lookahead Decoding"

2. Hyperparameter Tuning

Draft Model Speculative:

# K = number of speculative tokens
K = 4  # Good default
K = 2  # Conservative (higher acceptance)
K = 8  # Aggressive (lower acceptance, but more when accepted)

# Rule: Larger K → more speedup IF draft model is good

Medusa:

# Posterior threshold (acceptance confidence)
posterior_threshold = 0.09  # Standard (from paper)
posterior_threshold = 0.05  # More conservative (slower, higher quality)
posterior_threshold = 0.15  # More aggressive (faster, may degrade quality)

# Tree depth (how many steps ahead)
medusa_choices = [[0], [0, 0], [0, 1], [0, 0, 0]]  # Depth 3 (standard)

Lookahead:

# Window size W (lookahead distance)
# N-gram size N (context for generation)

# 7B model (more resources)
W, N = 15, 5

# 13B model (moderate)
W, N = 10, 5

# 33B+ model (limited resources)
W, N = 7, 5

3. Production Deployment

# vLLM with speculative decoding
from vllm import LLM, SamplingParams

# Initialize with draft model
llm = LLM(
    model="meta-llama/Llama-2-70b-hf",
    speculative_model="meta-llama/Llama-2-7b-hf",  # Draft model
    num_speculative_tokens=5,
    use_v2_block_manager=True,
)

# Generate
prompts = ["Tell me about AI:", "Explain quantum physics:"]
sampling_params = SamplingParams(temperature=0.7, max_tokens=256)

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    print(output.outputs[0].text)

Resources

See Also

  • references/draft_model.md - Draft model selection and training
  • references/medusa.md - Medusa architecture and training
  • references/lookahead.md - Lookahead decoding implementation details

GitHub Repository

davila7/claude-code-templates
Path: cli-tool/components/skills/ai-research/emerging-techniques-speculative-decoding
anthropicanthropic-claudeclaudeclaude-code

Related Skills

knowledge-distillation

Other

This skill enables model compression via knowledge distillation, transferring capabilities from larger teacher models to smaller student models. It's ideal for deploying cost-efficient models while retaining performance or transferring GPT-4-level abilities to open-source alternatives. Key techniques include temperature scaling, soft targets, and logit distillation.

View skill

sglang

Meta

SGLang is a high-performance LLM serving framework that enables fast structured generation with JSON/regex outputs and constrained decoding. It's ideal for agentic workflows with tool calls and multi-turn conversations, offering significantly faster inference through RadixAttention prefix caching. Use it when you need production-scale performance with shared context across requests.

View skill

model-pruning

Other

This skill compresses large language models using pruning techniques like Wanda and SparseGPT to reduce model size by 40-60% and accelerate inference 2-4× with minimal accuracy loss. It enables one-shot compression without retraining and supports various sparsity patterns including N:M for hardware acceleration. Use it when deploying models on constrained hardware or needing faster inference with maintained performance.

View skill

tensorrt-llm

Other

TensorRT-LLM is an NVIDIA-optimized library for deploying LLMs on NVIDIA GPUs, delivering up to 100x faster inference than PyTorch. Use it for production serving where you need maximum throughput, low latency, and support for features like quantization (FP8/INT4), in-flight batching, and multi-GPU scaling.

View skill