Back to Skills

optimizing-attention-flash

davila7
Updated Today
188 views
18,478
1,685
18,478
View on GitHub
OtherOptimizationFlash AttentionAttention OptimizationMemory EfficiencySpeed OptimizationLong ContextPyTorchSDPAH100FP8Transformers

About

This skill implements Flash Attention to optimize transformer models, providing 2-4x speedups and 10-20x memory reductions for long sequences (>512 tokens). It's ideal when you encounter GPU memory issues or need faster inference with transformers. The skill supports multiple backends including PyTorch's native SDPA, the flash-attn library, and H100 FP8 acceleration.

Quick Install

Claude Code

Recommended
Primary
npx skills add davila7/claude-code-templates
Plugin CommandAlternative
/plugin add https://github.com/davila7/claude-code-templates
Git CloneAlternative
git clone https://github.com/davila7/claude-code-templates.git ~/.claude/skills/optimizing-attention-flash

Copy and paste this command in Claude Code to install this skill

Documentation

Flash Attention - Fast Memory-Efficient Attention

Quick start

Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.

PyTorch native (easiest, PyTorch 2.2+):

import torch
import torch.nn.functional as F

q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)  # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)

# Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)

flash-attn library (more features):

pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func

# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)

Common workflows

Workflow 1: Enable in existing PyTorch model

Copy this checklist:

Flash Attention Integration:
- [ ] Step 1: Check PyTorch version (≥2.2)
- [ ] Step 2: Enable Flash Attention backend
- [ ] Step 3: Verify speedup with profiling
- [ ] Step 4: Test accuracy matches baseline

Step 1: Check PyTorch version

python -c "import torch; print(torch.__version__)"
# Should be ≥2.2.0

If <2.2, upgrade:

pip install --upgrade torch

Step 2: Enable Flash Attention backend

Replace standard attention:

# Before (standard attention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v

# After (Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)

Force Flash Attention backend:

with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    out = F.scaled_dot_product_attention(q, k, v)

Step 3: Verify speedup with profiling

import torch.utils.benchmark as benchmark

def test_attention(use_flash):
    q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

    if use_flash:
        with torch.backends.cuda.sdp_kernel(enable_flash=True):
            return F.scaled_dot_product_attention(q, k, v)
    else:
        attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
        return attn @ v

# Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())

print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")

Expected: 2-4x speedup for sequences >512 tokens.

Step 4: Test accuracy matches baseline

# Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

# Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)

# Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v

# Check difference
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
# Should be <1e-3 for float16

Workflow 2: Use flash-attn library for advanced features

For multi-query attention, sliding window, or H100 FP8.

Copy this checklist:

flash-attn Library Setup:
- [ ] Step 1: Install flash-attn library
- [ ] Step 2: Modify attention code
- [ ] Step 3: Enable advanced features
- [ ] Step 4: Benchmark performance

Step 1: Install flash-attn library

# NVIDIA GPUs (CUDA 12.0+)
pip install flash-attn --no-build-isolation

# Verify installation
python -c "from flash_attn import flash_attn_func; print('Success')"

Step 2: Modify attention code

from flash_attn import flash_attn_func

# Input: [batch_size, seq_len, num_heads, head_dim]
# Transpose from [batch, heads, seq, dim] if needed
q = q.transpose(1, 2)  # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)

out = flash_attn_func(
    q, k, v,
    dropout_p=0.1,
    causal=True,  # For autoregressive models
    window_size=(-1, -1),  # No sliding window
    softmax_scale=None  # Auto-scale
)

out = out.transpose(1, 2)  # Back to [batch, heads, seq, dim]

Step 3: Enable advanced features

Multi-query attention (shared K/V across heads):

from flash_attn import flash_attn_func

# q: [batch, seq, num_q_heads, dim]
# k, v: [batch, seq, num_kv_heads, dim]  # Fewer KV heads
out = flash_attn_func(q, k, v)  # Automatically handles MQA

Sliding window attention (local attention):

# Only attend to window of 256 tokens before/after
out = flash_attn_func(
    q, k, v,
    window_size=(256, 256),  # (left, right) window
    causal=True
)

Step 4: Benchmark performance

import torch
from flash_attn import flash_attn_func
import time

q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]

# Warmup
for _ in range(10):
    _ = flash_attn_func(q, k, v)

# Benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
    out = flash_attn_func(q, k, v)
    torch.cuda.synchronize()
end = time.time()

print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")

Workflow 3: H100 FP8 optimization (FlashAttention-3)

For maximum performance on H100 GPUs.

