// from 2000 GPUs to 384KB of SRAM

Attention Mechanism: From Math to GPU

A deep dive into the attention mechanism that powers transformers. We'll go from the intuition to the math, implement it three ways (single-head, multi-head loop, multi-head vectorized), and understand exactly how it maps to GPU memory and compute.

The Intuition: What Problem Does Attention Solve?

Before transformers, sequence models like RNNs processed tokens one at a time, carrying information through a hidden state. This created two problems: the hidden state became a bottleneck (all information had to squeeze through it), and long-range dependencies were hard to learn (gradients vanished or exploded over many steps).

Attention solves both. Instead of processing sequentially, it lets every token directly look at every other token and decide which ones are relevant. No bottleneck, no vanishing gradients over distance.

Interactive: Attention Weights Visualization

Click a token to see what it attends to. Brighter = higher attention weight.

The key insight: attention computes a weighted average of all tokens, where the weights are learned based on relevance. When encoding "sat", the model might attend strongly to "cat" (the subject) and "mat" (the location), while ignoring "the" (not informative).

This is done through three projections of each token:

The query of one token is compared against the keys of all tokens to compute attention scores. These scores determine how much of each token's value is included in the output.

The Math: Scaled Dot-Product Attention

The attention function is elegantly simple:

Attention(Q, K, V) = softmax(QKT / √dk) · V

Let's break this down step by step:

1. Compute similarity scores: QKT

The dot product between query and key vectors measures similarity. If Q and K point in the same direction, their dot product is large (high attention). If orthogonal, it's zero (no attention).

Step-by-Step: Attention Computation

2. Scale by √dk:

This is crucial and often glossed over. Without scaling, the dot products grow with dimension size. For high-dimensional vectors, the dot products become very large, pushing softmax into regions where gradients vanish (all attention on one token, zero on others).

Why √dk?

If Q and K have elements drawn from a distribution with variance 1, their dot product has variance dk. Dividing by √dk normalizes the variance back to 1, keeping softmax in a well-behaved region regardless of dimension.

3. Apply softmax:

Softmax converts raw scores to a probability distribution (sums to 1, all positive). This makes the attention weights interpretable and ensures the output is a proper weighted average of values.

4. Weighted sum of values:

Finally, we compute the weighted sum of value vectors. Each token's output is a blend of all values, weighted by how relevant each key was to its query.

Single-Head Attention Implementation

