Attention Mechanism: From Math to GPU
A deep dive into the attention mechanism that powers transformers. We'll go from the intuition to the math, implement it three ways (single-head, multi-head loop, multi-head vectorized), and understand exactly how it maps to GPU memory and compute.
- The Intuition: What Problem Does Attention Solve?
- The Math: Scaled Dot-Product Attention
- Single-Head Attention Implementation
- Multi-Head Attention: Loop-Based
- Multi-Head Attention: Vectorized
- How Attention Maps to the GPU
- Memory Analysis: The Quadratic Problem
- FlashAttention: Solving the Memory Bottleneck
- Conclusion
The Intuition: What Problem Does Attention Solve?
Before transformers, sequence models like RNNs processed tokens one at a time, carrying information through a hidden state. This created two problems: the hidden state became a bottleneck (all information had to squeeze through it), and long-range dependencies were hard to learn (gradients vanished or exploded over many steps).
Attention solves both. Instead of processing sequentially, it lets every token directly look at every other token and decide which ones are relevant. No bottleneck, no vanishing gradients over distance.
Click a token to see what it attends to. Brighter = higher attention weight.
The key insight: attention computes a weighted average of all tokens, where the weights are learned based on relevance. When encoding "sat", the model might attend strongly to "cat" (the subject) and "mat" (the location), while ignoring "the" (not informative).
This is done through three projections of each token:
- Query (Q) — "What am I looking for?"
- Key (K) — "What do I contain?"
- Value (V) — "What information do I provide?"
The query of one token is compared against the keys of all tokens to compute attention scores. These scores determine how much of each token's value is included in the output.
The Math: Scaled Dot-Product Attention
The attention function is elegantly simple:
Let's break this down step by step:
1. Compute similarity scores: QKT
The dot product between query and key vectors measures similarity. If Q and K point in the same direction, their dot product is large (high attention). If orthogonal, it's zero (no attention).
2. Scale by √dk:
This is crucial and often glossed over. Without scaling, the dot products grow with dimension size. For high-dimensional vectors, the dot products become very large, pushing softmax into regions where gradients vanish (all attention on one token, zero on others).
If Q and K have elements drawn from a distribution with variance 1, their dot product has variance dk. Dividing by √dk normalizes the variance back to 1, keeping softmax in a well-behaved region regardless of dimension.
3. Apply softmax:
Softmax converts raw scores to a probability distribution (sums to 1, all positive). This makes the attention weights interpretable and ensures the output is a proper weighted average of values.
4. Weighted sum of values:
Finally, we compute the weighted sum of value vectors. Each token's output is a blend of all values, weighted by how relevant each key was to its query.
Single-Head Attention Implementation
Let's implement this in PyTorch. I'll annotate every line with tensor shapes—this is critical for understanding and debugging attention code.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class AttentionHead(nn.Module):
"""
Single Attention Head
Computes scaled dot-product attention:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
"""
def __init__(self, embed_dim: int, head_dim: int) -> None:
super().__init__()
self.head_dim = head_dim
self.scale = 1.0 / math.sqrt(head_dim) # Precompute for efficiency
# Linear projections: embed_dim → head_dim
self.W_q = nn.Linear(embed_dim, head_dim, bias=False)
self.W_k = nn.Linear(embed_dim, head_dim, bias=False)
self.W_v = nn.Linear(embed_dim, head_dim, bias=False)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
Args:
x: Input tensor [batch, seq_len, embed_dim]
mask: Optional attention mask
Returns:
Output tensor [batch, seq_len, head_dim]
"""
# Project to Q, K, V
Q = self.W_q(x) # [B, N, embed_dim] @ [embed_dim, head_dim] → [B, N, head_dim]
K = self.W_k(x) # [B, N, head_dim]
V = self.W_v(x) # [B, N, head_dim]
# Compute attention scores
# [B, N, head_dim] @ [B, head_dim, N] → [B, N, N]
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
# Apply mask (for causal attention or padding)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax over last dimension (keys)
attn_weights = F.softmax(scores, dim=-1) # [B, N, N]
# Weighted sum of values
# [B, N, N] @ [B, N, head_dim] → [B, N, head_dim]
output = torch.matmul(attn_weights, V)
return output
The tensor shape journey is:
Where B = batch size, N = sequence length, d = embed_dim, h = head_dim.
Multi-Head Attention: Loop-Based
A single attention head can only capture one type of relationship. But language is rich—we need syntax, semantics, coreference, and more. Multi-head attention runs multiple attention heads in parallel, each learning different patterns.
The naive implementation loops over heads. This is clear and correct, but not GPU-optimal:
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention (Loop-based, for pedagogical clarity)
Runs h attention heads in parallel and concatenates results.
This version explicitly loops over heads—clear but not GPU-optimal.
"""
def __init__(self, embed_dim: int, num_heads: int, head_dim: int = None) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim if head_dim else embed_dim // num_heads
# Create h independent attention heads
self.heads = nn.ModuleList([
AttentionHead(embed_dim, self.head_dim)
for _ in range(num_heads)
])
# Output projection: (num_heads * head_dim) → embed_dim
self.W_o = nn.Linear(num_heads * self.head_dim, embed_dim)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# Run each head (this is the inefficient part—Python loop!)
head_outputs = [head(x, mask) for head in self.heads]
# Concatenate along feature dimension
# List of [B, N, head_dim] → [B, N, num_heads * head_dim]
concat = torch.cat(head_outputs, dim=-1)
# Final projection back to embed_dim
output = self.W_o(concat)
return output
This works, but there's a problem. Each head launches its own CUDA kernels. With 8 heads, that's 8 separate matrix multiplications for Q, K, V projections, plus 8 attention computations. The overhead adds up.
Multi-Head Attention: Vectorized
The vectorized version computes all heads in a single operation. Instead of h separate weight matrices, we use one large matrix that projects to all heads simultaneously, then reshape.
class MultiHeadAttentionVectorized(nn.Module):
"""
Multi-Head Attention (Vectorized / GPU-Optimized)
Key insight: Instead of h separate linear layers, use ONE large linear
layer that computes all heads simultaneously, then reshape.
This is how production implementations (PyTorch, HuggingFace) do it.
"""
def __init__(self, embed_dim: int, num_heads: int, head_dim: int = None) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim if head_dim else embed_dim // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
# Single large projection for all heads
# embed_dim → (num_heads * head_dim) for Q, K, V each
self.W_q = nn.Linear(embed_dim, num_heads * self.head_dim, bias=False)
self.W_k = nn.Linear(embed_dim, num_heads * self.head_dim, bias=False)
self.W_v = nn.Linear(embed_dim, num_heads * self.head_dim, bias=False)
# Output projection
self.W_o = nn.Linear(num_heads * self.head_dim, embed_dim)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
B, N, _ = x.shape
H, D = self.num_heads, self.head_dim
# Project all heads at once
# [B, N, embed_dim] → [B, N, H*D]
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
# Reshape to separate heads: [B, N, H*D] → [B, N, H, D] → [B, H, N, D]
Q = Q.view(B, N, H, D).transpose(1, 2)
K = K.view(B, N, H, D).transpose(1, 2)
V = V.view(B, N, H, D).transpose(1, 2)
# Batched attention computation (all heads in parallel!)
# [B, H, N, D] @ [B, H, D, N] → [B, H, N, N]
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
# [B, H, N, N] @ [B, H, N, D] → [B, H, N, D]
attn_output = torch.matmul(attn_weights, V)
# Merge heads back: [B, H, N, D] → [B, N, H, D] → [B, N, H*D]
attn_output = attn_output.transpose(1, 2).contiguous().view(B, N, H * D)
# Final projection
output = self.W_o(attn_output)
return output
After transpose(), the tensor's memory layout doesn't match its logical layout. The view() operation requires contiguous memory. Calling .contiguous() copies the data into a contiguous block. This is a common gotcha in attention implementations.
How Attention Maps to the GPU
Understanding GPU execution is essential for optimizing attention. Let's trace what happens when the vectorized attention runs.
Loop vs Vectorized: Kernel Launch Comparison
Loop-Based (8 heads)
24 kernel launches for Q/K/V projections
8 kernels for Q @ KT
8 kernels for softmax
8 kernels for attn @ V
+ synchronization overhead
Vectorized
3 kernel launches for Q/K/V
1 batched GEMM for Q @ KT
1 batched softmax
1 batched GEMM for attn @ V
Tensor Core utilization
Each CUDA kernel launch has ~5-10μs overhead. With 48+ kernel launches per attention layer, this adds up in deep transformers.
GPU Memory Hierarchy
The key insight: attention is memory-bound, not compute-bound. The GPU spends more time moving data than doing math. The N×N attention matrix must be written to HBM (slow global memory) between the softmax and the value multiplication. This is the bottleneck FlashAttention addresses.
Tensor Core Utilization
Modern NVIDIA GPUs (Volta and later) have Tensor Cores—specialized hardware for matrix multiply-accumulate operations. They operate on small tiles (e.g., 16×16) and can deliver 8× the throughput of regular CUDA cores.
The vectorized implementation enables Tensor Core usage because:
- Batched matrix multiplications map directly to cuBLAS batched GEMM
- Contiguous memory layout allows efficient tiling
- Dimensions are typically multiples of 8 (required for Tensor Cores)
Memory Analysis: The Quadratic Problem
Attention has a fundamental scaling problem: the N×N attention matrix.
Let's do the math for a typical setup: batch size 32, 32 attention heads, sequence length 4096.
| Tensor | Shape | Memory (FP32) |
|---|---|---|
| Q, K, V (each) | [32, 32, 4096, 64] | 1.07+ GB |
| Attention Scores | [32, 32, 4096, 4096] | 64+ GB (!) |
| Output | [32, 32, 4096, 64] | 1.07+ GB |
The attention scores alone consume most of an H100's memory (80GB). This is why long-context models are so challenging.
FlashAttention: Solving the Memory Bottleneck
FlashAttention (Dao et al., 2022) is one of the most impactful optimizations in modern deep learning. It reduces memory from O(N²) to O(N) while also being faster.
The key insight: never materialize the full N×N attention matrix. Instead, compute attention in tiles that fit in fast SRAM (shared memory), accumulating results on-the-fly.
The Tiling Strategy
FlashAttention processes attention in blocks:
- Load a tile of Q (e.g., 64 rows) into SRAM
- For each tile of K and V:
- Load K tile, compute partial attention scores
- Compute partial softmax (with online normalization)
- Load V tile, accumulate weighted output
- Write final output block to HBM
The attention scores are computed and consumed within SRAM—they never hit slow global memory. This is a 2-4× speedup and enables sequence lengths that would otherwise OOM.
You don't need to implement FlashAttention yourself. PyTorch 2.0+ includes it via torch.nn.functional.scaled_dot_product_attention() with automatic backend selection. For explicit control, use the flash-attention library.
# PyTorch 2.0+ with automatic FlashAttention
import torch.nn.functional as F
# This automatically uses FlashAttention when available
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=mask,
dropout_p=0.0,
is_causal=True # For decoder self-attention
)
Conclusion
We've covered a lot of ground:
- Intuition: Attention is a learned weighted average that lets every token see every other token
- Math: Scaled dot-product attention with √dk normalization to prevent softmax saturation
- Single-head: The basic building block with Q, K, V projections
- Multi-head (loop): Conceptually clear but GPU-inefficient due to kernel launch overhead
- Multi-head (vectorized): Production implementation with batched operations
- GPU mapping: Memory-bound nature, Tensor Core utilization, kernel fusion
- Memory scaling: The O(N²) problem that limits sequence length
- FlashAttention: Tiled computation that keeps attention scores in SRAM
The attention mechanism is deceptively simple—just a weighted average. But the details matter: scaling factors, memory layout, GPU utilization. Understanding these details is the difference between code that works and code that's fast.