Aussie AI
29. Caching Optimizations
-
Book Excerpt from "Generative AI in C++"
-
by David Spuler, Ph.D.
“The modern world is not geared properly to the storage of goods.”
— Benjamin Graham.
Caching is the general optimization method where computed results are stored and re-used instead of repeating a later computation. Generally, the idea is to trade off use of extra memory in order to save on execution time. This mainly works if the same exact computations are being repeated, but can also work for repetitions of similar near-identical computations. In the research literature, caching algorithms for neural networks are also called “memoization,” “data re-use” or “computation re-use” algorithms.
There are at least seven caching optimizations known for Transformers:
- KV caching
- Encoder/prefill KV caching
- Inference cache
- Semantic cache
- Vector dot product caching
- Input similarity caching
- Cached matrix transpose
KV caching is the best known of these optimizations, and relates to the K and V tensors used in the QKV attention mechanism. It was discovered quickly that some of the K and V tensor calculations could be cached between tokens, thereby avoiding repeated matrix computations in the usual autoregressive model. This is only temporary caching used while processing a single query, rather than across multiple user queries.
Caching can also be done at the highest level with model inference answers stored in a global cache. Inference results can also be cached across multiple queries from multiple users, so that repeated identical queries need not be re-computed. When the entire results of an inference calculation are saved and re-used, the optimization is called an “Inference Cache.”
Vectorized caching is also possible for non-exact cache matches in at least two ways. Semantic caching with vector hashing or vector databases can help identify user queries that are non-identical, but have the same meaning, and need not be fully computed. Incremental caching of full inference results can be used with “input similarity” algorithms, such as when analyzing the individual frames of a video in an AI engine. Optimizations such as frame skipping or partial image caching are possible.
Some research papers have attempted to use caching deep inside the engine to reduce the sheer number of weight multiplications in matrix computations. Low-level caching and computation reuse can be done even at the vector dot product level. By detecting when similar vectors have been calculated before, such as using Locality-Sensitive Hashing (LSH), the cache results of the dot product calculation can be accessed and re-used instead. This approach does not seem to have reached production usage yet.
Before introducing the various caching methods, one proviso: not all queries can be cached. Any time-dependent queries cannot be cached (over a long duration), either in terms of the text outputs or the KV calculations, because they differ over time. Consider the response to this user query: “What day is today?”
KV Caching
KV caching is storing the results of the K and V vector operations that are performed in Transformer attention heads. Analysis of the vanilla Transformer by researchers has discovered at least two distinct ways to cache these results.
- Autoregressive KV caching
- Global encoder/prefill KV caching
Autoregressive decoder KV caching: This is in-memory caching during one query as the decoder processes multiple tokens. Partial KV tensor operations can be cached in memory during decoding, across the processing of multiple tokens, avoiding re-computations in the next cycle of decoder stacks. In autoregressive decoder mode, the extra KV computations related to the new token are not cached, but all prior KV-related calculations can be cached. This is a subtype of autoregression optimization.
Uncaching KV: Care must be taken in special cases with KV caching to keep the cache accurate and updated. This is particularly true in algorithms that “look ahead” but must sometimes “backtrack” to a prior token. Caching is efficient when moving forwards, but some of the cached items must be flushed and the cache recalculated whenever there is a token rejected. For example, this occurs in speculative decoding, parallel decoding, beam search decoding, and various other non-autoregressive decoding algorithms. It may also occur in various research algorithms such as token pruning, token merging, and prompt compression.
Global KV Prefill/Encoder Caching
Basic KV caching stores the values of K and V across multiple token processing phases, but only within the one query. It is a form of temporary local cache. This helps with autoregression complexity in long sequences, but won't be stored between queries, and there is no global cache used across multiple queries. At the other extreme is a full “inference cache” that stores the results for identical queries in a global cache of all prior answers, and may completely avoid the inference expense for any cached answers that are used.
In between these two approaches is the encoder/prefill KV caching method. This is on-disk caching of the prefill/encoder KV results across multiple user queries. For an encoder-decoder architecture, this stores the K and V results after the encoder has finished. For a decoder-only architecture, the KV results after the encoder-like “prefill” phase are stored.
This idea avoids the expense of running the encoder or the prefill phase, but the full decoder stack is still executed. Hence, it is a partial caching of the inference algorithms, and significantly different answers can result from the randomness inherent to the various decoding algorithms (e.g. top-k decoding).
The KV operations can be cached for identical queries, across many users, so that when a user inputs the same text, the KV operations do not have to be re-done, but can be loaded from a disk cache. If there are no cached KV results detected, the full encoder/prefill is performed without caching, and its results can be added to the cache.
The simplest approach is to cache KV results for exactly identical queries. Also possible is to extend this idea to a “semantic cache” with vector caching, which caches the encoder/prefill KV results for any “close enough” queries. This differs from the full semantic cache, because only the encoder/prefill data is cached, rather than the resulting logits or the full answer text.
Inference Cache
A full inference cache is where the entire results of a model inference are stored, and re-used for a later identical query. For example, such an approach would recognize that 100 users are all submitting “This is a test,” whether concurrently or over time, and would do the inference computation only the first time, and retrieve it from the cache for the other 99 users.
Inference caching could involve storing the actual identical results in text form, in which case all users would get exactly the same response. Alternatively, a more flexible approach that still avoids most computations is storing the near-final results, in some intermediate form with logits (probabilities), and a final brief computation can still emit different results to different users. In this way, most of the computation is avoided, and some variability is added to the final output. Another simpler way to add variability to responses would be to cache more than one possible answer for a given input.
Caching logit arrays fails. Here's an idea. For every token, cache the array of logits (or their probabilities after exponentiation via Softmax), or rather, just cache the top-k logits for each output token. This is 50 times more space usage than just the token list, but also has more useful information. Then, rather than emit the exact same token list for a cached user query, we can re-run the top-k decoding algorithm to get a different random choice from the top-50 tokens.
Unfortunately, it doesn't work very well, because if you change one of the words early in a sequence (i.e. choose a different token with a random top-k), then the whole sentence should change. Changing one choice of token should alter the words that would be in the top-k for all subsequent tokens, but they probably won't be the ones we've cached, and the probabilities would be wrong even if we did happen to cache them.
Another use case for a full inference cache is where the input is similar or continuous. This is typically the case in image processing for machine vision (e.g. self-driving cars) or video analysis (e.g. security camera monitoring). There are many frames per second and they are often not very different. In such cases, the cached inference results from the previous image or frame can often be re-used with modifications for a faster incremental algorithm, or alternatively, the entire results from the previous frame can simply replace the current inference computation (i.e. “frame skipping”).
Semantic Caching and Vector Databases
Semantic caching refers to a partial inference cache that finds cached responses not only to identical queries, but also to queries that are “semantically similar” to a cached query. For example, these queries have different token sequences, but the same meaning:
What is a dog What is a dog? What is dog? Dog what is it
Immediately, we can think of various heuristics that can detect very similar queries, and poll a cache. For example, the cache lookup could detect question mark punctuations and words reordered in a query. These ideas could help improve speed a little, but the heuristics won't generalize to the semantic meaning (e.g. synonyms or other equivalent phrasings).
Vector hashing. The generalization of these heuristics is to use “vector hashing” to find semantically similar queries. The idea is that we first create a vector out of the query, which could be the token vector, but more likely effective is the embeddings vector. Then we can use “vector hashing” to find the “nearest neighbor” of that vector in N-dimensional space that is stored in our cache. Returning the cached results avoids any further computation on the query.
That sounds good, but it's glossed over something important: cache misses. For example, if our cache has seen only “What is a cat?” and this is returned as the nearest-neighbor vector for the query “What is a dog?” then the answer won't be very accurate. What's missing is a discussion of “closeness” of the two vectors, whereby the cached vector can be rejected, and a full inference cycle executed (and then its results are added to the cache).
Semantic cache lookup needs to have both cache hits and misses, like any normal caching algorithm. The semantic cache needs to make sure that the two vectors are similar enough (i.e. a “cache hit”), and this requires a measure of closeness between the query vector and the cached vector.
Vector databases. To implement our vector hashing capability for the semantic cache we can use Locality Sensitive Hashing (LSH) or some other algorithms. It's just a small matter of coding. Alternatively, there are “vector databases” available that have already implemented this functionality. Vector databases have been in use for years in various semi-AI functionality such as semantic document indexing and image similarity analysis. For example, open source and commercial vector databases include Pinecone, Weaviate, Milvus/Zilliz, Chroma, FAISS, Vespa, Qdrant, and Vald, to name a few.
Note that semantic caching with a vector database is technically a type of approximation. There is a trade-off in setting the level of closeness of two vectors for which the cache is used. If we set the threshold too high, then some answers will be wrong for the query. If the threshold is low, then the cache will miss often, and there will be the expense of computing additional inference queries.
Cached or Precomputed Transpose
Matrix multiplications can be much faster if they operate on the transpose, because this has the columns stored in sequential memory addresses. Our MatMul/GEMM kernel is much faster if it can send sequential blocks of data to the GPU, and it's also faster for CPU-only versions of matrix multiplication because of data locality benefits that speed up memory accesses.
The value of using a transpose of a matrix is so significant that we can calculate the transpose on the fly if we need it. Creating a transpose is O(N^2) and we are speeding up the O(N^3) MatMul operation, so the extra benefit is worth the cost. Then we can further optimize by storing the transpose in a cache for next time.
On the other hand, why not precompute the transpose! If it's the transpose of a weight matrix, then it's known at compile-time (i.e. pre-inference time), and we could fully precompute it and store it with the rest of the model. Thus, a significant way to optimize matrix multiplications is to store both versions of the matrix in memory: the original matrix and its transpose. This can help to speed up inference by:
(a) avoiding the need to compute the transpose on-the-fly, and
(b) by having the transpose already laid out properly in contiguous memory for pipelining and dataflow efficiency.
Note that this transpose caching method doesn't work as well for training or fine-tuning, because the weights in the matrix change often, and our cached transpose would become out-of-date. Further details of using a precomputed transpose in MatMul are covered in Chapter 34.
Vector Dot Product Computation Reuse
A lot of the low-level computations of tensor multiplications in AI inference break down into a vector dot product computation (also called a “scalar product”). This is a multiplication-addition (multiply-accumulate or MAC) of all of the paired elements of two vectors to create a single number, usually a 32-bit floating-point result. Hence, it's a good candidate for caching with computation reuse since it's a large amount of arithmetic, and the result is only a single floating-point number, which won't need much extra space to store. The trade-off is attractive, with a small amount of extra storage used to avoid significant computations.
Another interesting feature is that one of those vectors is static during inference (e.g. the weights vector), whereas the other vector operand is dynamic (activations). Hence, the idea with vector computation reuse is to cache the computed dot product results and then detect when the second, incoming (dynamic) vector is the same as, or similar enough to, a previous incoming vector.
An alternative way to approach this is by combining pairs of vectors and using this as the cached vector. The two vectors are simply concatenated and treated as a single vector of double length, with the vector caching methods computed on the longer vector.
Various researchers have looked into this type of vector caching. The main methods to detect similar vectors are:
- Locality-Sensitive Hashing (LSH)
- Bit signatures
- K-means clustering
- Hyper-Cube
LSH is the most popular method, which uses cosine similarity to find reasonably “close” vectors in n-dimensional space. If the vectors are similar enough, the cached result is a reasonable approximation of the dot product computation, which is thereby avoided. If there are no vectors close enough, then the full dot product must be performed, and its result can be added to the vector cache.
The cost of looking up the cache must be low for this method to be effective, since the computation is being done in the busiest part of the AI algorithm: vector dot products inside matrix multiplications. Hence, this method typically uses hand-coded in-memory vector caching methods rather than vector databases. Theoretically, an in-memory highly tuned vector database could also do the job.
Assuming similar vectors can be identified efficiently, the question is: how often does AI model inference perform vector computations on similar vectors? What is the cache hit rate? At the start the vector cache is almost empty, and it takes a while for there to be enough vectors in the cache to get a high hit rate. But after the “warm-up” period, research papers seem to indicate that it's rather a lot, with some reporting 50% speedup of inference over time.
Input Similarity-Based Caching
When an input is similar enough to a prior input, the previous inference results can be cached and re-used. This is applicable to analysis of continual feeds, such as audio or video frames, where the incremental differences are relatively small. This is a type of incremental algorithm for neural network inference.
The overall idea is to detect situations where the input does not need to be processed, because it is similar enough to the previous input. There does not need to be a large cache of previously seen images. Although that can be done, too, it's a different algorithm (i.e. it's the “Inference Cache” idea). For the input similarity approach, only the results from the previous frame are needed. If the previous frame's results are close enough, the new frame can be “skipped” and the prior results retained.
The choice is basically whether to re-use the inference results from the prior video frame, or to re-compute a new set of results. Potentially, the same results can be re-used for multiple frames, if there are minimal changes, but eventually a new computation will be required.
Input similarity could be checked using vector hashing or vector databases. However, more commonly for images, there are non-vector methods to detect when images only have minimal changes between them.
• Next: Chapter 30. Vectorization • Up: Table of Contents |
The new AI programming book by Aussie AI co-founders:
Get your copy from Amazon: Generative AI in C++ |