Self Attention

← Back to Knowledge Share

At its core, self-attention is a sequence-to-sequence operation. It takes a sequence of vectors and produces a new sequence of vectors of the same length, where each output vector is a weighted sum of all the input vectors.

The easiest way to understand the math is through a database retrieval analogy: you have a Query, you check it against a set of Keys, and depending on how well they match, you retrieve the corresponding Values. In self-attention, the query, keys, and values all come from the *same* input sequence.

1. The Input and Linear Projections

Let the input sequence be represented by a matrix \(X \in \mathbb{R}^{N \times d_{\text{model}}}\), where \(N\) is the sequence length (number of tokens) and \(d_{\text{model}}\) is the embedding dimension.

To create the Queries (\(Q\)), Keys (\(K\)), and Values (\(V\)), we multiply the input matrix \(X\) by three learnable weight matrices:

  • \(W^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}\)
  • \(W^K \in \mathbb{R}^{d_{\text{model}} \times d_k}\)
  • \(W^V \in \mathbb{R}^{d_{\text{model}} \times d_v}\)

This gives us our \(Q\), \(K\), and \(V\) matrices:

\[Q = XW^Q\]

\[K = XW^K\]

\[V = XW^V\]

If you are tracking the tensor dimensions (as you typically would when setting up your network architecture in code), \(Q\) and \(K\) will have the shape \(N \times d_k\), and \(V\) will have the shape \(N \times d_v\).

2. Computing Attention Scores (Dot Product)

Next, we need to determine how much focus (or "attention") each token should place on every other token in the sequence. We do this by taking the dot product of the Query matrix with the transposed Key matrix:

\[S = QK^T\]

The resulting score matrix \(S \in \mathbb{R}^{N \times N}\) represents the raw alignment scores. For instance, the entry \(S_{ij}\) tells us how much the token at position \(i\) (the query) wants to attend to the token at position \(j\) (the key).

3. Scaling and Softmax

The raw dot products can grow very large in magnitude when the key dimension \(d_k\) is large. Large values push the subsequent softmax function into regions where gradients are extremely small (vanishing gradients).

To stabilize the gradients during training, the scores are scaled down by the square root of the key dimension, \(\sqrt{d_k}\). Then, a softmax function is applied row-wise so that the attention weights for each query sum to 1.

\[A = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)\]

Here, \(A \in \mathbb{R}^{N \times N}\) is the attention weight matrix. \(A_{ij}\) is the probability-like weight that token \(i\) assigns to token \(j\).

4. The Final Output

Finally, we compute the output sequence by multiplying the attention weight matrix \(A\) by the Value matrix \(V\):

\[\text{Output} = AV\]

Because \(A\) has shape \(N \times N\) and \(V\) has shape \(N \times d_v\), the resulting Output matrix will have the shape \(N \times d_v\). Each vector in this output matrix is a convex combination of the original Value vectors, weighted heavily toward the tokens that scored highest in the attention matrix.

The Complete Equation

Putting it all together, the single, unified equation for Scaled Dot-Product Attention is:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]