// from 2000 GPUs to 384KB of SRAM

Online Softmax Demystified: From PyTorch to FlashAttention in Triton

FlashAttention is 6–9× faster (depending on precision) and uses O(N) memory instead of O(N²). The secret? Online softmax. This post takes you from a clean PyTorch baseline through naive Triton to a fused FlashAttention kernel—step by step, with no magic operators.

The Problem: Why Standard Attention Doesn't Scale

In the previous post, we saw that attention creates an N×N matrix of scores. For sequence length 4096 with batch 32 and 32 heads, that's 32 × 32 × 4096 × 4096 ≈ 17 billion elements. In FP32 (4 bytes each), that's ~69 GB just for attention scores. In FP16, half that—but still enormous.

A note on precision

Throughout this post we benchmark in FP16 (torch.float16), which is the standard precision for transformer inference and the sweet spot for GPU tensor cores. FP16 halves the memory footprint compared to FP32 and typically doubles throughput on modern GPUs. Our benchmark config (B=4, H=32, N=4096) uses 4 GB for the attention scores in FP16 vs 8 GB in FP32—far less than the 69 GB example above, because we use a smaller batch size.

But memory isn't the only problem. Even when we have enough memory, standard attention is slow because of memory bandwidth. The attention scores must be:

  1. Written to HBM after computing QKT
  2. Read from HBM for softmax
  3. Written to HBM after softmax
  4. Read from HBM for the final matmul with V

That's 4 trips to slow global memory for the N×N matrix. On an H100, HBM bandwidth is ~3 TB/s, but SRAM bandwidth is ~33 TB/s. We're leaving 10× performance on the table.

The Goal

Compute attention without ever materializing the N×N matrix in HBM. Keep everything in fast SRAM by processing tiles and accumulating results on-the-fly.

The blocker? Softmax needs the entire row to compute. Or does it?

Part 1: PyTorch Baseline

Let's start with clean, readable PyTorch. This is our reference implementation—correct but slow.

import torch
import torch.nn.functional as F
import math

def attention_pytorch(Q, K, V):
    """
    Standard scaled dot-product attention in PyTorch.
    
    Args:
        Q: [batch, heads, seq_len, head_dim]
        K: [batch, heads, seq_len, head_dim]
        V: [batch, heads, seq_len, head_dim]
    
    Returns:
        Output: [batch, heads, seq_len, head_dim]
    """
    B, H, N, D = Q.shape
    scale = 1.0 / math.sqrt(D)
    
    # Step 1: Compute attention scores
    # [B, H, N, D] @ [B, H, D, N] → [B, H, N, N]
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    
    # Step 2: Softmax (this is the bottleneck!)
    # Needs full row to compute max and sum
    attn = F.softmax(scores, dim=-1)
    
    # Step 3: Weighted sum of values
    # [B, H, N, N] @ [B, H, N, D] → [B, H, N, D]
    output = torch.matmul(attn, V)
    
    return output

Simple and correct. The scores tensor is shape [B, H, N, N]—that's our O(N²) memory. Let's profile where time actually goes:

# Profiling setup
B, H, N, D = 4, 32, 4096, 64
Q = torch.randn(B, H, N, D, device='cuda', dtype=torch.float16)
K = torch.randn(B, H, N, D, device='cuda', dtype=torch.float16)
V = torch.randn(B, H, N, D, device='cuda', dtype=torch.float16)

# Warmup
for _ in range(10):
    _ = attention_pytorch(Q, K, V)

# Benchmark
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
for _ in range(100):
    _ = attention_pytorch(Q, K, V)
end.record()
torch.cuda.synchronize()

print(f"PyTorch: {start.elapsed_time(end) / 100:.2f} ms")

On an NVIDIA RTX PRO 6000 Blackwell, this runs at roughly 18.4 ms per attention layer. We'll use this as our baseline.

Part 2: Triton 101 — A Gentle Introduction

Before we optimize attention, let's understand Triton. If you've never written a GPU kernel, this section is for you. If you're comfortable with Triton, skip to Part 3.