Let's implement this in PyTorch. I'll annotate every line with tensor shapes—this is critical for understanding and debugging attention code.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class AttentionHead(nn.Module):
    """
    Single Attention Head
    
    Computes scaled dot-product attention:
        Attention(Q, K, V) = softmax(QK^T / √d_k) V
    """
    def __init__(self, embed_dim: int, head_dim: int) -> None:
        super().__init__()
        
        self.head_dim = head_dim
        self.scale = 1.0 / math.sqrt(head_dim)  # Precompute for efficiency
        
        # Linear projections: embed_dim → head_dim
        self.W_q = nn.Linear(embed_dim, head_dim, bias=False)
        self.W_k = nn.Linear(embed_dim, head_dim, bias=False)
        self.W_v = nn.Linear(embed_dim, head_dim, bias=False)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Args:
            x: Input tensor [batch, seq_len, embed_dim]
            mask: Optional attention mask
        
        Returns:
            Output tensor [batch, seq_len, head_dim]
        """
        # Project to Q, K, V
        Q = self.W_q(x)  # [B, N, embed_dim] @ [embed_dim, head_dim] → [B, N, head_dim]
        K = self.W_k(x)  # [B, N, head_dim]
        V = self.W_v(x)  # [B, N, head_dim]
        
        # Compute attention scores
        # [B, N, head_dim] @ [B, head_dim, N] → [B, N, N]
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        
        # Apply mask (for causal attention or padding)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax over last dimension (keys)
        attn_weights = F.softmax(scores, dim=-1)  # [B, N, N]
        
        # Weighted sum of values
        # [B, N, N] @ [B, N, head_dim] → [B, N, head_dim]
        output = torch.matmul(attn_weights, V)
        
        return output

The tensor shape journey is:

[B, N, d] W_q [B, N, h] @ KT [B, N, N] softmax @ V [B, N, h]

Where B = batch size, N = sequence length, d = embed_dim, h = head_dim.

Multi-Head Attention: Loop-Based

A single attention head can only capture one type of relationship. But language is rich—we need syntax, semantics, coreference, and more. Multi-head attention runs multiple attention heads in parallel, each learning different patterns.

Multi-Head Attention: Different Heads, Different Patterns
Head 1: Syntactic
Head 2: Local
Head 3: Semantic
Head 4: Position

The naive implementation loops over heads. This is clear and correct, but not GPU-optimal:

class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention (Loop-based, for pedagogical clarity)
    
    Runs h attention heads in parallel and concatenates results.
    This version explicitly loops over heads—clear but not GPU-optimal.
    """
    def __init__(self, embed_dim: int, num_heads: int, head_dim: int = None) -> None:
        super().__init__()
        
        self.num_heads = num_heads
        self.head_dim = head_dim if head_dim else embed_dim // num_heads
        
        # Create h independent attention heads
        self.heads = nn.ModuleList([
            AttentionHead(embed_dim, self.head_dim) 
            for _ in range(num_heads)
        ])
        
        # Output projection: (num_heads * head_dim) → embed_dim
        self.W_o = nn.Linear(num_heads * self.head_dim, embed_dim)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        # Run each head (this is the inefficient part—Python loop!)
        head_outputs = [head(x, mask) for head in self.heads]
        
        # Concatenate along feature dimension
        # List of [B, N, head_dim] → [B, N, num_heads * head_dim]
        concat = torch.cat(head_outputs, dim=-1)
        
        # Final projection back to embed_dim
        output = self.W_o(concat)
        
        return output

This works, but there's a problem. Each head launches its own CUDA kernels. With 8 heads, that's 8 separate matrix multiplications for Q, K, V projections, plus 8 attention computations. The overhead adds up.

Multi-Head Attention: Vectorized

The vectorized version computes all heads in a single operation. Instead of h separate weight matrices, we use one large matrix that projects to all heads simultaneously, then reshape.

class MultiHeadAttentionVectorized(nn.Module):
    """
    Multi-Head Attention (Vectorized / GPU-Optimized)
    
    Key insight: Instead of h separate linear layers, use ONE large linear
    layer that computes all heads simultaneously, then reshape.
    
    This is how production implementations (PyTorch, HuggingFace) do it.
    """
    def __init__(self, embed_dim: int, num_heads: int, head_dim: int = None) -> None:
        super().__init__()
        
        self.num_heads = num_heads
        self.head_dim = head_dim if head_dim else embed_dim // num_heads
        self.scale = 1.0 / math.sqrt(self.head_dim)
        
        # Single large projection for all heads
        # embed_dim → (num_heads * head_dim) for Q, K, V each
        self.W_q = nn.Linear(embed_dim, num_heads * self.head_dim, bias=False)
        self.W_k = nn.Linear(embed_dim, num_heads * self.head_dim, bias=False)
        self.W_v = nn.Linear(embed_dim, num_heads * self.head_dim, bias=False)
        
        # Output projection
        self.W_o = nn.Linear(num_heads * self.head_dim, embed_dim)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        B, N, _ = x.shape
        H, D = self.num_heads, self.head_dim
        
        # Project all heads at once
        # [B, N, embed_dim] → [B, N, H*D]
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Reshape to separate heads: [B, N, H*D] → [B, N, H, D] → [B, H, N, D]
        Q = Q.view(B, N, H, D).transpose(1, 2)
        K = K.view(B, N, H, D).transpose(1, 2)
        V = V.view(B, N, H, D).transpose(1, 2)
        
        # Batched attention computation (all heads in parallel!)
        # [B, H, N, D] @ [B, H, D, N] → [B, H, N, N]
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        
        # [B, H, N, N] @ [B, H, N, D] → [B, H, N, D]
        attn_output = torch.matmul(attn_weights, V)
        
        # Merge heads back: [B, H, N, D] → [B, N, H, D] → [B, N, H*D]
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, N, H * D)
        
        # Final projection
        output = self.W_o(attn_output)
        
        return output