FP8 Setup:
- [ ] Step 1: Verify H100 GPU available
- [ ] Step 2: Install flash-attn with FP8 support
- [ ] Step 3: Convert inputs to FP8
- [ ] Step 4: Run with FP8 attention

Step 1: Verify H100 GPU

nvidia-smi --query-gpu=name --format=csv
# Should show "H100" or "H800"

Step 2: Install flash-attn with FP8 support

pip install flash-attn --no-build-isolation
# FP8 support included for H100

Step 3: Convert inputs to FP8

import torch

q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)

# Convert to float8_e4m3 (FP8)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)

Step 4: Run with FP8 attention

from flash_attn import flash_attn_func

# FlashAttention-3 automatically uses FP8 kernels on H100
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16

When to use vs alternatives

Use Flash Attention when:

  • Training transformers with sequences >512 tokens
  • Running inference with long context (>2K tokens)
  • GPU memory constrained (OOM with standard attention)
  • Need 2-4x speedup without accuracy loss
  • Using PyTorch 2.2+ or can install flash-attn

Use alternatives instead:

  • Standard attention: Sequences <256 tokens (overhead not worth it)
  • xFormers: Need more attention variants (not just speed)
  • Memory-efficient attention: CPU inference (Flash Attention needs GPU)

Common issues

Issue: ImportError: cannot import flash_attn

Install with no-build-isolation flag:

pip install flash-attn --no-build-isolation

Or install CUDA toolkit first:

conda install cuda -c nvidia
pip install flash-attn --no-build-isolation

Issue: Slower than expected (no speedup)

Flash Attention benefits increase with sequence length:

  • <512 tokens: Minimal speedup (10-20%)
  • 512-2K tokens: 2-3x speedup
  • 2K tokens: 3-4x speedup

Check sequence length is sufficient.

Issue: RuntimeError: CUDA error

Verify GPU supports Flash Attention:

import torch
print(torch.cuda.get_device_capability())
# Should be ≥(7, 5) for Turing+

Flash Attention requires:

  • Ampere (A100, A10): ✅ Full support
  • Turing (T4): ✅ Supported
  • Volta (V100): ❌ Not supported

Issue: Accuracy degradation

Check dtype is float16 or bfloat16 (not float32):

q = q.to(torch.float16)  # Or torch.bfloat16

Flash Attention uses float16/bfloat16 for speed. Float32 not supported.

Advanced topics

Integration with HuggingFace Transformers: See references/transformers-integration.md for enabling Flash Attention in BERT, GPT, Llama models.

Performance benchmarks: See references/benchmarks.md for detailed speed and memory comparisons across GPUs and sequence lengths.

Algorithm details: See references/algorithm.md for tiling strategy, recomputation, and IO complexity analysis.

Advanced features: See references/advanced-features.md for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.

Hardware requirements

  • GPU: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
  • VRAM: Same as standard attention (Flash Attention doesn't increase memory)
  • CUDA: 12.0+ (11.8 minimum)
  • PyTorch: 2.2+ for native support

Not supported: V100 (Volta), CPU inference

Resources

GitHub Repository

davila7/claude-code-templates
Path: cli-tool/components/skills/ai-research/optimization-flash-attention
anthropicanthropic-claudeclaudeclaude-code

Related Skills

quantizing-models-bitsandbytes

Other

This skill quantizes LLMs to 8-bit or 4-bit precision using bitsandbytes, achieving 50-75% memory reduction with minimal accuracy loss. It's ideal for running larger models on limited GPU memory or accelerating inference, supporting formats like INT8, NF4, and FP4. The skill integrates with HuggingFace Transformers and enables QLoRA training and 8-bit optimizers.

View skill

gguf-quantization

Design

This skill enables GGUF quantization for efficient model deployment on consumer hardware like CPUs and Apple Silicon. It provides flexible 2-8 bit quantization options without requiring GPU acceleration. Use it when optimizing models for local inference tools or resource-constrained environments.

View skill

awq-quantization

Other

AWQ is a 4-bit weight quantization technique that uses activation patterns to preserve critical weights, enabling 3x faster inference with minimal accuracy loss. It's ideal for deploying large models (7B-70B) on limited GPU memory and is particularly effective for instruction-tuned and multimodal models. This skill integrates with vLLM and Marlin kernels for optimized deployment.

View skill

weights-and-biases

Design

This skill integrates Weights & Biases for comprehensive ML experiment tracking and MLOps. It automatically logs metrics, visualizes training in real-time, and manages hyperparameter sweeps and model versions. Use it to compare runs, optimize models, and collaborate within team workspaces directly from your development environment.

View skill