Aussie AI
19. Encoders & Decoders
-
Book Excerpt from "Generative AI in C++"
-
by David Spuler, Ph.D.
“It's alive!”
— Mary Shelley, Frankenstein, 1818.
What are Encoders and Decoders?
The original 2017 Transformer had two major structures: encoders and decoders. Both are still used in various different ways by modern Transformers, and each of these structures has many sub-components and layers. However, an encoder-decoder architecture is not the only way to fly. In fact, at the top-level overview, the main types of Transformer architectures are:
- Encoder-decoder — original vanilla Transformer
- Encoder-only — BERT (research)
- Decoder-only — GPT-2, GPT-3
From this list you can see that although the very first GPT in 2017 was based on an encoder-decoder architecture, more recent commercial models including GPT-2 and GPT-3 use decoder-only engines. Encoder-decoder architectures have since been largely relegated to foreign language translation use cases, where encoding and decoding are very distinct, each in different languages. Encoder-only architectures have quite limited use cases, mainly where there's no need to produce a long output, such as classification. Anything that involves using an LLM to write a Shakespearean sonnet about your first-born child is running in a decoder-only Transformer architecture.
Why decoder-only? Research found that removing the encoder was faster, by greatly reducing the number of weights/parameters. Half of the LLM weights were unnecessary. This was possible because the encoder was largely redundant for most LLM-related use cases since its operation was similar to a decoder anyway. Instead of an encoder, decoder-only architectures use an initialization phase called “prefill” that runs an encoder-like process in the decoder. Hence, most models changed to decoder-only architectures from GPT-2 onwards, and subsequently for GPT-3, GPT-3.5, ChatGPT, InstructGPT, and (unofficially) GPT-4.
Meta's Llama and Llama2 models are also decoder-only, similar to GPT-3 versions. Google Gemini and its earlier Bard or PaLM versions are also based on a decoder-only architecture. Although Google's Gemini model was rumored to be reverting to a multimodal encoder-decoder architecture, its release information confirmed a decoder-only architecture. Pity all those poor unwanted encoders.
Transformer Layers and Components
Inside the encoder and decoder blocks there are lots of sub-structures, many of which are in “layers.” For example, GPT-2 is decoder-only with 12 layers of decoders. Sometimes these layers are called the “encoder stack” and “decoder stack.”
However, layers are not the whole story. Each layer has lots of sub-components, and there are also other parts of the engine that aren't in the layers. In fact, there are numerous low-level component parts of an AI engine, such as:
- Model Loader
- Tokenizer (input module)
- Embeddings
- Positional Encoding
- Vector Arithmetic (e.g. addition)
- Matrix Multiplier (MatMul/GEMM)
- Attention Heads (i.e. Q, K, and V)
- Feed-Forward Network (FFN)
- Activation Functions
- Normalization
- Softmax
- Linearization/De-embedding
- Decoding Algorithm (choosing words)
- Output module (formatting)
So, that's 14 distinct C++ modules you need to write. If we estimate two weeks for each, your engine will be done in a few months. (I wonder, dear reader, did you check my count in the above list?)
But that's not all. We've forgotten training and the above engine would be inference-only. All of the above components are related to both inference and training. The training-specific extra algorithms and modules include:
- Learning algorithms (e.g. supervised vs unsupervised)
- Training Optimizer (i.e. “gradient descent” method)
- Loss function
- Dropout
- Evaluation metrics
FAQs on Transformer Architecture
Here are some top-level questions about the architecture of modern Transformer architectures.
What is attention? Attention is an important underpinning concept in how LLMs work. The idea is for the model to focus its “attention” on particular tokens in a sequence of words, and parts needing the most attention are amplified by larger model weights. More about attention is found in the next chapter, if I haven't lost yours by then.
What is prefill? It's a Clayton's encoder: the encoder you have when you don't have an encoder. Don't worry, it's an Aussie joke. Prefill is an encoder-like phase at the start of inference for decoder-only architectures (e.g. GPT-2). There's no encoder, so the first step in a decoder-only architecture is to process the input prompt so as to “prefill” the internal embeddings with known data. It's very similar to having an encoder, but it's inside the decoder. The second phase that the decoder runs is then the “decoding” phase, which emits one token at a time.
What are linear and quadratic attention? These are statements about the efficiency or lack thereof in the “attention” phase of a Transformer. The vanilla 2017 Transformer had quadratic or O(n^2) complexity in the length of the inputs, which is slow for a long token prompt. Various modifications to the attention architecture in research papers, notably Flash Attention, have changed this to linear O(n) complexity, which is faster.
What are pre-norm and post-norm? This refers to the placement of the normalization module relative to the feed-forward networks in a Transformer architecture. The original 2017 Transformer used post-norm, with normalization after the outputs, but noted an instability in training. Thus, the first GPT used this “post-norm” architecture. Various researchers subsequently confirmed that changing the Transformer architecture to “pre-norm”, with normalization before the attention heads, was more stable, removed the instability, and thereby allowed for faster training. GPT-2 was subsequently released with a pre-norm architecture. Although the general view is the “pre-norm” is preferred, I'm still seeing some research papers that say the opposite, so this is somewhat unresolved.
What are BatchNorm and LayerNorm? These are normalization modules. BatchNorm came first and normalizes a vector of probabilities. LayerNorm was an extension to layerwise normalization and is more complicated, but is broadly regarded as having advantages over BatchNorm.
What's an igloo? Oh, you mean SwiGlu? That's a Swish function in a Gated Linear Unit (GLU). It's one of the many possible “activation functions” that you can choose. There's also RELU, GELU, leaky RELU, and a bunch more in research papers. See the chapter on activation functions, or just skip it, because I still have my doubts that these fancy functions are worth the effort.
What are autoregressive and non-autoregressive? The standard Transformer with the GPT architecture has an “autoregressive” decoding algorithm when it emits tokens. This means that it sends its own output back to itself (“auto”) and loops around again (“regressive”). The simplest decoding method is for the decoder to emit one token, and then it adds that new token onto the end of the input sequence, creating a longer “input sequence”, which is then processed again by the entire decoder stack to spit out the next one. In a word: sloooow. But very smart. Generally, non-autoregressive decoding algorithms, such as parallel decoding, will be faster than the default autoregressive mode, but possibly less accurate.
What is overfitting? When you put on a jacket and your hands don't appear. No, wait. Overfitting is an obscure statistical concept that the approximation (i.e. the AI model) fits the data too well, is too specific, and cannot generalize its insight to newer data. Any further attempt to explain this will just get me into trouble, because overfitting is something that everyone sort-of understands, but no-one can explain properly. The way I think about it, which isn't fully accurate but is a useful approximation, is that an overfitting model has “too much” capability to predict with too much specificity. Overfitting also doesn't really mean that the model has too many parameters, and could have been just as smart with fewer weights. At the very least, overfitting is better than underfitting, which means the model can't predict much of anything.
What's are linear and bilinear layers? The standard Transformer layer has a Feed Forward Network (FFN) component that consists of two linear layers. The term “linear layer” is a fancy way of saying matrix multiplication (similarly “linear projection”), where a matrix of weights is multiplied against a vector of probabilities (embedding vector) to get an updated vector of probabilities (with extra geniousness added). The default Transformer FFN's do a linear layer twice, with an activation function applied on the vector as an extra step between them (usually RELU). A “bilinear layer” is an FFN that's lost its in-between activation function, so it just does two matrix multiplies. Bilinear layers are not normally used in a Transformer, although researchers have tried.
What is masking? Well, it's not bit masks, if that's what you're thinking. It refers to attention masks for tokens. In an encoder-decoder Transformer, the encoder is allowed to examine tokens not only backwards, but also look ahead and see all of the possible future tokens in the sequence. However, the decoder is a naughty child that is only allowed to look backwards to the tokens it has already output. No copying allowed! This is done in the decoder's “attention” module by “masking” the lookahead tokens so that the decoder can't see them (no matter how hard it tries to peek). Encoders have a non-masked attention allowing lookahead, whereas decoders have a “masked attention” module only allowing look-backwards.
Advances in Transformer Architectures
Here are some of the notable research advances that have become commonplace in the industry.
Decoder-only architectures (e.g. GPT). As mentioned above, one of the major architectural improvements was to use decoder-only architectures, rather than encoder-decoder methods. Note that encoder-decoder models are still valuable for some use cases, such as foreign language translation (although GPT can also do this).
Pre-Norm beats Post-Norm: Probably the first improvement over the original 2017 vanilla Transformer architecture was to move normalization to apply on the inputs (“pre-norm”) rather than on the layer outputs (“post-norm”). This change sped up training because it solved a training instability problem in the original Transformer that had required a “warmup period” at the beginning of training. Modern commercial engines use pre-norm, but various research papers still continue to try post-norm with some success.
Quantization: This is so prevalent it hardly needs saying. Quantization changes weights from 32-bit floats to smaller data sizes, such as 16-bit floating-point or integers. There are numerous quantized versions of the major open source models available in the repos. Quantization to 8-bit integer, and even down to 4-bit integer, has become commonplace. This improves inference efficiency tremendously at the cost of a few percentage points of accuracy (perplexity).
Pruning: Pruning is removal of small or less important weights to reduce model size. Various types of model pruning are widely supported in model frameworks, such as PyTorch or TensorFlow. Unstructured pruning means removing or zeroing any weights that are too small. Structured pruning means removing whole Transformer components, such as layer pruning. In all types of pruning, removing weights allows the engine to run lighter and faster, with a reasonable trade-off in model accuracy.
Flash Attention: At first there were many attempts to overcome the quadratic complexity of attention on long contexts. However, Flash Attention followed by Flash Attention 2, seems to have succeeded as the best and is starting to be implemented on major engine platforms.
Rotary Positional Embeddings (RoPE): The RoPE method of adding positional encoding is becoming a standard way to (a) efficiently handle longer contexts, and (b) have models become better at understanding or generating long texts (“length generalization”).
KV Caching: There are various ways to do KV caching, and issues of how much to cache, but the general idea that a KV cache is required is widespread nowadays.
Flash Decoding: The decoding algorithm is the last part of the Decoder block, where it chooses the next output token. A new fast decoding algorithm from the team that created Flash Attention is now garnering some attention.
Model Loader
The first step in Transformer execution is loading the model. What does a model look like? At a basic level, it's simply a very large binary file containing mostly numeric data. The main things you'll find in a model file include:
- Header data with settings and hyper-parameters
- String data for tokens
- Lots and lots and lots of numbers.
Model Header. The start of the model file contains some header data and hyper-parameter values which define the “shape” of the model. For example, it will have the number of “layers” in the model (depth), and the size of the “hidden dimension” (width) and various other settings.
Billions of Numbers. Almost all of the model's size is taken up with floating-point numbers, because a 7B model will have literally 7 billion numbers, usually in 32-bit format (i.e. assuming it's an FP32 model). And these are further organized into sub-structures that represent “layers” and this includes “tensors” and other fun stuff.
Numbers are Static. The first point about these numbers is that they don't change. Model data is static data for any pre-trained model. These are read-only numbers that have been pre-computed “offline” during training or fine-tuning. When you run a model doing “inference,” these numbers don't change. The whole big bang of a full cycle of the entire model, when it spits out one word, actually runs on static numbers. Only if you're doing fine-tuning of a model will the numbers change again. The exceptions to this are the various dynamic inference methods, which are mainly at the research-level, whereas the default inference of a model is static.
All Numbers Are Processed. Another point is that all of these numbers get used. You might read that model files have lots of “redundancy” but that only means that lots of these numbers are less important, but they will all still be used for arithmetic, because it's hard to figure out which ones to discard. (It's they used to say about your advertising budget before, you know, cookies: half the money was wasted, but you didn't know which half.) An inference cycle will perform a floating-point operation on every single one of these billions of numbers. This is usually a multiply, but there's also additions, and these are all called “floating-point operations” or FLOPs. Since inference uses every single number in a model file and there are billions of numbers, there are GigaFLOPs of calculations just for the decoder to spit out one word (or a part-word or punctuation, or whatever token). And then the Transformer repeats all of that for the next word. Again, there are exceptions to this in advanced algorithms, such as “model pruning”, where some of the numbers are skipped.
String Data. The second type of model data is strings, and there's much less. Model files contain some string data for a few different purposes. There are a few descriptive strings that give names to things, and these strings are effectively overhead, since they're not part of the computation (e.g. they might appear in reports or be useful during debugging). The main string data in a model file is the “vocabulary” for the tokenizer. Typically, the model will have about 50,000 different strings that represent words, part-words, punctuation, and any fancy stuff (e.g. UTF8 codes for love heart emojis).
String Data is Static. Again, this string data is fixed at runtime. Actually, the string data is chosen right at the start when designing a model and can't even change during training! Hence, the strings that make up the tokenizer do not change during runtime inference of a model. The AI engine cannot learn a new word and add it to the vocabulary. So, whatever token set was setup in the vocabulary of the model before and during training has to remain the same during inference.
Load Order Matters. The order of the string tokens also matters in the model file. The inference engine treats tokens as numbers, using the offset of the string in the vocabulary array. If you mess up the model file loading of its vocabulary so that it's out-of-order or missing a few words, then the AI engine is going to get very confused and output gibberish.
Handling Unknown Words. A fixed vocabulary doesn't mean the AI engine falls over on unrecognized text. The tokenizer instead uses some default token strategies to handle unknown words or symbols. Individual digit letters are used for numbers, because it's difficult to encode every single number up to infinity as a separate token. Words that are not recognized are tokenized using part-words or in the worst case with individual letters. Unusual symbols, like emoji codes, are also tokenized to UTF8 single-byte tokens.
Engine Initialization. Since the numbers and strings are static, the model loader doesn't need to do anything to this data, other than to store the strings into a tokenizer module, and organize the numbers into tensors and layers. But that is kind of a lot of coding work anyway!
Other parts of the initialization involve getting ready to run a fast inference or training procedure. For example, the lookup tables to optimize the various non-linear (expensive) activation functions could be computed at program startup, although these really should be precomputed offline for a production model.
Note that in production deployment of an AI engine, this initialization cost isn't very important. A server should handle lots of queries, whereas this initialization occurs once, so any initialization time cost is amortized over many server queries. Even so, as any longtime Windows user knows, it's annoying if anything starts up slow.
Where is the magic? If all of the numbers are static, and the shape of the model is fixed and finite, how is it so smart? Moreso, how is it creative? I mean, it sounds like a robotic piece of number-crunching code. Yes, indeedy. It can neither feel bad nor taste dessert. The first part of the explanation of an LLM's abilities is that hyperscale brute-force simply works. Having a huge enough model of billions of weights mapping word probabilities is amazingly good at predicting which are the top 50 words that I should end this sentence with. That part is deterministic and also smart enough not to choose a preposition. The second part is “intentional randomness” introduced into this algorithm, mostly in the final “decoding algorithm” that chooses which of the highest-probability 50 words to pick. Or select. Or choose. Or culminate.
• Next: Chapter 20. Attention • Up: Table of Contents |
The new AI programming book by Aussie AI co-founders:
Get your copy from Amazon: Generative AI in C++ |