What is Triton?

Triton is a language for writing GPU kernels that sits between PyTorch and CUDA. It handles the painful parts (memory coalescing, shared memory management, warp synchronization) while giving you control over the algorithm.

PyTorch

High-level, automatic
Limited control
Can't fuse operations

Triton

Block-level programming
Control over tiling
Automatic memory management

The Execution Model

In Triton, you write a kernel that runs on many programs in parallel. Each program processes a block of data. Think of it like this:

Triton Execution Model

When you launch a kernel with grid=(4,), Triton runs 4 programs. Each program gets a unique program_id (0, 1, 2, 3) and processes its assigned block of data.

A Simple Example: Vector Addition

Let's start with the "hello world" of GPU programming:

import triton
import triton.language as tl

@triton.jit
def add_kernel(
    x_ptr,        # Pointer to first input
    y_ptr,        # Pointer to second input
    out_ptr,      # Pointer to output
    N,            # Total number of elements
    BLOCK: tl.constexpr,  # Block size (compile-time constant)
):
    # Which block am I?
    pid = tl.program_id(0)
    
    # Compute the range of indices this block handles
    # If pid=0 and BLOCK=256, offsets = [0, 1, 2, ..., 255]
    # If pid=1 and BLOCK=256, offsets = [256, 257, ..., 511]
    offsets = pid * BLOCK + tl.arange(0, BLOCK)
    
    # Mask: don't read/write out of bounds
    mask = offsets < N
    
    # Load data from global memory
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    
    # Compute
    out = x + y
    
    # Store result to global memory
    tl.store(out_ptr + offsets, out, mask=mask)

Key concepts:

To launch the kernel:

