Online Softmax Demystified: From PyTorch to FlashAttention in Triton
FlashAttention is 6–9× faster (depending on precision) and uses O(N) memory instead of O(N²). The secret? Online softmax. This post takes you from a clean PyTorch baseline through naive Triton to a fused FlashAttention kernel—step by step, with no magic operators.
The Problem: Why Standard Attention Doesn't Scale
In the previous post, we saw that attention creates an N×N matrix of scores. For sequence length 4096 with batch 32 and 32 heads, that's 32 × 32 × 4096 × 4096 ≈ 17 billion elements. In FP32 (4 bytes each), that's ~69 GB just for attention scores. In FP16, half that—but still enormous.
Throughout this post we benchmark in FP16 (torch.float16), which is the standard precision for transformer inference and the sweet spot for GPU tensor cores. FP16 halves the memory footprint compared to FP32 and typically doubles throughput on modern GPUs. Our benchmark config (B=4, H=32, N=4096) uses 4 GB for the attention scores in FP16 vs 8 GB in FP32—far less than the 69 GB example above, because we use a smaller batch size.
But memory isn't the only problem. Even when we have enough memory, standard attention is slow because of memory bandwidth. The attention scores must be:
- Written to HBM after computing QKT
- Read from HBM for softmax
- Written to HBM after softmax
- Read from HBM for the final matmul with V
That's 4 trips to slow global memory for the N×N matrix. On an H100, HBM bandwidth is ~3 TB/s, but SRAM bandwidth is ~33 TB/s. We're leaving 10× performance on the table.
Compute attention without ever materializing the N×N matrix in HBM. Keep everything in fast SRAM by processing tiles and accumulating results on-the-fly.
The blocker? Softmax needs the entire row to compute. Or does it?
Part 1: PyTorch Baseline
Let's start with clean, readable PyTorch. This is our reference implementation—correct but slow.
import torch
import torch.nn.functional as F
import math
def attention_pytorch(Q, K, V):
"""
Standard scaled dot-product attention in PyTorch.
Args:
Q: [batch, heads, seq_len, head_dim]
K: [batch, heads, seq_len, head_dim]
V: [batch, heads, seq_len, head_dim]
Returns:
Output: [batch, heads, seq_len, head_dim]
"""
B, H, N, D = Q.shape
scale = 1.0 / math.sqrt(D)
# Step 1: Compute attention scores
# [B, H, N, D] @ [B, H, D, N] → [B, H, N, N]
scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
# Step 2: Softmax (this is the bottleneck!)
# Needs full row to compute max and sum
attn = F.softmax(scores, dim=-1)
# Step 3: Weighted sum of values
# [B, H, N, N] @ [B, H, N, D] → [B, H, N, D]
output = torch.matmul(attn, V)
return output
Simple and correct. The scores tensor is shape [B, H, N, N]—that's our O(N²) memory. Let's profile where time actually goes:
# Profiling setup
B, H, N, D = 4, 32, 4096, 64
Q = torch.randn(B, H, N, D, device='cuda', dtype=torch.float16)
K = torch.randn(B, H, N, D, device='cuda', dtype=torch.float16)
V = torch.randn(B, H, N, D, device='cuda', dtype=torch.float16)
# Warmup
for _ in range(10):
_ = attention_pytorch(Q, K, V)
# Benchmark
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(100):
_ = attention_pytorch(Q, K, V)
end.record()
torch.cuda.synchronize()
print(f"PyTorch: {start.elapsed_time(end) / 100:.2f} ms")
On an NVIDIA RTX PRO 6000 Blackwell, this runs at roughly 18.4 ms per attention layer. We'll use this as our baseline.
Part 2: Triton 101 — A Gentle Introduction
Before we optimize attention, let's understand Triton. If you've never written a GPU kernel, this section is for you. If you're comfortable with Triton, skip to Part 3.
What is Triton?
Triton is a language for writing GPU kernels that sits between PyTorch and CUDA. It handles the painful parts (memory coalescing, shared memory management, warp synchronization) while giving you control over the algorithm.
PyTorch
High-level, automatic
Limited control
Can't fuse operations
Triton
Block-level programming
Control over tiling
Automatic memory management
The Execution Model
In Triton, you write a kernel that runs on many programs in parallel. Each program processes a block of data. Think of it like this:
When you launch a kernel with grid=(4,), Triton runs 4 programs. Each program gets a unique program_id (0, 1, 2, 3) and processes its assigned block of data.
A Simple Example: Vector Addition
Let's start with the "hello world" of GPU programming:
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # Pointer to first input
y_ptr, # Pointer to second input
out_ptr, # Pointer to output
N, # Total number of elements
BLOCK: tl.constexpr, # Block size (compile-time constant)
):
# Which block am I?
pid = tl.program_id(0)
# Compute the range of indices this block handles
# If pid=0 and BLOCK=256, offsets = [0, 1, 2, ..., 255]
# If pid=1 and BLOCK=256, offsets = [256, 257, ..., 511]
offsets = pid * BLOCK + tl.arange(0, BLOCK)
# Mask: don't read/write out of bounds
mask = offsets < N
# Load data from global memory
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# Compute
out = x + y
# Store result to global memory
tl.store(out_ptr + offsets, out, mask=mask)
Key concepts:
@triton.jit— compiles the function to GPU codetl.program_id(0)— which program am I? (likeblockIdx.xin CUDA)tl.arange(0, BLOCK)— create indices [0, 1, ..., BLOCK-1]tl.load / tl.store— read/write global memorymask— prevent out-of-bounds access
To launch the kernel:
def add(x, y):
out = torch.empty_like(x)
N = x.numel()
BLOCK = 256
# Number of blocks needed
grid = ((N + BLOCK - 1) // BLOCK,)
# Launch kernel
add_kernel[grid](x, y, out, N, BLOCK)
return out
Triton kernels receive pointers to data, not tensors. x_ptr + offsets computes memory addresses. This is lower-level than PyTorch but gives us control over memory access patterns.
Part 3: Naive Attention in Triton
Now let's write attention in Triton—the same algorithm as PyTorch, just on GPU. This won't be faster (it might be slower), but it teaches us the mechanics.
We'll process one query row at a time. Each program handles one (batch, head, query) combination:
@triton.jit
def attention_naive_kernel(
Q_ptr, K_ptr, V_ptr, Out_ptr,
stride_qb, stride_qh, stride_qn, stride_qd, # Q strides
stride_kb, stride_kh, stride_kn, stride_kd, # K strides
stride_vb, stride_vh, stride_vn, stride_vd, # V strides
stride_ob, stride_oh, stride_on, stride_od, # Out strides
N, D,
scale,
BLOCK_D: tl.constexpr,
):
# Which query row am I processing?
pid_b = tl.program_id(0) # batch index
pid_h = tl.program_id(1) # head index
pid_n = tl.program_id(2) # query row index
# Pointer to the start of my Q row: Q[b, h, n, :]
q_offset = pid_b * stride_qb + pid_h * stride_qh + pid_n * stride_qn
# Load entire query vector (size D)
d_range = tl.arange(0, BLOCK_D)
q = tl.load(Q_ptr + q_offset + d_range * stride_qd, mask=d_range < D)
# Compute attention scores with ALL keys
# This is where we'd like to tile, but can't yet (softmax needs full row)
m_i = -float('inf') # running max for numerical stability
l_i = 0.0 # running sum of exp(scores - max)
acc = tl.zeros([BLOCK_D], dtype=tl.float32) # accumulator for output
# Loop over all key positions
for j in range(N):
# Load K[b, h, j, :]
k_offset = pid_b * stride_kb + pid_h * stride_kh + j * stride_kn
k = tl.load(K_ptr + k_offset + d_range * stride_kd, mask=d_range < D)
# Compute dot product: q · k
score = tl.sum(q * k) * scale
# Update running max
m_new = tl.maximum(m_i, score)
# Rescale previous accumulator and sum
alpha = tl.exp(m_i - m_new)
l_i = l_i * alpha + tl.exp(score - m_new)
acc = acc * alpha
# Load V[b, h, j, :] and accumulate
v_offset = pid_b * stride_vb + pid_h * stride_vh + j * stride_vn
v = tl.load(V_ptr + v_offset + d_range * stride_vd, mask=d_range < D)
acc += tl.exp(score - m_new) * v
# Normalize by sum of exponentials
acc = acc / l_i
# Store result
out_offset = pid_b * stride_ob + pid_h * stride_oh + pid_n * stride_on
tl.store(Out_ptr + out_offset + d_range * stride_od, acc, mask=d_range < D)
This kernel processes one query at a time with a Python-level for loop over keys. It's pedagogical, not practical. We'll fix this in Part 6.
Notice something sneaky? We're already using the online softmax idea—maintaining a running max (m_i) and sum (l_i), rescaling as we go. Let's understand why this works.
Part 4: Why Standard Softmax Blocks Tiling
Before we can understand the solution, we need to deeply understand the problem. Let's look at softmax carefully.
Anatomy of Softmax
For a row x = [x₁, x₂, ..., xₙ], softmax is:
Look at this formula. The numerator is easy—it only depends on xᵢ. We can compute it the moment we see xᵢ.
The denominator is the problem. It's a sum over all elements. We cannot compute softmax(x₁) until we've seen x₂, x₃, ..., xₙ.
Every single output depends on every single input. To compute any softmax value, you need to have seen all N values first. This global dependency is what blocks tiling.
The Numerically Stable Version
In practice, we subtract the max for numerical stability (prevents overflow):
softmax(xᵢ) = exp(xᵢ - m) / Σⱼ exp(xⱼ - m)
This doesn't help our problem—it makes it worse. Now we have two global dependencies:
Dependency 1: max
m = max(x₁, x₂, ..., xₙ)
Must see ALL values to know the maximum.
Dependency 2: sum
l = Σⱼ exp(xⱼ - m)
Must see ALL values (and know m) to compute sum.
Cost Analysis: Why This Kills Performance
Let's trace what happens in standard attention with N tokens:
Count the HBM accesses for the N×N attention matrix:
| Operation | HBM Reads | HBM Writes |
|---|---|---|
| Q @ K.T → scores | — | N² |
| max(scores) | N² | — |
| exp(scores - m) | N² | N² |
| sum(exp_scores) | N² | — |
| exp_scores / l | N² | N² |
| attn @ V | N² | — |
| Total | 5N² | 3N² |
That's 8N² memory operations for the attention scores alone. For N=4096, that's 134 million accesses to slow HBM, just for softmax bookkeeping.
The Tiling Dream (That Doesn't Work)
We'd love to process attention in tiles: load a small block of Q and K, compute partial scores, immediately multiply by V, accumulate the result. This would keep data in fast SRAM instead of slow HBM.
But the denominator blocks us:
Say we process the first tile and compute scores [s₁, s₂] for tokens 1-2. We want to compute:
We don't know s₃, s₄, ..., sₙ yet! We can't compute the denominator. So we're stuck—we must either:
- Store all scores and come back later (defeats tiling)
- Find a way to update the denominator incrementally (online softmax!)
What We Need
The question becomes: can we maintain a running estimate of the denominator that we correct as we see more values?
If yes, we could:
- Process tile 1: compute partial scores, partial denominator, partial output
- Process tile 2: update denominator, correct our previous work, accumulate more output
- Continue until done
This is exactly what online softmax does. The key insight is that when our estimate of max changes, we can mathematically correct our running sum with a simple multiplication.
Part 5: The Online Softmax Algorithm
Now that we understand the problem—the denominator requires seeing all values—let's see the elegant solution.
The Key Insight
What if we could update our answer as we process each tile? When we see a new value larger than our current max, we rescale everything we've computed so far.
Define for the first k elements:
lₖ = Σⱼ₌₁ᵏ exp(xⱼ - mₖ) ← this is our running denominator estimate
The final softmax is: softmax(xᵢ) = exp(xᵢ - mₙ) / lₙ
Numerical Example: Watch It Happen
Let's trace through a concrete example. We have 5 values and want to compute softmax. Click "Next Step" to see the online algorithm in action.
Notice how the online version maintains a running estimate that gets corrected whenever we see a larger value. The correction factor exp(m_old - m_new) rescales our previous work. At the end, both methods produce identical results.
Detailed Numerical Trace
Here's the exact arithmetic for input x = [2.0, 1.0, 4.0, 1.5, 3.0]:
| Step | xₖ | m_new | Correction | l_new |
|---|---|---|---|---|
| init | — | -∞ | — | 0 |
| k=1 | 2.0 | 2.0 | exp(-∞ - 2) = 0 | 0 × 0 + exp(0) = 1.0 |
| k=2 | 1.0 | 2.0 | exp(2 - 2) = 1.0 | 1.0 × 1 + exp(-1) = 1.368 |
| k=3 | 4.0 ★ | 4.0 ↑ | exp(2 - 4) = 0.135 | 1.368 × 0.135 + exp(0) = 1.185 |
| k=4 | 1.5 | 4.0 | exp(4 - 4) = 1.0 | 1.185 × 1 + exp(-2.5) = 1.267 |
| k=5 | 3.0 | 4.0 | exp(4 - 4) = 1.0 | 1.267 × 1 + exp(-1) = 1.635 |
softmax(xᵢ) = exp(xᵢ - 4.0) / 1.635
The key moment is step k=3: we encounter 4.0, which is larger than our previous max of 2.0. The correction factor exp(2 - 4) = 0.135 scales down all our previous work. This is the "online" magic—we fix our running estimate without going back to reprocess earlier elements.
The Recurrence Relations
The recurrence relations when adding element xₖ₊₁:
The magic: exp(mₖ - mₖ₊₁) is the correction factor. If the new max is larger, this scales down our previous sum. If the new max equals the old max, this equals 1 (no change).
Standalone Online Softmax in Triton
Let's implement just softmax first, to verify the algorithm:
@triton.jit
def online_softmax_kernel(
X_ptr, Out_ptr,
N,
BLOCK: tl.constexpr,
):
# One program per row
row_idx = tl.program_id(0)
# Initialize running statistics
m_i = -float('inf') # running max
l_i = 0.0 # running sum of exp(x - m)
# First pass: compute m and l online
for block_start in range(0, N, BLOCK):
cols = block_start + tl.arange(0, BLOCK)
mask = cols < N
# Load block of values
x = tl.load(X_ptr + row_idx * N + cols, mask=mask, other=-float('inf'))
# Block-level max
m_block = tl.max(x, axis=0)
# Update global max
m_new = tl.maximum(m_i, m_block)
# Rescale previous sum
l_i = l_i * tl.exp(m_i - m_new)
# Add contributions from this block
l_i += tl.sum(tl.exp(x - m_new), axis=0)
# Update max
m_i = m_new
# Second pass: compute final softmax values
for block_start in range(0, N, BLOCK):
cols = block_start + tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(X_ptr + row_idx * N + cols, mask=mask, other=-float('inf'))
# Final softmax: exp(x - m) / l
out = tl.exp(x - m_i) / l_i
tl.store(Out_ptr + row_idx * N + cols, out, mask=mask)
This version still has two passes, but they're over blocks, not individual elements. More importantly, we never write the intermediate values to global memory. The stats (m_i, l_i) stay in registers.
Why This Matters for Attention
In attention, we don't just want softmax—we want softmax(QKT) @ V. The online approach lets us:
- Process a block of K and V
- Compute partial attention scores (Q @ K_blockT)
- Update running max and sum
- Accumulate weighted V_block into output
- Repeat for next block
The N×N scores are never fully materialized. We compute and consume them block by block.
Part 6: Fused FlashAttention Kernel
Now we put it all together. This is a simplified but functional FlashAttention kernel:
@triton.jit
def flash_attention_kernel(
Q_ptr, K_ptr, V_ptr, Out_ptr,
stride_qb, stride_qh, stride_qn, stride_qd,
stride_kb, stride_kh, stride_kn, stride_kd,
stride_vb, stride_vh, stride_vn, stride_vd,
stride_ob, stride_oh, stride_on, stride_od,
N, D,
scale,
BLOCK_M: tl.constexpr, # Block size for queries
BLOCK_N: tl.constexpr, # Block size for keys/values
BLOCK_D: tl.constexpr, # Head dimension (must cover full D)
):
# Grid: (num_q_blocks, batch * heads)
pid_m = tl.program_id(0) # which query block
pid_bh = tl.program_id(1) # which batch*head
# Decode batch and head from combined index
# (Assuming num_heads is passed or computed)
# Offsets for the query block this program handles
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_D)
# Pointers to Q block: [BLOCK_M, D]
q_ptrs = Q_ptr + pid_bh * stride_qh + \
offs_m[:, None] * stride_qn + \
offs_d[None, :] * stride_qd
# Load Q block (stays in SRAM for entire kernel)
mask_m = offs_m < N
q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0)
# Initialize accumulators (in registers)
m_i = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32) # max per query
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # sum per query
acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) # output accumulator
# Loop over key/value blocks
for block_start in range(0, N, BLOCK_N):
offs_n = block_start + tl.arange(0, BLOCK_N)
mask_n = offs_n < N
# Load K block: [BLOCK_N, D]
k_ptrs = K_ptr + pid_bh * stride_kh + \
offs_n[:, None] * stride_kn + \
offs_d[None, :] * stride_kd
k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0)
# Compute attention scores: [BLOCK_M, BLOCK_N]
scores = tl.dot(q, tl.trans(k)) * scale
# Mask out invalid positions
scores = tl.where(
mask_m[:, None] & mask_n[None, :],
scores,
-float('inf')
)
# === Online softmax update ===
# Max of this block (per query row)
m_block = tl.max(scores, axis=1)
# New global max
m_new = tl.maximum(m_i, m_block)
# Correction factors
alpha = tl.exp(m_i - m_new) # rescale old accumulator
beta = tl.exp(m_block - m_new) # scale for this block
# Rescale running sum and accumulator
l_i = l_i * alpha
acc = acc * alpha[:, None]
# Compute softmax weights for this block
p = tl.exp(scores - m_new[:, None])
# Update running sum
l_i += tl.sum(p, axis=1)
# Load V block: [BLOCK_N, D]
v_ptrs = V_ptr + pid_bh * stride_vh + \
offs_n[:, None] * stride_vn + \
offs_d[None, :] * stride_vd
v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0)
# Accumulate weighted values: [BLOCK_M, D]
acc += tl.dot(p.to(v.dtype), v)
# Update max
m_i = m_new
# Final normalization
acc = acc / l_i[:, None]
# Store output
out_ptrs = Out_ptr + pid_bh * stride_oh + \
offs_m[:, None] * stride_on + \
offs_d[None, :] * stride_od
tl.store(out_ptrs, acc, mask=mask_m[:, None])
The wrapper function to launch the kernel:
def flash_attention(Q, K, V):
"""
FlashAttention forward pass.
Args:
Q, K, V: [batch, heads, seq_len, head_dim]
Returns:
Output: [batch, heads, seq_len, head_dim]
"""
B, H, N, D = Q.shape
# Ensure contiguous
Q = Q.contiguous()
K = K.contiguous()
V = V.contiguous()
# Allocate output
Out = torch.empty_like(Q)
# Block sizes (tune for your GPU)
BLOCK_M = 64
BLOCK_N = 64
BLOCK_D = triton.next_power_of_2(D)
# Grid: one program per (query_block, batch*head)
num_m_blocks = triton.cdiv(N, BLOCK_M)
grid = (num_m_blocks, B * H)
# Scale factor
scale = 1.0 / math.sqrt(D)
# Launch kernel
flash_attention_kernel[grid](
Q, K, V, Out,
Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3),
K.stride(0), K.stride(1), K.stride(2), K.stride(3),
V.stride(0), V.stride(1), V.stride(2), V.stride(3),
Out.stride(0), Out.stride(1), Out.stride(2), Out.stride(3),
N, D,
scale,
BLOCK_M, BLOCK_N, BLOCK_D,
)
return Out
What Makes This "Flash"?
Standard Attention
Write N×N scores to HBM
Read for softmax
Write N×N attention to HBM
Read for V multiplication
Memory: O(N²)
FlashAttention
Load Q block once
Stream K, V blocks
Accumulate in registers
Write output only
Memory: O(N)
The key operations happen in SRAM (fast) not HBM (slow):
scores = tl.dot(q, tl.trans(k))— computed on-chipp = tl.exp(scores - m_new)— computed on-chipacc += tl.dot(p, v)— accumulated on-chip
Only the final output is written to HBM. The intermediate attention scores never leave the chip.
Part 7: Benchmarks
Let's see if all this complexity is worth it. Test setup:
- GPU: NVIDIA RTX PRO 6000 Blackwell Workstation Edition (96GB)
- Config: B=4, H=32, D=64
- PyTorch: 2.8.0+cu129, Triton: 3.4.0
def benchmark(fn, Q, K, V, name, warmup=10, iters=100):
# Warmup
for _ in range(warmup):
_ = fn(Q, K, V)
# Benchmark
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
_ = fn(Q, K, V)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end) / iters
print(f"{name}: {ms:.2f} ms")
return ms
FP16 Results
FP16 is the standard precision for transformer inference. Each element is 2 bytes, which halves memory and enables tensor core acceleration.
| Seq Length | Attn Mem | PyTorch | PyTorch SDPA | Our FlashAttn | Speedup |
|---|---|---|---|---|---|
| 512 | 0.06 GB | 0.17 ms | 0.03 ms | 0.04 ms | 4.7× |
| 1024 | 0.25 GB | 1.10 ms | 0.10 ms | 0.12 ms | 8.9× |
| 2048 | 1.00 GB | 4.48 ms | 0.37 ms | 0.49 ms | 9.2× |
| 4096 | 4.00 GB | 18.41 ms | 1.54 ms | 2.06 ms | 8.9× |
| 8192 | 16.00 GB | 71.58 ms | 6.31 ms | 8.18 ms | 8.8× |
FP32 Results
FP32 doubles memory per element (4 bytes), so the attention score matrix is twice as large. This matters both for capacity (will it fit in VRAM?) and for bandwidth (twice as many bytes to move through the memory bus). Everything slows down:
| Seq Length | Attn Mem | PyTorch | PyTorch SDPA | Our FlashAttn | Speedup |
|---|---|---|---|---|---|
| 512 | 0.12 GB | 0.58 ms | 0.17 ms | 0.11 ms | 5.1× |
| 1024 | 0.50 GB | 2.34 ms | 0.62 ms | 0.39 ms | 6.0× |
| 2048 | 2.00 GB | 9.28 ms | 2.53 ms | 1.51 ms | 6.2× |
| 4096 | 8.00 GB | 36.77 ms | 10.18 ms | 5.90 ms | 6.2× |
| 8192 | 32.00 GB | 147.64 ms | 40.78 ms | 23.57 ms | 6.3× |
FP16 vs FP32: What Changes?
Comparing the two precision modes reveals interesting patterns:
FP16
~9× speedup over manual PyTorch
4 GB attention scores at N=4096
Tensor core acceleration
Standard for inference
FP32
~6× speedup over manual PyTorch
8 GB attention scores at N=4096
No tensor core benefit
Higher numerical precision
Manual PyTorch is ~2× slower in FP32 than FP16—expected, since twice the bytes means twice the memory traffic. Our FlashAttention kernel sees a larger slowdown (~2.9× at N=4096) because it relies heavily on tl.dot which benefits from FP16 tensor cores. The net effect: the speedup over manual PyTorch drops from ~9× (FP16) to ~6× (FP32), but is still substantial.
An interesting detail: in FP32 our kernel consistently beats PyTorch's SDPA across all sequence lengths (e.g. 5.90 ms vs 10.18 ms at N=4096). This is because SDPA's FlashAttention backend requires FP16/BF16 inputs—when given FP32, it falls back to a less optimized path. Our Triton kernel has no such restriction and handles FP32 natively with the same tiled algorithm.
Memory Usage
Speed is only half the story. At N=2048 with FP16:
PyTorch Peak Memory
2.16 GB
Includes the 1.0 GB N×N attention matrix
FlashAttention Peak Memory
0.16 GB
93% reduction — N×N matrix never materialized
The speedup is consistent across sequence lengths: our FlashAttention kernel is ~9× faster than manual PyTorch in FP16, ~6× in FP32. PyTorch's built-in SDPA (which uses an optimized FlashAttention backend) is faster still—our kernel lands within ~1.3× of it in FP16, which is a solid result for an educational implementation.
Our implementation is educational. PyTorch's built-in F.scaled_dot_product_attention uses an optimized FlashAttention backend and is ~1.3× faster than our kernel. The real flash-attention library is faster still due to additional optimizations: better block sizes, warp-level tiling, and causal masking fusion. Use those for production.
Conclusion
We went from a 12-line PyTorch function to a 100-line Triton kernel. Was it worth it? In FP16: a consistent ~9× speedup and 93% memory reduction. In FP32: still ~6× faster. Both within striking distance of PyTorch's production-optimized SDPA backend.
The key insights:
- Standard softmax requires the full row — this forces materializing N×N scores
- Online softmax computes incrementally — maintaining running max and sum with correction factors
- Fused kernels avoid HBM round-trips — compute scores, softmax, and output accumulation on-chip
- Memory complexity drops from O(N²) to O(N) — enabling long contexts
The online softmax trick is elegant: just two extra variables (m and l) and a correction factor (exp(m_old - m_new)). But it unlocks a fundamental change in how we compute attention.
This pattern—restructuring algorithms to fit memory hierarchies—is the future of performance optimization. As models get larger and sequences get longer, understanding where your data lives matters more than raw FLOPs.
Complete implementation with tests: github.com/isztld/flash-attention-triton