Aussie AI
20. Attention
-
Book Excerpt from "Generative AI in C++"
-
by David Spuler, Ph.D.
“When people talk, listen completely. Most people never listen.”
— Ernest Hemingway
What is Attention?
The attention mechanism is one of the major breakthroughs that allowed advanced AI to take shape. After all, the seminal 2017 Transformer paper was titled Attention is all you need (Vaswani et al., 2017). It's such an endlessly cited paper that it must be a real downer if your name is in the “et al” part, so here's the full list of names: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin; see https://arxiv.org/abs/1706.03762.
What's so great about that paper? The overall class of attention algorithms used in Transformers is called “self-attention” because the different tokens in a sequence pay attention to each other and their relative positions. The vanilla Transformer used a specific type of self-attention called “scaled dot-product attention.”
The idea for attention comes from human intelligence. When we are considering something, we tend to pay more attention to certain features than others. This is true when humans examine words in a sentence or parts of an image. Hence, the AI idea is to apply attention to tokens with different weightings, and have the model learn these weights through training.
Attention is a very powerful mechanism in terms of model capability. It allows the model to learn how much “attention” it should pay to other tokens in the sequence. Hence, it is a mapping between tokens (or words) that indicates how to interrelate the presence of a token in a sequence with the probabilities of the next output token. Thus, it is deeply involved in deciding on the next token to output based on what tokens have previously appeared.
Masking and Lookahead
The above discussion makes it sound like attention only looks backwards at the previously emitted tokens, whereas the situation is more complicated. When an engine starts processing an input query from a user, it actually has an existing sequence of tokens from the prompt, and can “look ahead” at some of the upcoming tokens, too. This idea is executed in the “encoder” part of the Transformer, which can look at the entirety of the user prompt. However, the decoder is typically disallowed from lookahead features, and uses “masked attention” which blocks the decoder from looking at future tokens (i.e., its view of the future tokens is “masked off”). Hence, the usual decoder is only allowed to look at the already-produced output tokens from the engine.
Thus, in the vanilla encoder-decoder architecture, there are two different types of attention. The encoder is used to pay attention to token positions of the input text (i.e. the user prompt) and that is a major part of its “encoding” intelligence. The decoder's attention mechanism pays attention only to the output sequence that it has already produced, rather than the input sequence. The decoder indirectly gets attention information about the input prompt from the encoder via “cross attention” links, but the decoder doesn't directly examine the input text.
What is Cross Attention?
Cross attention is a high-level crossover of attention results between encoder and decoder, whereas the QKV computations are the low-level mechanism inside the attention heads. The cross attention mechanism allows the decoder to pay attention to the encoder output in every layer of the decoder.
In the vanilla encoder-decoder architecture, cross attention allows the decoder to get some attention information about the input prompt. The encoder's attention calculations are based on analyzing the input prompt with lookahead. The decoder's attention is focused only on the output sequence (using “masked attention”), and it doesn't directly analyze the input prompt. Hence, the decoder indirectly pays attention to the input prompt via cross attention results coming across from the encoder.
In a decoder-only architecture (e.g. GPT), the whole of cross attention is removed because there isn't an encoder to provide this input. Similarly, in an encoder-only architecture (e.g. BERT), you can code up cross attention if you like, but there'll be no-one listening on the other end.
What are Q, K and V?
Each attention head has a significant amount of computational work to do each iteration. The attention mechanism works at runtime by using three different vectors:
- Q — Query
- K — Key
- V — Value
All three of these vectors are acted upon by parameters that are learned during training. In fact, the ability to learn how to show attention to different tokens is deeply enmeshed in the intelligence of LLMs in their processing of text sequences. These calculations occur in the three vectors during runtime processing.
The attention block performs multiple levels of computations on the QKV matrices. In pseudocode, the QKV attention mechanism looks like:
// 3 linear projections Q = linear-projection(WQ, Input); K = linear-projection(WK, Input); V = linear-projection(WV, Input); // Combine Q and K QVCombined = MatMul(Q, K) // Softmax normalization QVCombined = Softmax(QVCombined) // Merge V in too QKVCombined = MatMul(QVCombined, VLinear); return QKVCombined
Are QKV vectors or matrices? Both, really. Let me explain. Q, K and V are often mentioned as “vectors” and I've also been calling them vectors above, but that's not the full story. They can be called “vectors” because each of Q, K, and V end up as a vector for each token. The key point is at the end: for each token.
There's three QKV vectors for each token. However, for a sequence of multiple tokens that make up the prompt, each of QKV has vectors for each token, so the structure is really a vector-of-vectors, which is a matrix. Hence, the resulting Q, K and V calculations result in three matrices, which are called Q, K, and V.
Aren't you glad that you asked, now?
What are the QKV weight matrices? These are three other matrices, not the QKV matrices, so there are six matrices floating around inside your GPU. Three are dynamically computed, and three are static parts of the model. The operations performed on QKV all have matrices of weights involved, and these three matrices are designated with names based on the QKV matrix to which they apply (i.e. WQ, WK, and WV). These three weight matrices are:
(a) learned during training, and
(b) static during inference.
Hence, the intelligence in the attention mechanism is trained into these three weight matrices, and the resulting three QKV matrices are the dynamic computations based on this learned attention, as computed during runtime inference.
In more detail, what's actually happening is the input matrix has one-dimension of the sequence length (of tokens) and one dimension of the embedding size (an internal model meta-parameter). Each of Q, K, and V has their own (static) matrix of weights. The input into the attention block is the same for all three computations. This input matrix is separately multiplied in a matrix-by-matrix multiplication (MatMul), using each of the three different weight matrices, to give you the three resulting Q, K and V matrices.
Yes, matrices. These three Q, K and V matrices are dynamically computed values during processing, and thus don't contain any learned data or static weights themselves. However, the QKV matrices are still two-dimensional matrices indexed by:
(a) token sequence (in the prompt), and
(b) embedding vector dimensions.
Technically, the Q, K and V matrices are “linear projections” of input sequences based on the (static) parameters in the three weight matrices. And the creation of the three QKV matrices is just the first step inside the attention block, with multiple subsequent steps that combine Q, K and V matrices back together. It is analogous to a “mini-model” within the overall model, because this attention method has trained weights and ongoing tracking of QKV matrices of probability-like values that indicate how much “attention” each token should pay to another.
Are QKV used for inference or training? Both. The computations of the Q, K and V matrices occur in both training and inference. During training, the weights in the three related weight matrices are updated (and put into the model file at the end), whereas for inference, the weights are static. The QKV matrices themselves are not part of the model file, because they contain dynamic calculations during both training and inference.
Are QKV used by encoders or decoders? Both. There are attention mechanisms in both encoders and decoders. In the vanilla Transformer, the encoder mechanism allows “lookahead” for processing, whereas the decoder uses “masked attention” that disallows processing of upcoming prompt tokens. Masked attention means the decoder can look backwards at already-emitted tokens only. Hence, the encoder pays attention to the input prompt, whereas the decoder can only directly consider the output tokens. Just to confuse matters further, there's also “cross attention” where the decoder indirectly gets information about the prompt, but only via the encoder's work.
What about decoder-only models? Yes, there are differences in attention architectures for encoder-only models (e.g. BERT) and decoder-only models (e.g. GPT). In a decoder-only model, the encoder does not provide input to the decoder layers (because there's no encoder at all), and the “cross attention” capabilities are therefore removed from the decoder. However, the decoder-only architecture still uses masked attention without lookahead, and outputs are based entirely off the already-output tokens.
Softmax Normalization
The Softmax component is used as part of the attention head. It normalizes the output values into proper probabilities (i.e., not negative and not too large), by scaling them using a “sum-of-exponentials” method. This also ensures that all of the distribution sums to one, as probabilities should. See Chapter 25 for more about Softmax.
Some research papers use a different normalization in the attention heads. For example, the “Hardmax” function can be used instead of Softmax, which makes it a different type of distribution that isn't a range of probabilities. Another possibility is the “Sparsemax” function. However, only Softmax has mainstream acceptance in Transformer architectures.
Positional Encoding
The way I think about positional encoding is that it's a kind of work-around. A weird feature of the attention mechanism is that the Transformer can lose track of the relative position of words. And the ordering of words in a sentence somewhat matters!
In order to help the attention mechanism know to pay attention not just to the words, but also their position in sentences, positional encoding is used put more “position” information into the input sequences. Technically, positional encoding creates a vector that is added into the “embeddings” vector as an extra step.
The conversion of tokens to embeddings is a complex learned procedure that is then combined with positional encoding. The vanilla Transformer used trigonometric sine and cosine functions, without any trainable parameters, but there are various alternative positional encoding algorithms. The computation of positional encodings adds some extra mathematical values at the very end of the creation of the embeddings. This occurs dynamically at runtime, and positional encoding values are not learned weights. See Chapter 27 for more about embeddings and positional encoding.
Interestingly, maybe this work-around is not necessary. There's been some theoretical research that positional encoding can be omitted completely, and Transformers are apparently capable of learning positional context without the extra hints. The algorithm is amusingly named NoPE (“no positional encoding”), but this is still early research (see Chapter 27).
Multi-Head Attention
The standard Transformer architecture splits the attention algorithm in parallel using Multi-Head Attention (MHA). Typically, the calculation of the attention algorithm is parallelized across 16 heads. This is the reason that the “model dimension” of Transformers is usually an exact multiple of 16, so that it can be efficiently parallelized without any “extra” cases (i.e. avoids needing any padding).
When I read about the vanilla Transformer in 2017 having “Multi-Head Attention” with 16 attention heads acting in parallel, I had a major question about it: why? Or more specifically, why 16?
Is splitting the attention mechanism into 16 different “heads” just a code optimization? Is it only to make it faster? Or does it affect the model's output results to make it “smarter” because of changes to the distribution of weights and its processing algorithm? In other words, would the model act any differently in terms of outputs if we merged these 16 heads into one, or if we doubled it to 32 attention heads? Or would changing the number of attention heads only affect the speed?
I'm like a three-year-old: why? why? why?
The whole thing is very complicated, and I could be wrong, but here's my best answer: it's just a code optimization and there's nothing special about 16.
To be clear, I'm not saying that attention is just a code optimization, but only the “multi-head attention” modification. The attention mechanism makes the model smarter, but not splitting it into multiple heads. The innovation was attention and being able to run it fast enough was important for that reason.
In more detail, the computation of a multi-head attention simply computes exactly the same numbers as if it did the whole thing with one head and one massive matrix computation. Or if it had 32 heads instead of 16, and did smaller slices with more parallelism. I don't see anywhere in the multi-head attention algorithm where the final calculations are dependent on the number of heads or the size of each head. Yes, it's a nice parallelism, where it does 16 slices of the same computation and then merges them back together at the end. It seems to be purely a parallelization coding optimization, and does not change the accuracy or intelligence of the output.
Maybe I'm wrong. I've seen a lot of articles that say that MHA allows each attention head to focus on different features of the inputs, and gives a more intricate understanding of the text. The implication is that splitting the computations across multiple attention heads makes it smarter than one big tensor operation. But I can't see that in the code.
Efficient Attention Algorithms
The self-attention mechanism is computationally expensive, being perhaps a case of “too much of a good thing.” The quadratic complexity of self-attention in a vanilla Transformer is well-known, and becomes a performance bottleneck of the engine.
There has been much research on how to optimize attention to a linear-complexity algorithm, often called “linear attention” algorithms. The main attempts have been:
- Parallel attention optimizations — going beyond MHA.
- Efficient attention algorithms — faster attention algorithms, notably “Flash Attention.”
- Removing some attention heads — called “attention head pruning.”
- Approximations of attention — not paying attention to every token.
- Sparse attention — using sparse matrices for attention.
- Low-rank matrices — using smaller matrices and matrix algebra.
- Alternative architectures — using other ideas instead of attention.
- QKV code optimizations — advanced coding speedups such as “KV caching” and “QKV tensor merging.”
Some of these are discussed more fully in other chapters. For example, KV caching is in Chapter 29, and attention head pruning is under “width pruning” in Chapter 48. Although we'll now examine some of these interesting attention speedup methods in extra detail, let's be honest in admitting these are mostly research techniques. The main way to optimize attention in industry practice: more GPUs.
Attention Head Approximation
Much of the research into attention heads has been in regard to attention head pruning (removing redundant or under-utilized attention head components) or reducing the quadratic cost of attention in terms of sequence length (related to non-autoregressive algorithms). However, there are also some “simple” or “approximate” attention heads that have been considered to replace the original Transformer components.
Note that the default Transformer's attention method is “full attention” where every token attends to every other one. This is what creates the quadratic complexity, because N tokens attend to N other tokens, giving N*N computational mappings.
The idea of approximation is to process fewer mappings. Some of the various approximate attention approaches include:
- Local attention: the idea is that each token only pays attention to a few prior tokens. This has also been called “sliding window attention” because the scope of attention moves along with the position. This method is fast but has obvious implications for accuracy.
- Sparse attention: Instead of having every token attending to every token, various attention algorithms reduce this mapping by paying attention to fewer tokens. There are numerous research papers that only consider tokens in various patterns. The cost reduction depends roughly on the sparsity level, but accuracy also declines with sparsity. For example, attending to only every second token is a simplistic idea in this vein, which would halve the cost, but there are much more sophisticated methods in the research literature.
- Random attention: one type of sparse attention is to randomly pay attention to different tokens in a random pattern. Although this sounds silly at first, it's an example where a stochastic (probabilistic) algorithm works in AI, although the model does suffer some accuracy degradation.
Note that in addition to these various faster attention algorithms, there are also research papers that try to do more computations to get a slower-but-smarter attention algorithm. For example, one approach is “double attention.”
Attention Head Pruning
Attention head pruning is a type of model width pruning where some of the less important heads are removed. Research has shown that some of the attention heads are more important than others, and there is some redundancy that can be removed.
Attention head pruning is a type of “width pruning” of a model (see Chapter 48). The pruning can be done statically as a type of model compression, or dynamically depending on the user's inputs.
Alternatives to Attention
Although attention has performed very well, there are still various attempts to replace it with something even better. Usually, the goal is a simpler type of attention algorithm with fewer computations involved. There are newer ideas, and also older ideas, where some of the research papers also try reverting back to earlier methods that existed before the current generation of attention algorithms. There is a lot of overlap with the research area of “non-autoregressive decoding algorithms” and “parallel decoding” (see Chapter 26) where the aim is also to avoid the quadratic cost of attention on long context sequences.
Long Context Research
Context size is the number of input tokens that a model can process. Early models, even ChatGPT, had small context sizes of about 2048. Each token is usually a part-word or a whole word, so this meant it could process inputs of about 1,000-2,000 words.
Why seek a longer context size? Because context length is not just the user's current prompt, but is also everything else that the engine must process. And note that context is not only input, but also output text, because the Transformer must track its own prior output during creation of the next token. Some examples where long context matters include:
- Completions: for the engine to write long text responses or creative essays, it needs to track the context of what it's already written, and its output text is also processed as input context as it works along.
- Chatbots or Q&A: The full length of the text to be processed is not just the user's current question, but also the full contextual history of the prior conversation.
- Retrieval Augmented Generation (RAG): the extra retrieved document “chunks” must be processed as context in addition to the user's query.
- Editing: the document to be analyzed is input context. A user's report could be 5,000 words, and a full-length novel is 50,000-100,000 words, or even 200,000 words if it's in the “epic sci-fi” genre.
Newer models have been increasing the context size. For example, GPT-4 has a 32k window size, which is 32,000 tokens, which will handle a small novella or novelette of maybe 15,000-20,000 words. Anthropic reportedly has a Claude model with a 100k context size, which will hold most document sizes. MPT has an open-source model called “MPT-StoryWriter-65k+” with a 65,000 token window size.
Why is there a context size limitation? One of the main bottlenecks is the “quadratic” cost of the self-attention mechanism. And there are various ways to optimize attention to overcome this limitation. However, it's not the only bottleneck, and Alperovich (2023) offers the "secret sauce" for long contexts as fixing three main bottlenecks:
- Quadratic attention cost (in the input token size)
- Quadratic size of internal tensors (in the model dimension)
- Positional embedding cost.
Some of the engine optimization techniques with relevance to allowing the processing and creation of longer token sequences include:
- Faster attention algorithms
- Autoregression optimizations
- Tokenization algorithms
- Token pruning
- Embeddings pruning
- Length pruning
- Positional encoding optimization
Length Generalization
Speed is not the only problem with long contexts. The vanilla Transformers are also not particularly good at generalizing their results with a long context size. Although a key innovation of the Transformer was its “attention” capability, the engine starts to lose track as the output elongates.
This ability to intelligently process long texts is known as “length generalization” (or “length extrapolation”), and improving the accuracy in long inputs and longer outputs is an area of active research.
One of the methods being analyzed to improve length generalization is called “scratchpad” or “chain-of-thought” algorithms. The idea is that the AI inference engine emits rough summaries to an internal scratchpad at regular intervals, which are merged into subsequent inference, thereby the AI helps itself keep track of its own “chain of thoughts” over a longer output sequence.
• Next: Chapter 21. Activation Functions • Up: Table of Contents |
The new AI programming book by Aussie AI co-founders:
Get your copy from Amazon: Generative AI in C++ |