Why .contiguous()?

After transpose(), the tensor's memory layout doesn't match its logical layout. The view() operation requires contiguous memory. Calling .contiguous() copies the data into a contiguous block. This is a common gotcha in attention implementations.

How Attention Maps to the GPU

Understanding GPU execution is essential for optimizing attention. Let's trace what happens when the vectorized attention runs.

Loop vs Vectorized: Kernel Launch Comparison

Loop-Based (8 heads)

24 kernel launches for Q/K/V projections
8 kernels for Q @ KT
8 kernels for softmax
8 kernels for attn @ V
+ synchronization overhead

Vectorized

3 kernel launches for Q/K/V
1 batched GEMM for Q @ KT
1 batched softmax
1 batched GEMM for attn @ V
Tensor Core utilization

Each CUDA kernel launch has ~5-10μs overhead. With 48+ kernel launches per attention layer, this adds up in deep transformers.

GPU Memory Hierarchy

GPU Memory Hierarchy During Attention

The key insight: attention is memory-bound, not compute-bound. The GPU spends more time moving data than doing math. The N×N attention matrix must be written to HBM (slow global memory) between the softmax and the value multiplication. This is the bottleneck FlashAttention addresses.

Tensor Core Utilization

Modern NVIDIA GPUs (Volta and later) have Tensor Cores—specialized hardware for matrix multiply-accumulate operations. They operate on small tiles (e.g., 16×16) and can deliver 8× the throughput of regular CUDA cores.

The vectorized implementation enables Tensor Core usage because:

Memory Analysis: The Quadratic Problem

Attention has a fundamental scaling problem: the N×N attention matrix.

O(N²)
memory complexity for attention scores

Let's do the math for a typical setup: batch size 32, 32 attention heads, sequence length 4096.

Tensor Shape Memory (FP32)
Q, K, V (each) [32, 32, 4096, 64] 1.07+ GB
Attention Scores [32, 32, 4096, 4096] 64+ GB (!)
Output [32, 32, 4096, 64] 1.07+ GB

The attention scores alone consume most of an H100's memory (80GB). This is why long-context models are so challenging.

Memory vs Sequence Length

FlashAttention: Solving the Memory Bottleneck

FlashAttention (Dao et al., 2022) is one of the most impactful optimizations in modern deep learning. It reduces memory from O(N²) to O(N) while also being faster.

The key insight: never materialize the full N×N attention matrix. Instead, compute attention in tiles that fit in fast SRAM (shared memory), accumulating results on-the-fly.

Standard vs FlashAttention Memory Access

The Tiling Strategy

FlashAttention processes attention in blocks:

  1. Load a tile of Q (e.g., 64 rows) into SRAM
  2. For each tile of K and V:
    • Load K tile, compute partial attention scores
    • Compute partial softmax (with online normalization)
    • Load V tile, accumulate weighted output
  3. Write final output block to HBM

The attention scores are computed and consumed within SRAM—they never hit slow global memory. This is a 2-4× speedup and enables sequence lengths that would otherwise OOM.

Using FlashAttention in Practice

You don't need to implement FlashAttention yourself. PyTorch 2.0+ includes it via torch.nn.functional.scaled_dot_product_attention() with automatic backend selection. For explicit control, use the flash-attention library.

# PyTorch 2.0+ with automatic FlashAttention
import torch.nn.functional as F

# This automatically uses FlashAttention when available
output = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=mask,
    dropout_p=0.0,
    is_causal=True  # For decoder self-attention
)

Conclusion

We've covered a lot of ground:

The attention mechanism is deceptively simple—just a weighted average. But the details matter: scaling factors, memory layout, GPU utilization. Understanding these details is the difference between code that works and code that's fast.