A learnable, content-aware soft dictionary lookup. Each token forms a query and asks every other token (via their keys) “how relevant are you to me?”, then collects a weighted sum of their values. The weights are softmaxed dot-product similarities — fully differentiable, fully parallel, and unbounded in range.
What problem it solves
RNNs process tokens sequentially and forget across long distances; gradients vanish through long chains. CNNs only see a local receptive field; bridging long-range dependencies requires deep stacks of layers. Yet language is full of arbitrarily long-range dependencies — the subject and verb of a sentence can be separated by an unbounded relative clause.
Self-attention removes the bottleneck. Every token can directly look at every other token in the sequence, in a single layer, with a learnable importance weighting. Distance no longer matters; only relevance does.
Intuition: a soft dictionary lookup
A Python dictionary or JSON object stores key–value pairs and answers queries by exact-matching the query against the keys:
{
"Name": "Jane Doe",
"Address": "37 Coronation street",
"Date of birth": "May 5th 2000",
"Place of birth":"Hull",
}
Query "Date of birth" → exact-match the key "Date of birth" → return the value "May 5th 2000". Mathematically:
The indicator is 1 for exactly one matching key, 0 elsewhere. Hard match. Not differentiable, so it can’t be inserted into a neural network and trained with gradient descent.
Relaxing the match
Replace the hard indicator with a soft, continuous similarity. Instead of “are these equal?”, ask “how aligned are these vectors?” — a dot product. Push the dot products through softmax to turn them into a probability distribution over keys:
Now every key contributes to the result, but the highest-similarity key contributes the most. The result is a weighted average of values, not a single value. Fully differentiable.
This soft, content-addressed lookup is the entire substance of attention. The rest of the page is just “where do , , come from, and how do we apply this across a sequence?”
Self-attention: every token makes its own Q, K, V
In self-attention, the queries, keys, and values are all derived from the same input sequence — every token contributes one query, one key, and one value. (When the queries come from a different source than the keys/values, you get cross-attention instead.)
For a sequence of token embeddings , the layer holds three learnable weight matrices:
| Matrix | Shape | Job |
|---|---|---|
| Project token embedding into the query space | ||
| Project token embedding into the key space | ||
| Project token embedding into the value space |
For each token :
Stacking all tokens row-wise gives matrices and:
each in (with or ). Three projections of the same sequence into three different “roles” — interrogator, label, content.
TIP — Why three projections of the same input?
If , the attention pattern becomes “every token attends most strongly to itself” (because a vector is most similar to itself). The three separate learned projections let the network decouple “what is this token about?” (key) from “what does this token want to know about?” (query) from “what information does this token contribute when attended to?” (value). All three roles emerge from the same word, but live in different learned subspaces.
Scaled dot-product attention
Combine the three projections into the canonical formula:
Reading it left to right:
- () — every token’s query dotted against every other token’s key. Entry is the raw similarity score between token as a querier and token as a target.
- — divide by the square root of the key dimension. Without scaling, large makes dot products grow large in magnitude, pushing softmax into saturation regions where gradients vanish.
- (applied row-wise) — turn each row of similarities into a probability distribution. Row now says “how much does token care about each other token?” — sums to 1.
- — multiply the attention matrix by the value matrix. Each output row is a weighted sum of value vectors, weighted by token ‘s attention distribution. Output is — one new vector per input token.
The output is a sequence of the same length as the input, but every output position is now a content-mixed combination of every input position.
A worked picture
Consider the input “the animal didn’t cross the street because it was too tired”. For the token it, the model needs to figure out what it refers to. After training, it’s query strongly aligns with the key of animal, weakly with street, weakly with the others. The attention distribution for it looks roughly like:
| Token | Attention weight from it |
|---|---|
the | 0.02 |
animal | 0.45 |
didn't | 0.05 |
| … | … |
street | 0.10 |
| … | … |
tired | 0.08 |
The output vector at the position of it is now a weighted blend dominated by the value vector of animal — the network has contextually rewritten it to mean “it (referring to animal)“. Coreference resolution as a side effect of training the right attention distribution.
Self-attention vs. feed-forward
A standard MLP layer applies a fixed weight matrix to every input: . The same is used regardless of the input.
Self-attention applies a dynamic, input-dependent weight matrix: the matrix is computed from the inputs themselves. Different inputs produce different attention patterns, and therefore different effective weight matrices.
| Layer | Weight matrix | What it depends on |
|---|---|---|
| MLP / feed-forward | — fixed after training | Nothing per-input |
| Self-attention | — computed at every forward pass | The input sequence itself |
This is the deeper reason transformers are so expressive: the connection pattern between tokens is itself learned and varies with input. The MLP analogue is a network whose connections rewire on the fly to fit each new sentence.
Masked self-attention (causal attention)
When self-attention is used in a decoder for autoregressive generation, every token must only see previous tokens — never future ones. Otherwise, predicting token would be trivial: the model could just look at the answer.
Masked attention enforces causality by zeroing out the attention scores for positions before the softmax:
After softmax, becomes , so future positions contribute nothing. The attention matrix is lower-triangular: each row has non-zero weights only on columns .
In the encoder, no masking is applied (the encoder can freely attend to the entire input). The decoder uses masked self-attention on its own outputs and unmasked cross-attention over the encoder’s latent code.
Time and memory cost
Self-attention computes an matrix for a sequence of length . Time and memory both scale as . For this is fine; for it becomes prohibitive. This is the quadratic complexity bottleneck that motivated countless approximate-attention variants (sparse attention, linear attention, Longformer, Performer, FlashAttention’s IO-aware optimisations). For the canonical “Attention Is All You Need” transformer, the quadratic cost is just the price you pay.
Worked example: tiny self-attention by hand
For a single token’s query and three other tokens’ keys and values , all 3-dimensional, with :
Query: .
Keys and values:
Step 1 — raw similarities :
Step 2 — scale by :
.
Step 3 — softmax (numerically: , , ; sum ):
.
Token 3 dominates the attention.
Step 4 — weighted sum of values:
.
The output for token is essentially , lightly perturbed by tiny contributions from and . Token “looked at” token 3 and copied its value, because was much more aligned with than the other keys.
Related
- transformer — the architecture built around stacked self-attention layers
- multi-head-attention — running self-attention with parallel heads
- cross-attention — same machinery, but comes from a different sequence than
- positional-encoding — required because self-attention is permutation-invariant
- softmax — the row-wise normalisation that turns scores into a distribution
- dot-product — the per-pair similarity measure
- autoregressive-model — masked self-attention enables autoregressive decoders
Active Recall
What is the formula for scaled dot-product attention, and what does each piece do?
. computes pairwise query-key similarities; scales the magnitudes so the softmax doesn’t saturate at large ; softmax (row-wise) turns the similarity scores into a probability distribution per query; computes a weighted sum of value vectors using those probabilities. Output is one new vector per query position.
Why three separate projections instead of using the input directly as queries, keys, and values?
If , every token’s query is most similar to its own key (a vector is most similar to itself), so attention degenerates to “every token attends to itself”. The three learned projections decouple “what does this token want to ask?” (query), “what is this token labelled as?” (key), and “what content does this token contribute when attended to?” (value), all from the same input embedding but in different learned subspaces. This is what makes attention expressive.
How is self-attention different from a feed-forward / MLP layer in terms of the weights?
A feed-forward layer applies a fixed weight matrix — the same matrix is used for every input, set at training and frozen at inference. Self-attention applies a dynamic weight matrix that is computed from the input itself every forward pass. Different inputs produce different attention patterns. This input-dependence is why transformers are so much more expressive per parameter than MLPs.
Why is masked self-attention needed in the decoder, and how is it implemented?
An autoregressive decoder predicts token given tokens . If the decoder’s self-attention layer could see future tokens, prediction would be trivial — the answer is in the input. Masked attention prevents this by setting attention scores for positions to before softmax, so they become after softmax and contribute nothing. The attention matrix is lower-triangular: each row only has non-zero weights on columns .
Why does self-attention scale quadratically with sequence length, and why does that matter?
The attention score matrix is — one entry per (query, key) pair. Computing and storing this matrix costs in time and in memory. For tokens this is fine; for (long documents, high-resolution image sequences) it becomes prohibitive. This is the quadratic bottleneck that motivates approximate-attention variants (Longformer, Performer, FlashAttention).
A query is compared against keys , , . Without softmax-scaling, what are the raw dot-product similarities?
, , . Token 1 is most similar (highest dot product), token 2 is orthogonal (zero similarity), token 3 is anti-aligned (negative similarity). After softmax, almost all weight goes to token 1, very little to token 2, and even less to token 3.