def add(x, y):
    out = torch.empty_like(x)
    N = x.numel()
    BLOCK = 256
    
    # Number of blocks needed
    grid = ((N + BLOCK - 1) // BLOCK,)
    
    # Launch kernel
    add_kernel[grid](x, y, out, N, BLOCK)
    
    return out
Why Pointers?

Triton kernels receive pointers to data, not tensors. x_ptr + offsets computes memory addresses. This is lower-level than PyTorch but gives us control over memory access patterns.

Part 3: Naive Attention in Triton

Now let's write attention in Triton—the same algorithm as PyTorch, just on GPU. This won't be faster (it might be slower), but it teaches us the mechanics.

We'll process one query row at a time. Each program handles one (batch, head, query) combination:

@triton.jit
def attention_naive_kernel(
    Q_ptr, K_ptr, V_ptr, Out_ptr,
    stride_qb, stride_qh, stride_qn, stride_qd,  # Q strides
    stride_kb, stride_kh, stride_kn, stride_kd,  # K strides
    stride_vb, stride_vh, stride_vn, stride_vd,  # V strides
    stride_ob, stride_oh, stride_on, stride_od,  # Out strides
    N, D,
    scale,
    BLOCK_D: tl.constexpr,
):
    # Which query row am I processing?
    pid_b = tl.program_id(0)  # batch index
    pid_h = tl.program_id(1)  # head index
    pid_n = tl.program_id(2)  # query row index
    
    # Pointer to the start of my Q row: Q[b, h, n, :]
    q_offset = pid_b * stride_qb + pid_h * stride_qh + pid_n * stride_qn
    
    # Load entire query vector (size D)
    d_range = tl.arange(0, BLOCK_D)
    q = tl.load(Q_ptr + q_offset + d_range * stride_qd, mask=d_range < D)
    
    # Compute attention scores with ALL keys
    # This is where we'd like to tile, but can't yet (softmax needs full row)
    m_i = -float('inf')  # running max for numerical stability
    l_i = 0.0             # running sum of exp(scores - max)
    acc = tl.zeros([BLOCK_D], dtype=tl.float32)  # accumulator for output
    
    # Loop over all key positions
    for j in range(N):
        # Load K[b, h, j, :]
        k_offset = pid_b * stride_kb + pid_h * stride_kh + j * stride_kn
        k = tl.load(K_ptr + k_offset + d_range * stride_kd, mask=d_range < D)
        
        # Compute dot product: q · k
        score = tl.sum(q * k) * scale
        
        # Update running max
        m_new = tl.maximum(m_i, score)
        
        # Rescale previous accumulator and sum
        alpha = tl.exp(m_i - m_new)
        l_i = l_i * alpha + tl.exp(score - m_new)
        acc = acc * alpha
        
        # Load V[b, h, j, :] and accumulate
        v_offset = pid_b * stride_vb + pid_h * stride_vh + j * stride_vn
        v = tl.load(V_ptr + v_offset + d_range * stride_vd, mask=d_range < D)
        acc += tl.exp(score - m_new) * v
    
    # Normalize by sum of exponentials
    acc = acc / l_i
    
    # Store result
    out_offset = pid_b * stride_ob + pid_h * stride_oh + pid_n * stride_on
    tl.store(Out_ptr + out_offset + d_range * stride_od, acc, mask=d_range < D)
This is slow!

This kernel processes one query at a time with a Python-level for loop over keys. It's pedagogical, not practical. We'll fix this in Part 6.

Notice something sneaky? We're already using the online softmax idea—maintaining a running max (m_i) and sum (l_i), rescaling as we go. Let's understand why this works.

Part 4: Why Standard Softmax Blocks Tiling

Before we can understand the solution, we need to deeply understand the problem. Let's look at softmax carefully.

Anatomy of Softmax

For a row x = [x₁, x₂, ..., xₙ], softmax is:

softmax(xᵢ) = exp(xᵢ) / Σⱼ exp(xⱼ)

Look at this formula. The numerator is easy—it only depends on xᵢ. We can compute it the moment we see xᵢ.

The denominator is the problem. It's a sum over all elements. We cannot compute softmax(x₁) until we've seen x₂, x₃, ..., xₙ.

The Core Problem

Every single output depends on every single input. To compute any softmax value, you need to have seen all N values first. This global dependency is what blocks tiling.

The Numerically Stable Version

In practice, we subtract the max for numerical stability (prevents overflow):

m = max(x)
softmax(xᵢ) = exp(xᵢ - m) / Σⱼ exp(xⱼ - m)

This doesn't help our problem—it makes it worse. Now we have two global dependencies:

Dependency 1: max

m = max(x₁, x₂, ..., xₙ)

Must see ALL values to know the maximum.

Dependency 2: sum

l = Σⱼ exp(xⱼ - m)

Must see ALL values (and know m) to compute sum.

Cost Analysis: Why This Kills Performance

Let's trace what happens in standard attention with N tokens:

Standard Attention Memory Access Pattern
# Step 1: Compute scores
scores = Q @ K.T               # [N, N] — write to HBM
# Step 2: Find max (read scores from HBM)
m = max(scores, dim=-1)       # read N² elements
# Step 3: Compute exp and sum (read scores again)
exp_scores = exp(scores - m)  # read N², write N²
l = sum(exp_scores, dim=-1)   # read N²
# Step 4: Normalize (read exp_scores again)
attn = exp_scores / l         # read N², write N²
# Step 5: Apply to values (read attn)
output = attn @ V             # read N²

Count the HBM accesses for the N×N attention matrix:

Operation HBM Reads HBM Writes
Q @ K.T → scores
max(scores)
exp(scores - m)
sum(exp_scores)
exp_scores / l
attn @ V
Total 5N² 3N²

That's 8N² memory operations for the attention scores alone. For N=4096, that's 134 million accesses to slow HBM, just for softmax bookkeeping.

8N²
HBM accesses for attention scores (read + write)

The Tiling Dream (That Doesn't Work)

We'd love to process attention in tiles: load a small block of Q and K, compute partial scores, immediately multiply by V, accumulate the result. This would keep data in fast SRAM instead of slow HBM.

But the denominator blocks us:

The Softmax Dependency Problem

Say we process the first tile and compute scores [s₁, s₂] for tokens 1-2. We want to compute:

softmax(s₁) = exp(s₁) / (exp(s₁) + exp(s₂) + exp(s₃) + ... + exp(sₙ))

We don't know s₃, s₄, ..., sₙ yet! We can't compute the denominator. So we're stuck—we must either:

  1. Store all scores and come back later (defeats tiling)
  2. Find a way to update the denominator incrementally (online softmax!)

What We Need

The question becomes: can we maintain a running estimate of the denominator that we correct as we see more values?

If yes, we could:

  1. Process tile 1: compute partial scores, partial denominator, partial output
  2. Process tile 2: update denominator, correct our previous work, accumulate more output
  3. Continue until done

This is exactly what online softmax does. The key insight is that when our estimate of max changes, we can mathematically correct our running sum with a simple multiplication.

Part 5: The Online Softmax Algorithm

Now that we understand the problem—the denominator requires seeing all values—let's see the elegant solution.

The Key Insight

What if we could update our answer as we process each tile? When we see a new value larger than our current max, we rescale everything we've computed so far.

Define for the first k elements:

mₖ = max(x₁, ..., xₖ)
lₖ = Σⱼ₌₁ᵏ exp(xⱼ - mₖ)    ← this is our running denominator estimate

The final softmax is: softmax(xᵢ) = exp(xᵢ - mₙ) / lₙ

Numerical Example: Watch It Happen

Let's trace through a concrete example. We have 5 values and want to compute softmax. Click "Next Step" to see the online algorithm in action.

Interactive: Standard vs Online Softmax

Notice how the online version maintains a running estimate that gets corrected whenever we see a larger value. The correction factor exp(m_old - m_new) rescales our previous work. At the end, both methods produce identical results.

Detailed Numerical Trace

Here's the exact arithmetic for input x = [2.0, 1.0, 4.0, 1.5, 3.0]:

Step-by-Step Online Softmax Computation
Step xₖ m_new Correction l_new
init -∞ 0
k=1 2.0 2.0 exp(-∞ - 2) = 0 0 × 0 + exp(0) = 1.0
k=2 1.0 2.0 exp(2 - 2) = 1.0 1.0 × 1 + exp(-1) = 1.368
k=3 4.0 ★ 4.0 ↑ exp(2 - 4) = 0.135 1.368 × 0.135 + exp(0) = 1.185
k=4 1.5 4.0 exp(4 - 4) = 1.0 1.185 × 1 + exp(-2.5) = 1.267
k=5 3.0 4.0 exp(4 - 4) = 1.0 1.267 × 1 + exp(-1) = 1.635
Final: m = 4.0, l = 1.635
softmax(xᵢ) = exp(xᵢ - 4.0) / 1.635

The key moment is step k=3: we encounter 4.0, which is larger than our previous max of 2.0. The correction factor exp(2 - 4) = 0.135 scales down all our previous work. This is the "online" magic—we fix our running estimate without going back to reprocess earlier elements.

The Recurrence Relations

The recurrence relations when adding element xₖ₊₁:

Online Softmax Update
Given: mₖ, lₖ (current max and sum)
New element: xₖ₊₁
# Update max
mₖ₊₁ = max(mₖ, xₖ₊₁)
# Rescale previous sum and add new term
lₖ₊₁ = lₖ × exp(mₖ - mₖ₊₁) + exp(xₖ₊₁ - mₖ₊₁)

The magic: exp(mₖ - mₖ₊₁) is the correction factor. If the new max is larger, this scales down our previous sum. If the new max equals the old max, this equals 1 (no change).

Standalone Online Softmax in Triton

Let's implement just softmax first, to verify the algorithm:

@triton.jit
def online_softmax_kernel(
    X_ptr, Out_ptr,
    N,
    BLOCK: tl.constexpr,
):
    # One program per row
    row_idx = tl.program_id(0)
    
    # Initialize running statistics
    m_i = -float('inf')  # running max
    l_i = 0.0             # running sum of exp(x - m)
    
    # First pass: compute m and l online
    for block_start in range(0, N, BLOCK):
        cols = block_start + tl.arange(0, BLOCK)
        mask = cols < N
        
        # Load block of values
        x = tl.load(X_ptr + row_idx * N + cols, mask=mask, other=-float('inf'))
        
        # Block-level max
        m_block = tl.max(x, axis=0)
        
        # Update global max
        m_new = tl.maximum(m_i, m_block)
        
        # Rescale previous sum
        l_i = l_i * tl.exp(m_i - m_new)
        
        # Add contributions from this block
        l_i += tl.sum(tl.exp(x - m_new), axis=0)
        
        # Update max
        m_i = m_new
    
    # Second pass: compute final softmax values
    for block_start in range(0, N, BLOCK):
        cols = block_start + tl.arange(0, BLOCK)
        mask = cols < N
        
        x = tl.load(X_ptr + row_idx * N + cols, mask=mask, other=-float('inf'))
        
        # Final softmax: exp(x - m) / l
        out = tl.exp(x - m_i) / l_i
        
        tl.store(Out_ptr + row_idx * N + cols, out, mask=mask)
Still Two Passes?

This version still has two passes, but they're over blocks, not individual elements. More importantly, we never write the intermediate values to global memory. The stats (m_i, l_i) stay in registers.

Why This Matters for Attention

In attention, we don't just want softmax—we want softmax(QKT) @ V. The online approach lets us:

  1. Process a block of K and V
  2. Compute partial attention scores (Q @ K_blockT)
  3. Update running max and sum
  4. Accumulate weighted V_block into output
  5. Repeat for next block

The N×N scores are never fully materialized. We compute and consume them block by block.

Part 6: Fused FlashAttention Kernel

Now we put it all together. This is a simplified but functional FlashAttention kernel:

@triton.jit
def flash_attention_kernel(
    Q_ptr, K_ptr, V_ptr, Out_ptr,
    stride_qb, stride_qh, stride_qn, stride_qd,
    stride_kb, stride_kh, stride_kn, stride_kd,
    stride_vb, stride_vh, stride_vn, stride_vd,
    stride_ob, stride_oh, stride_on, stride_od,
    N, D,
    scale,
    BLOCK_M: tl.constexpr,  # Block size for queries
    BLOCK_N: tl.constexpr,  # Block size for keys/values
    BLOCK_D: tl.constexpr,  # Head dimension (must cover full D)
):
    # Grid: (num_q_blocks, batch * heads)
    pid_m = tl.program_id(0)  # which query block
    pid_bh = tl.program_id(1)  # which batch*head
    
    # Decode batch and head from combined index
    # (Assuming num_heads is passed or computed)
    
    # Offsets for the query block this program handles
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, BLOCK_D)
    
    # Pointers to Q block: [BLOCK_M, D]
    q_ptrs = Q_ptr + pid_bh * stride_qh + \
             offs_m[:, None] * stride_qn + \
             offs_d[None, :] * stride_qd
    
    # Load Q block (stays in SRAM for entire kernel)
    mask_m = offs_m < N
    q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0)
    
    # Initialize accumulators (in registers)
    m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32)  # max per query
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)               # sum per query
    acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)      # output accumulator
    
    # Loop over key/value blocks
    for block_start in range(0, N, BLOCK_N):
        offs_n = block_start + tl.arange(0, BLOCK_N)
        mask_n = offs_n < N
        
        # Load K block: [BLOCK_N, D]
        k_ptrs = K_ptr + pid_bh * stride_kh + \
                 offs_n[:, None] * stride_kn + \
                 offs_d[None, :] * stride_kd
        k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0)
        
        # Compute attention scores: [BLOCK_M, BLOCK_N]
        scores = tl.dot(q, tl.trans(k)) * scale
        
        # Mask out invalid positions
        scores = tl.where(
            mask_m[:, None] & mask_n[None, :],
            scores,
            -float('inf')
        )
        
        # === Online softmax update ===
        
        # Max of this block (per query row)
        m_block = tl.max(scores, axis=1)
        
        # New global max
        m_new = tl.maximum(m_i, m_block)
        
        # Correction factors
        alpha = tl.exp(m_i - m_new)      # rescale old accumulator
        beta = tl.exp(m_block - m_new)   # scale for this block
        
        # Rescale running sum and accumulator
        l_i = l_i * alpha
        acc = acc * alpha[:, None]
        
        # Compute softmax weights for this block
        p = tl.exp(scores - m_new[:, None])
        
        # Update running sum
        l_i += tl.sum(p, axis=1)
        
        # Load V block: [BLOCK_N, D]
        v_ptrs = V_ptr + pid_bh * stride_vh + \
                 offs_n[:, None] * stride_vn + \
                 offs_d[None, :] * stride_vd
        v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0)
        
        # Accumulate weighted values: [BLOCK_M, D]
        acc += tl.dot(p.to(v.dtype), v)
        
        # Update max
        m_i = m_new
    
    # Final normalization
    acc = acc / l_i[:, None]
    
    # Store output
    out_ptrs = Out_ptr + pid_bh * stride_oh + \
               offs_m[:, None] * stride_on + \
               offs_d[None, :] * stride_od
    tl.store(out_ptrs, acc, mask=mask_m[:, None])

