<aside> 📢
<aside> 🎯
</aside>
<aside>
Standard self-attention is expensive because its time and space complexities scale quadratically with sequence length N. It materializes the full N×N attention matrix in High Bandwidth Memory (HBM), which is a slow read/write operation for GPUs. FlashAttention (Dao et al., 2022, arXiv:2205.14135) restructures self-attention to be IO-aware, meaning it minimizes memory transfers between slow HBM and fast, on-chip SRAM.
FlashAttention achieves speedups not by reducing FLOPs (it performs comparable FLOPs in forward and more in backward due to recomputation), but by reducing memory access overhead:
⚠️ Limitations & Caveats
$$
\begin{aligned} & \text{[Standard Attention IO Complexity]: dominated by } \mathcal{O}(N^2) \text{ HBM accesses} \\ & \text{[FlashAttention IO Complexity]: } \mathcal{O}\left(\frac{N^2}{\sqrt{M}}\right) \text{ (or } \mathcal{O}(N^2 d / M_{\text{eff}})\text{), where } M \text{ is SRAM size} \\ & \text{[Online Softmax (Maximum)]: } m_{new} = \max(m_{old}, m_{block}) \\ & \text{[Online Softmax (Exp Sum)]: } l_{new} = l_{old} \cdot e^{m_{old} - m_{new}} + l_{block} \cdot e^{m_{block} - m_{new}} \\ & \text{[Rescaled Output]: } O_{new} = \frac{l_{old} \cdot e^{m_{old} - m_{new}} \cdot O_{old} + l_{block} \cdot e^{m_{block} - m_{new}} \cdot O_{block}}{l_{new}} \\ & \text{[FlashAttention-3 FP8 Scaling]: uses per-tensor scaling factors } (s_Q, s_K, s_V) \text{ for low-precision computation} \end{aligned}
$$
Standard softmax requires knowing the global maximum (m) and the sum of exponentials (l) of a row before computing probabilities. This creates a global barrier, forcing the model to read the entire row of Query-Key dot products before normalizing.
FlashAttention circumvents this using online softmax (Milakov & Gimelshein, 2018), which tracks running local statistics as it processes blocks. This enables numerically stable streaming computation without materializing the full attention row in HBM.
When a new block is processed, FlashAttention compares the new local maximum with the running maximum and rescales previously accumulated outputs using the exponential difference. This correction ensures mathematical equivalence to standard softmax while enabling block-by-block computation.
⚠️ Limitations & Caveats
The original FlashAttention suffered from hardware underutilization because it did not fully parallelize across sequence dimensions and spent cycles on non-matmul operations (e.g., softmax updates) on standard CUDA cores. FlashAttention-2 addressed these issues primarily on Ampere-class GPUs (e.g., A100), but did not exploit the new hardware primitives introduced in Hopper-class GPUs (e.g., H100), motivating FlashAttention-3.
| Dimension | FlashAttention | FlashAttention-2 | FlashAttention-3 |
|---|---|---|---|
| Computational Cost | Moderate (some underutilization) | Lower (optimized for Ampere/A100) | Lowest (overlaps memory + compute on Hopper/H100) |
| Data Types | bfloat16 / float16 | bfloat16 / float16 | bfloat16 / float16 / float8 (with scaling) |
| Known Limitation | Limited parallelism | Shared memory pressure / bank conflicts | Hopper-specific optimizations, limited portability |
⚠️ Limitations & Caveats
<aside>
Speculative Decoding addresses the memory-bandwidth bottleneck of autoregressive inference by converting sequential generation into parallel verification (Leviathan et al., 2023, arXiv:2211.17192). In standard decoding, the processor loads the entire massive weight matrix of the target model from High Bandwidth Memory (HBM) to generate a single token, leaving compute units underutilized.
Speculative Decoding mitigates this by introducing a smaller, faster "draft" model to predict a sequence of γ (gamma) future tokens. The large "target" model then processes this entire drafted sequence in a single forward pass. Because modern accelerators are compute-rich but memory-starved, the latency of passing γ tokens through the target model simultaneously is nearly identical to passing just one token.
⚠️ Limitations & Caveats
$$
\begin{aligned} & \text{[Acceptance Probability]: } P_{accept}(x) = \min\left(1, \frac{P_{target}(x)}{P_{draft}(x)}\right) \\ & \text{[Resampling Distribution (if rejected)]: } P_{resample}(x) = \frac{\max(0, P_{target}(x) - P_{draft}(x))}{\sum_{x'} \max(0, P_{target}(x') - P_{draft}(x'))} \\ & \text{[Expected Tokens per Step]: } \mathbb{E}[N] = 1 + \sum_{i=1}^{\gamma} \prod_{j=1}^{i} P_{accept}(x_j) \\ & \text{[Wall-Clock Speedup Ratio]: } S = \frac{\gamma \cdot t_{target}}{t_{target} + \gamma \cdot t_{draft}} \times \frac{\mathbb{E}[N]}{\gamma} \\ & \text{[Medusa Head Objective]: } \mathcal{L}{Medusa} = \sum{k=1}^{K} \lambda_k \cdot \text{CrossEntropy}(P_{target}^{(t+k)}, P_{head_k}^{(t)}) \end{aligned} $$
To ensure the output distribution is mathematically identical to what the target model would have produced alone, Speculative Decoding uses a specialized rejection sampling scheme. For a given drafted token x, the system compares the probability assigned by the target model (P_target) to the probability assigned by the draft model (P_draft).
If P_target(x) ≥ P_draft(x), the token is accepted with 100% probability, meaning the draft model correctly anticipated a highly likely token. If P_target(x) < P_draft(x), the token is accepted with a probability equal to the ratio P_target(x) / P_draft(x); this handles cases where the draft model was overly confident about a token that the target model considers less likely.
If a token is rejected through this stochastic process, the sequence breaks. The system must then sample a new token from a modified probability distribution (the "Resampling Distribution" in the math block) that subtracts the draft model's probability mass from the target model's distribution. This correction ensures the final marginal probability exactly matches P_target.
⚠️ Limitations & Caveats
Standard Speculative Decoding requires finding a smaller draft model that perfectly matches the tokenizer and adequately approximates the reasoning of the target model, which is often difficult in practice. Architectures like Medusa (Cai et al., 2024, arXiv:2401.10774) eliminate the secondary model entirely by grafting multiple independent "decoding heads" onto the final hidden layer of the target model itself.
Each Medusa head is trained to predict a different future token offset (e.g., Head 1 predicts token t+1, Head 2 predicts token t+2). During inference, the model generates a tree of possible future token sequences in a single forward pass. The system then uses a hardware-efficient "Tree Attention" mechanism to verify all candidate paths simultaneously, accepting the longest valid sequence.
| --- | --- | --- |
⚠️ Limitations & Caveats
<aside> <img src="/icons/reorder_gray.svg" alt="/icons/reorder_gray.svg" width="40px" />
</aside>