Optimizing Attention Flash
Optimizing Attention Flash
Section titled “Optimizing Attention Flash”Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
Skill metadata
Section titled “Skill metadata”| Source | Optional — install with hermes skills install official/mlops/flash-attention |
| Path | optional-skills/mlops/flash-attention |
| Version | 1.0.0 |
| Author | Orchestra Research |
| License | MIT |
| Dependencies | flash-attn, torch, transformers |
| Tags | Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers |
Reference: full SKILL.md
Section titled “Reference: full SKILL.md”The following is the complete skill definition that Hermes loads when this skill is triggered. This is what the agent sees as instructions when the skill is active.
Flash Attention - Fast Memory-Efficient Attention
Section titled “Flash Attention - Fast Memory-Efficient Attention”Quick start
Section titled “Quick start”Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
PyTorch native (easiest, PyTorch 2.2+):
import torchimport torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
# Automatically uses Flash Attention if availableout = F.scaled_dot_product_attention(q, k, v)flash-attn library (more features):
pip install flash-attn --no-build-isolationfrom flash_attn import flash_attn_func
# q, k, v: [batch, seqlen, nheads, headdim]out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)Common workflows
Section titled “Common workflows”Workflow 1: Enable in existing PyTorch model
Section titled “Workflow 1: Enable in existing PyTorch model”Copy this checklist:
Flash Attention Integration:- [ ] Step 1: Check PyTorch version (≥2.2)- [ ] Step 2: Enable Flash Attention backend- [ ] Step 3: Verify speedup with profiling- [ ] Step 4: Test accuracy matches baselineStep 1: Check PyTorch version
python -c "import torch; print(torch.__version__)"# Should be ≥2.2.0If <2.2, upgrade:
pip install --upgrade torchStep 2: Enable Flash Attention backend
Replace standard attention:
# Before (standard attention)attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)out = attn_weights @ v
# After (Flash Attention)import torch.nn.functional as Fout = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)Force Flash Attention backend:
with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=False): out = F.scaled_dot_product_attention(q, k, v)Step 3: Verify speedup with profiling
import torch.utils.benchmark as benchmark
def test_attention(use_flash): q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash: with torch.backends.cuda.sdp_kernel(enable_flash=True): return F.scaled_dot_product_attention(q, k, v) else: attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1) return attn @ v
# Benchmarkt_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")print(f"Standard: {t_standard.timeit(100).mean:.3f}s")Expected: 2-4x speedup for sequences >512 tokens.
Step 4: Test accuracy matches baseline
# Compare outputsq, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Flash Attentionout_flash = F.scaled_dot_product_attention(q, k, v)
# Standard attentionattn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)out_standard = attn_weights @ v
# Check differencediff = (out_flash - out_standard).abs().max()print(f"Max difference: {diff:.6f}")# Should be <1e-3 for float16Workflow 2: Use flash-attn library for advanced features
Section titled “Workflow 2: Use flash-attn library for advanced features”For multi-query attention, sliding window, or H100 FP8.
Copy this checklist:
flash-attn Library Setup:- [ ] Step 1: Install flash-attn library- [ ] Step 2: Modify attention code- [ ] Step 3: Enable advanced features- [ ] Step 4: Benchmark performanceStep 1: Install flash-attn library
# NVIDIA GPUs (CUDA 12.0+)pip install flash-attn --no-build-isolation
# Verify installationpython -c "from flash_attn import flash_attn_func; print('Success')"Step 2: Modify attention code
from flash_attn import flash_attn_func
# Input: [batch_size, seq_len, num_heads, head_dim]# Transpose from [batch, heads, seq, dim] if neededq = q.transpose(1, 2) # [batch, seq, heads, dim]k = k.transpose(1, 2)v = v.transpose(1, 2)
out = flash_attn_func( q, k, v, dropout_p=0.1, causal=True, # For autoregressive models window_size=(-1, -1), # No sliding window softmax_scale=None # Auto-scale)
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]Step 3: Enable advanced features
Multi-query attention (shared K/V across heads):
from flash_attn import flash_attn_func
# q: [batch, seq, num_q_heads, dim]# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV headsout = flash_attn_func(q, k, v) # Automatically handles MQASliding window attention (local attention):
# Only attend to window of 256 tokens before/afterout = flash_attn_func( q, k, v, window_size=(256, 256), # (left, right) window causal=True)Step 4: Benchmark performance
import torchfrom flash_attn import flash_attn_funcimport time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Warmupfor _ in range(10): _ = flash_attn_func(q, k, v)
# Benchmarktorch.cuda.synchronize()start = time.time()for _ in range(100): out = flash_attn_func(q, k, v) torch.cuda.synchronize()end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")Workflow 3: H100 FP8 optimization (FlashAttention-3)
Section titled “Workflow 3: H100 FP8 optimization (FlashAttention-3)”For maximum performance on H100 GPUs.
FP8 Setup:- [ ] Step 1: Verify H100 GPU available- [ ] Step 2: Install flash-attn with FP8 support- [ ] Step 3: Convert inputs to FP8- [ ] Step 4: Run with FP8 attentionStep 1: Verify H100 GPU
nvidia-smi --query-gpu=name --format=csv# Should show "H100" or "H800"Step 2: Install flash-attn with FP8 support
pip install flash-attn --no-build-isolation# FP8 support included for H100Step 3: Convert inputs to FP8
import torch
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
# Convert to float8_e4m3 (FP8)q_fp8 = q.to(torch.float8_e4m3fn)k_fp8 = k.to(torch.float8_e4m3fn)v_fp8 = v.to(torch.float8_e4m3fn)Step 4: Run with FP8 attention
from flash_attn import flash_attn_func
# FlashAttention-3 automatically uses FP8 kernels on H100out = flash_attn_func(q_fp8, k_fp8, v_fp8)# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16When to use vs alternatives
Section titled “When to use vs alternatives”Use Flash Attention when:
- Training transformers with sequences >512 tokens
- Running inference with long context (>2K tokens)
- GPU memory constrained (OOM with standard attention)
- Need 2-4x speedup without accuracy loss
- Using PyTorch 2.2+ or can install flash-attn
Use alternatives instead:
- Standard attention: Sequences <256 tokens (overhead not worth it)
- xFormers: Need more attention variants (not just speed)
- Memory-efficient attention: CPU inference (Flash Attention needs GPU)
Common issues
Section titled “Common issues”Issue: ImportError: cannot import flash_attn
Install with no-build-isolation flag:
pip install flash-attn --no-build-isolationOr install CUDA toolkit first:
conda install cuda -c nvidiapip install flash-attn --no-build-isolationIssue: Slower than expected (no speedup)
Flash Attention benefits increase with sequence length:
- <512 tokens: Minimal speedup (10-20%)
- 512-2K tokens: 2-3x speedup
-
2K tokens: 3-4x speedup
Check sequence length is sufficient.
Issue: RuntimeError: CUDA error
Verify GPU supports Flash Attention:
import torchprint(torch.cuda.get_device_capability())# Should be ≥(7, 5) for Turing+Flash Attention requires:
- Ampere (A100, A10): ✅ Full support
- Turing (T4): ✅ Supported
- Volta (V100): ❌ Not supported
Issue: Accuracy degradation
Check dtype is float16 or bfloat16 (not float32):
q = q.to(torch.float16) # Or torch.bfloat16Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
Advanced topics
Section titled “Advanced topics”Integration with HuggingFace Transformers: See references/transformers-integration.md for enabling Flash Attention in BERT, GPT, Llama models.
Performance benchmarks: See references/benchmarks.md for detailed speed and memory comparisons across GPUs and sequence lengths.
Algorithm details: See references/algorithm.md for tiling strategy, recomputation, and IO complexity analysis.
Advanced features: See references/advanced-features.md for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
Hardware requirements
Section titled “Hardware requirements”- GPU: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
- VRAM: Same as standard attention (Flash Attention doesn’t increase memory)
- CUDA: 12.0+ (11.8 minimum)
- PyTorch: 2.2+ for native support
Not supported: V100 (Volta), CPU inference
Resources
Section titled “Resources”- Paper: “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” (NeurIPS 2022)
- Paper: “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning” (ICLR 2024)
- Blog: https://tridao.me/blog/2024/flash3/
- GitHub: https://github.com/Dao-AILab/flash-attention
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html