The wrapper function to launch the kernel:

def flash_attention(Q, K, V):
    """
    FlashAttention forward pass.
    
    Args:
        Q, K, V: [batch, heads, seq_len, head_dim]
    
    Returns:
        Output: [batch, heads, seq_len, head_dim]
    """
    B, H, N, D = Q.shape
    
    # Ensure contiguous
    Q = Q.contiguous()
    K = K.contiguous()
    V = V.contiguous()
    
    # Allocate output
    Out = torch.empty_like(Q)
    
    # Block sizes (tune for your GPU)
    BLOCK_M = 64
    BLOCK_N = 64
    BLOCK_D = triton.next_power_of_2(D)
    
    # Grid: one program per (query_block, batch*head)
    num_m_blocks = triton.cdiv(N, BLOCK_M)
    grid = (num_m_blocks, B * H)
    
    # Scale factor
    scale = 1.0 / math.sqrt(D)
    
    # Launch kernel
    flash_attention_kernel[grid](
        Q, K, V, Out,
        Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3),
        K.stride(0), K.stride(1), K.stride(2), K.stride(3),
        V.stride(0), V.stride(1), V.stride(2), V.stride(3),
        Out.stride(0), Out.stride(1), Out.stride(2), Out.stride(3),
        N, D,
        scale,
        BLOCK_M, BLOCK_N, BLOCK_D,
    )
    
    return Out

