Aussie AI Blog
RAG Optimization via Caching
-
September 26, 2024
-
by David Spuler, Ph.D.
Types of RAG Speed Optimizations
Optimizing the latency of an LLM system using Retrieval Augmented Generation (RAG) can be achieved in various ways. The main approaches are:
- RAG component optimization
- General LLM inference optimizations
- Text-to-text caching ("inference cache")
- Global KV caching methods (general)
- RAG-specific KV caching
RAG Component Optimization
Every component in a RAG application architecture is critical to the overall latency as seen by the user. Hence, we can look at optimizing any of the main components:
- Retriever — use any of the well-known vector database optimizations, such as indexing and caching.
- Serving stack — general web architectural optimizations.
- LLM optimizations — various well-known industry and research approaches.
This section mainly deals with how to optimize the LLM infernece portion of the RAG stack.
RAG LLM Optimization
Optimizing LLM inference is a well-known problem, with literally thousands of research papers, and here's my list of 500 inference optimization techniques.
The first point about optimizing the LLM in a RAG application is that most of the general LLM optimization ideas apply. You're probably familiar with the various techniques:
- Buy a better GPU
- Quantization
- Pruning
- Small Language Models (SLMs)
And there are various inference software improvements that apply to a RAG LLM just as well as they apply to any LLM:
- Attention algorithm optimizations — e.g. Flash attention, paged attention.
- Speculative decoding (parallel decoding)
I'm not going to list all 500 of them.
General Text-to-Text Caching
Any LLM architecture can put a text-to-text cache in front of the query processor. Whenever the user types in the same query, we can output an answer that we've previously cached. There are two main types:
- Inference cache — exact text-to-text mapping (identical queries).
- Semantic cache — detects similar-enough queries with the same meaning.
Text-to-text caching works for RAG applications and is perhaps even more effective than for other LLM applications, because a text-to-text cache removes the need to launch the RAG retriever. We don't need chunks of text if we already have stored the answer.
Global KV Caching
Instead of text-to-text caching, we can go halfway and cache the internal computations of the LLM, called the "KV cache" data. The basic type of KV caching occurs in every query, one token at a time. But we can also store the whole KV cache data in a "global KV cache." We can do this for an "inference cache" of exact queries, but it doesn't work for semantic caching, because the number of tokens is usually different (i.e., the user query used slightly different words). KV cache data is very specific to a token sequence.
The idea is that for a known query, such as an inference cache, a huge part of the computation is cached in this internal data. Most of the tokens are processed, and there's no need to do any of the "prefill phase" that usually delays the LLM's first token. Only the decoding phase remains and the LLM can output its first token very quickly to give good user responsiveness.
Global KV caching works especially well for RAG applications because:
- Input data is large — e.g. chunks of text
- Queries are small — usually questions of a few words.
- Output is smaller than the input (it's usually a summary extracted from a chunk).
Hence, this type of global KV caching works in any LLM and is a good candidate for RAG caching.
RAG-Specific Caching Optimizations
But what can we only do in a RAG architecture? How can we make the LLM run faster when we're looking up most of the answer as text chunks in a vector database? The main answers are:
- Prefix KV caching
- Precomputed RAG KV cache
Prefix KV caching is a slight generalization of global KV caching, in that it will cache any prefix of tokens that it's seen before. If you think about it, there are a lot of common prefixes in a RAG architecture:
- Global prompt instructions (prepended)
- RAG chunks (prepended)
- User query (a few words)
We have a common prefix every time consisting of the global instructions plus the first RAG chunk. The prefix has more than one RAG chunk if our vector database is deterministic in returning the chunks in the same order.
Prefix KV caching is a very powerful technique, and it's been implemented in various commercial inference frameworks, such as vLLM, DeepSeek, and Anthropic. It's so effective that some LLM API providers are offering "cached tokens" at a much lower price point.
Precomputed RAG Cache
The main problem with all of the above caching techniques is that they require a user's query to be seen already. We don't have a cache the first time.
So why not precompute it?
In order to precompute our cache, we could try to guess all of our user's likely queries. Then we can pre-populate the cache by running all those candidate user queries.
In fact, that would work for any of the above caching methods, but it's somewhat hard to know all the variations of words that users might choose. And remember that semantic caching doesn't work for KV data, so we can't use vector embeddings to help with a KV cache.
A better idea is to precompute our cache for each RAG chunk.
Precomputed RAG Chunk KV Cache
The idea here is to store the KV cache data in the vector database. For every text chunk, we have a precomputed KV cache. When the retriever returns a chunk, or more than one chunk, it also returns any KV cache data that it has for that chunk.
Note that if we don't want to precompute for all of the text chunks, then we can precompute for a subset of the most frequently used chunks, and then add more KV cache data to the RAG cache, as other chunks appear in user queries.
However, our RAG chunk is not actually the prefix for the LLM if there's a prepended set of global instructions. The trick is to precompute the KV cache data for the combined query: global instructions plus the RAG chunk.
This assumes that if we change the global instructions, we have to discard any previously cached KV data. Similarly, this doesn't work if there's per-user global instructions, because we don't want to store a wholly different cache for every user!
If this approach works, the only part of the query that isn't precomputed is the user's query. Usually, it's only a few words of a question.
Multiple RAG Chunks
You've probably noticed a problem with the precomputed RAG KV cache: what if the retriever returns multiple chunks. After all, it usually does!
Yes, there's a problem here, because the retriever could return the chunks in different orders. It's not always going to be the same prefix.
If we have 10 chunks, and have 10 precomputed KV caches, we can't just munge them together. Furthermore, each of our precomputed KV caches might be assuming it's the one just after the global instructions, whereas only 1 of the 10 chunks is actually the prefix chunk, so we've got a big mess.
There are several basic approaches used in the research literature to resolve this problem:
- Get the retriever to favor returning the one RAG chunk with a cache.
- Store multiple caches for different RAG chunk orders.
- Fused KV caching
Using a "cache-aware" RAG retriever asks the retriever to try to return the RAG chunks that have a cache as the first one. But it's not a very general approach, because the other 9 RAG chunks don't have a cache, or even if they have one, we can't use it, because they're not the prefix chunk.
Storing multiple KV caches for different chunk orders is possible, but it's very expensive. If the chunks are coming out in different orders, we need to cache each unique order. If we have 10 chunks in an retrieved result, there are 10! cache entries we could use. This can be done with prefix KV caching, but we can't precompute all of that!
The last idea is "fused KV caching." I said above that we can't just munge multiple caches for chunks all together. Sorry, I lied, because you actually can.
Fused KV Caching of RAG Chunks
This is definitely at the cutting edge of the AI research papers. There's only a handful of papers that look at this idea of fusing KV caches together.
If we have a query sent to the LLM, after the retriever has returned 10 chunks, then the LLM sees this:
- Global instructions prompt
- Chunk 1
- ....
- Chunk 10
- User query text
Each of those parts have their own KV cache data precomputed. We have a general precomputation for the global instructions, and our retriever returns 10 separate KV caches for each of the chunks. The only part without a KV cache is the user's query text.
How do we merge multiple KV caches? It seems like it should be a fancy GPU computation, and yet, no.
Instead, we just put one after the other, in a sequence, and pretend that we've computed the cache for the full token sequence. This is "fused KV caching" and it's obviously a fast method that only requires memory copies and literally zero extra computations.
I really don't understand why this works at all, let alone works well. Nevertheless, there are a couple of research papers that assert that it seems to be effective.
Even so, I feel like it needs a lot more research. Also, there are literally zero research papers on using "fused KV caching" with other optimizations, such as linear attention algorithms or hybrid local-global attention, which would be an obvious improvement to make RAG run even faster.
Further Reading on RAG Caching
- RAG caching research
- Prefix KV caching
- Fused KV caching
- LLM caching research overview
- 500 inference optimization techniques
CUDA C++ Optimization Book
The new CUDA C++ Optimization book:
Get your copy from Amazon: CUDA C++ Optimization |
More AI Research Topics
Read more about: