FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Tri Dao, Daniel Y. Fu

2022 · NeurIPS

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Problem

Framing

Standard exact attention is IO-bound: it materializes the N×NN \times N score matrix in HBM, so memory traffic dominates long-context runtime and memory. FlashAttention closes this gap with SRAM-resident tiling and online softmax, preserving exact attention while reducing HBM accesses to Θ(N2d2M1)\Theta(N^2 d^2 M^{-1}) instead of Θ(Nd+N2)\Theta(Nd + N^2).

Currently Used Methods

Foundational

Proposed Method

Architecture

FlashAttention preserves exact self-attention and only changes execution order. It tiles QQ, KK, and VV into SRAM blocks, fuses score, masking, softmax, dropout, and value accumulation, and writes only output tiles to HBM.

Verified architecture figure: GPU memory hierarchy, FlashAttention tiling over Q/K/V blocks in SRAM, and a small GPT-2 attention speedup comparison against PyTorch.

Loss / Objective

The operator is unchanged from standard exact attention:

O=softmax(QK)V.\mathbf{O} = \operatorname{softmax}(\mathbf{Q}\mathbf{K}^\top)\mathbf{V}.

Algorithm

Its key step is an online softmax merge across key-value tiles, so prior score blocks are never materialized:

minew=max(mi,m~ij),linew=emiminewli+em~ijminewl~ij,\mathbf{m}_i^{\mathrm{new}} = \max\big(\mathbf{m}_i, \tilde{\mathbf{m}}_{ij}\big), \qquad \mathbf{l}_i^{\mathrm{new}} = e^{\mathbf{m}_i-\mathbf{m}_i^{\mathrm{new}}}\mathbf{l}_i + e^{\tilde{\mathbf{m}}_{ij}-\mathbf{m}_i^{\mathrm{new}}}\tilde{\mathbf{l}}_{ij}, Oinew=diag(linew)1(diag(li)emiminewOi+em~ijminewP~ijVj).\mathbf{O}_i^{\mathrm{new}} = \operatorname{diag}(\mathbf{l}_i^{\mathrm{new}})^{-1}\left(\operatorname{diag}(\mathbf{l}_i)e^{\mathbf{m}_i-\mathbf{m}_i^{\mathrm{new}}}\mathbf{O}_i + e^{\tilde{\mathbf{m}}_{ij}-\mathbf{m}_i^{\mathrm{new}}}\tilde{\mathbf{P}}_{ij}\mathbf{V}_j\right).

Training Procedure

Evaluation

Datasets

Metrics

Headline results

Verified results plot: bar chart of FlashAttention speedup on T4 across sequence lengths, with largest gains when masking and dropout are fused.

Table 3: Long-Range Arena accuracy and speedup

ModelsListOpsTextRetrievalImagePathfinderAvgSpeedup
Transformer36.063.681.642.372.759.3-
FlashAttention37.663.981.443.572.759.82.4×
Block-sparse FlashAttention37.063.081.343.673.359.62.8×
Linformer [84]35.655.977.737.867.654.92.5×
Linear Attention [50]38.863.280.742.672.559.62.3×
Performer [12]36.863.682.242.169.958.91.8×
Local Attention [80]36.160.276.740.666.656.01.7×
Reformer [51]36.563.878.539.669.457.61.3×
Smyrf [19]36.164.179.039.670.557.91.7×

Ablations

Method Strengths and Weaknesses

Strengths

Weaknesses

Suggestions from the authors

Links

Prior Papers

Further Papers

No vault papers identified as further work yet.