What Makes This "Flash"?

Standard Attention

Write N×N scores to HBM
Read for softmax
Write N×N attention to HBM
Read for V multiplication
Memory: O(N²)

FlashAttention

Load Q block once
Stream K, V blocks
Accumulate in registers
Write output only
Memory: O(N)

The key operations happen in SRAM (fast) not HBM (slow):

Only the final output is written to HBM. The intermediate attention scores never leave the chip.

Part 7: Benchmarks

Let's see if all this complexity is worth it. Test setup:

def benchmark(fn, Q, K, V, name, warmup=10, iters=100):
    # Warmup
    for _ in range(warmup):
        _ = fn(Q, K, V)
    
    # Benchmark
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    for _ in range(iters):
        _ = fn(Q, K, V)
    end.record()
    torch.cuda.synchronize()
    
    ms = start.elapsed_time(end) / iters
    print(f"{name}: {ms:.2f} ms")
    return ms

FP16 Results

FP16 is the standard precision for transformer inference. Each element is 2 bytes, which halves memory and enables tensor core acceleration.

Seq Length Attn Mem PyTorch PyTorch SDPA Our FlashAttn Speedup
512 0.06 GB 0.17 ms 0.03 ms 0.04 ms 4.7×
1024 0.25 GB 1.10 ms 0.10 ms 0.12 ms 8.9×
2048 1.00 GB 4.48 ms 0.37 ms 0.49 ms 9.2×
4096 4.00 GB 18.41 ms 1.54 ms 2.06 ms 8.9×
8192 16.00 GB 71.58 ms 6.31 ms 8.18 ms 8.8×

