Attention Is All You Need

Ashish Vaswani, Noam Shazeer, Niki Parmar

2017 · NeurIPS

Attention Is All You Need

Problem

Framing

RNN and CNN transducers bottleneck on sequential computation, so training scales poorly and long-range paths stay long. The paper replaces both with a pure self-attention encoder-decoder, reaching 28.4 BLEU on WMT14 En-De and 41.8 BLEU on En-Fr.

Currently Used Methods

Foundational

Proposed Method

Architecture

The model stacks 6 encoder and 6 decoder layers. Each encoder layer applies multi-head self-attention and a position-wise FFN; each decoder layer adds masked self-attention and encoder-decoder attention. Base settings are dmodel=512d_{\mathrm{model}}=512, h=8h=8, dk=dv=64d_k=d_v=64, and dff=2048d_{\mathrm{ff}}=2048.

Transformer encoder-decoder architecture with stacked self-attention, masked decoder self-attention, encoder-decoder attention, residual Add&Norm blocks, and positional encodings.

Loss / Objective

Training uses next-token cross-entropy with label smoothing ϵls=0.1\epsilon_{ls}=0.1.

L=ti=1Vqi(t)logpθ ⁣(yi(t)y<t,x)\mathcal{L} = - \sum_t \sum_{i=1}^{V} q_i^{(t)} \log p_{\theta}\!\left(y_i^{(t)} \mid y_{<t}, x\right)

Algorithm

The core computation is scaled dot-product attention, composed into multi-head attention.

Attention(Q,K,V)=softmax ⁣(QKTdk)V\mathrm{Attention}(Q,K,V) = \mathrm{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right)V MultiHead(Q,K,V)=Concat(head1,,headh)WO,headi=Attention(QWiQ,KWiK,VWiV)\mathrm{MultiHead}(Q,K,V) = \mathrm{Concat}(\mathrm{head}_1,\ldots,\mathrm{head}_h)W^O, \quad \mathrm{head}_i = \mathrm{Attention}(QW_i^Q, KW_i^K, VW_i^V)

Training Procedure

lrate=dmodel0.5min ⁣(step0.5,  stepwarmup_steps1.5)\mathrm{lrate} = d_{\mathrm{model}}^{-0.5} \min\!\left(\mathrm{step}^{-0.5},\; \mathrm{step}\cdot \mathrm{warmup\_steps}^{-1.5}\right)

Evaluation

Datasets

Metrics

Headline results

Table 1: En-De development ablations over heads, width, dropout, label smoothing, and model size

SettingNdmodeldffhdkdvPdropεlstrain stepsPPL (dev)BLEU (dev)params ×106
base65122048864640.10.1100K4.9225.865
(A)15125125.2924.9
(A)41281285.0025.5
(A)1632324.9125.8
(A)3216165.0125.4
(B)165.1625.158
(B)325.0125.460
(C)26.1123.736
(C)45.1925.350
(C)84.8825.580
(C)25632325.7524.528
(C)10241281284.6626.0168
(C)10245.1225.453
(C)40964.7526.290
(D)0.05.7724.6
(D)0.24.9525.5
(D)0.04.6725.3
(D)0.25.4725.7
(E)positional embedding instead of sinusoids4.9225.7
big610244096160.3300K4.3326.4213

Ablations

Method Strengths and Weaknesses

Strengths

Weaknesses

Suggestions from the authors

Links

Prior Papers

Further Papers