FP32 Results

FP32 doubles memory per element (4 bytes), so the attention score matrix is twice as large. This matters both for capacity (will it fit in VRAM?) and for bandwidth (twice as many bytes to move through the memory bus). Everything slows down:

Seq Length Attn Mem PyTorch PyTorch SDPA Our FlashAttn Speedup
512 0.12 GB 0.58 ms 0.17 ms 0.11 ms 5.1×
1024 0.50 GB 2.34 ms 0.62 ms 0.39 ms 6.0×
2048 2.00 GB 9.28 ms 2.53 ms 1.51 ms 6.2×
4096 8.00 GB 36.77 ms 10.18 ms 5.90 ms 6.2×
8192 32.00 GB 147.64 ms 40.78 ms 23.57 ms 6.3×

FP16 vs FP32: What Changes?

Comparing the two precision modes reveals interesting patterns:

FP16

~9× speedup over manual PyTorch
4 GB attention scores at N=4096
Tensor core acceleration
Standard for inference

FP32

~6× speedup over manual PyTorch
8 GB attention scores at N=4096
No tensor core benefit
Higher numerical precision

Manual PyTorch is ~2× slower in FP32 than FP16—expected, since twice the bytes means twice the memory traffic. Our FlashAttention kernel sees a larger slowdown (~2.9× at N=4096) because it relies heavily on tl.dot which benefits from FP16 tensor cores. The net effect: the speedup over manual PyTorch drops from ~9× (FP16) to ~6× (FP32), but is still substantial.

An interesting detail: in FP32 our kernel consistently beats PyTorch's SDPA across all sequence lengths (e.g. 5.90 ms vs 10.18 ms at N=4096). This is because SDPA's FlashAttention backend requires FP16/BF16 inputs—when given FP32, it falls back to a less optimized path. Our Triton kernel has no such restriction and handles FP32 natively with the same tiled algorithm.

6–9×
speedup over manual PyTorch (FP32–FP16)

Memory Usage

Speed is only half the story. At N=2048 with FP16:

PyTorch Peak Memory

2.16 GB
Includes the 1.0 GB N×N attention matrix

FlashAttention Peak Memory

0.16 GB
93% reduction — N×N matrix never materialized

The speedup is consistent across sequence lengths: our FlashAttention kernel is ~9× faster than manual PyTorch in FP16, ~6× in FP32. PyTorch's built-in SDPA (which uses an optimized FlashAttention backend) is faster still—our kernel lands within ~1.3× of it in FP16, which is a solid result for an educational implementation.

Production FlashAttention

Our implementation is educational. PyTorch's built-in F.scaled_dot_product_attention uses an optimized FlashAttention backend and is ~1.3× faster than our kernel. The real flash-attention library is faster still due to additional optimizations: better block sizes, warp-level tiling, and causal masking fusion. Use those for production.

Conclusion

We went from a 12-line PyTorch function to a 100-line Triton kernel. Was it worth it? In FP16: a consistent ~9× speedup and 93% memory reduction. In FP32: still ~6× faster. Both within striking distance of PyTorch's production-optimized SDPA backend.

The key insights:

  1. Standard softmax requires the full row — this forces materializing N×N scores
  2. Online softmax computes incrementally — maintaining running max and sum with correction factors
  3. Fused kernels avoid HBM round-trips — compute scores, softmax, and output accumulation on-chip
  4. Memory complexity drops from O(N²) to O(N) — enabling long contexts

The online softmax trick is elegant: just two extra variables (m and l) and a correction factor (exp(m_old - m_new)). But it unlocks a fundamental change in how we compute attention.

This pattern—restructuring algorithms to fit memory hierarchies—is the future of performance optimization. As models get larger and sequences get longer, understanding where your data lives matters more than raw FLOPs.

Full Code

Complete implementation with tests: github.com/isztld/flash-attention-triton