Aussie AI

18. Parallel Data Structures

  • Book Excerpt from "Generative AI in C++"
  • by David Spuler, Ph.D.

“Bad programmers worry about the code.
Good programmers worry about
data structures and their relationships.”

— Linus Torvalds

 

 

Data Structures in AI Engines

The main data structures used in AI engines are vectors, matrices, and tensors. Examples of how these are used:

  • Vectors (1-D arrays): The input sequence of words is converted to a vector of tokens. Each token is processed to create an embedding vector.
  • Matrices (2-D arrays): The weights and activations are stored in matrices. Applying a set of weights to an embedding vector (which is a vector of probabilities) is a matrix multiplication of the weight matrix over the vector, creating a new vector that has updated probabilities (with amazing intelligence added).
  • Tensors (3-D arrays): Every “slice” of a 3-D tensor is a 2-D matrix. Because there are so many 2-dimensional matrix multiplications happening, it can be efficient to generalize this procedure into 3-dimensional tensors. It is very mind-bending to try to understand what's happening, but at its core, it's basically just doing a lot of 2-D matrix multiplications, where the 3-D structure of tensors allows for some fancy parallelizations.

Okay, we're done. There's more on vectors and matrices in Chapters 23 and 23, and that's all you need to know about data structures for AI. You can stop reading this chapter.

I'm only half kidding, because AI inference and training does a whole lot of vector and matrix operations (using tensors), and not a whole lot of anything else. In fact, I'm struggling to think of where in an AI engine there's even one hash table. Ah, yes, there's probably a small hash table in the tokenizer that maps 50,000 words to token numbers, but there doesn't need to be, because you could implement your tokenizer as an automaton, and that's more of an algorithm than a data structure. So, I'm going to say it out loud:

You don't need classic data structures in AI.

I think it's fair to say that a plain vanilla Transformer needs a lot of fancy coding of algorithms, but doesn't need all those obscure data structures you learned in Computer Science 101. You need maybe one hash table in the tokenizer, but then its vectors, vectors, vectors (e.g. embeddings, dot product, probabilities) and matrices, matrices, matrices (e.g. FFNs, attention heads, GEMM/MatMul kernels), some weird statistical math functions (e.g. activation functions, normalization, Softmax), and some AI-specific algorithms (e.g. decoding algorithms, parallelization, vectorization, tiling).

I'm not seeing any binary trees.

Where the data structures come out to play is when you try to optimize any of that tensor stuff to go faster. Then the roster of data structures looks like:

  • Lookup tables. Precomputed arrays are used to optimize activation functions, Softmax, and other mathematical methods. If you work in AI research for long enough, you'll call it a LUT, and it's your go-to data structure for speedups (and not in the Edsger Dijkstra sense).
  • Permutation arrays. Used to sort data without losing track of the indices (e.g. for mappings between word tokens and their probabilities) and also important for sparse matrices.
  • Bit vectors. Can be a fast way to do masks, or to mark some items as pruned.
  • Locality-sensitive hashing (LSH). This is “vector hashing.” Can be useful for optimizing weights and tracking previously seen inputs.
  • KV Caching. This is a widely used optimization that needs a specific hand-coded data structure.
  • Inference caching. This overall cache of user input strings can potentially be done using many data structures. Probably not a binary tree, though.
  • Bloom filters. These are a probabilistic combination of hashing and bit vectors. I've only seen these in research papers, although they look fast to me, and deserve more consideration.

The reason that classic data structures are missing from AI engines seems simple: parallelization. It's much easier to do parallel arithmetic on the contiguous memory blocks that underly vectors, matrices, and tensors. Similarly, lookup tables, permutation arrays, bit vectors, and vector hashing also have good vectorization characteristics.

Bit Vectors

Bit vectors are conceptually an array of N bits with 0 or 1 values. The term “bit set” is almost synonymous, but has a slightly different meaning. A bit vector maps a number at the index position to its binary bit value, whereas a bit set specifies whether a number is in a set of numbers. Both interpretations are valid, depending mostly on the application, and the underlying implementation of the data structure is almost identical.

In AI applications, a bit vector may represent a set of weights with 0 or 1 values, such as with binary quantization or XNOR neural networks. The operation of vector dot product on two bit vectors can be performed arithmetically using bitwise arithmetic.

Sparsity optimizations are another application of bit vectors. Pruning can often create “sparse” weight matrices, with lots of zeros and very few non-zero weights. A bit vector can then efficiently represent whether a weight in a vector has a non-zero value, which is then used to avoid doing any computations on zero values. An alternative to bit vectors for sparsity is to use permutation arrays of indices, as discussed further below.

Another application of bit vectors occurs in Bloom filter data structures, which are a probabilistic hybrid of hash tables and bit vectors. In this usage, a bit set represents whether an input number is found in the set of already-mapped numbers.

In practice, bit vectors or bit sets are often implemented as arrays of unsigned integers, with the bits packed into each integer. If the underlying unsigned type is 32-bits or 64-bits, then many bitwise operations on bit vectors can be performed 32 or 64 bits at a time, achieving significant parallelism without using any form of hardware acceleration beyond basic CPU instructions. Use of AVX SIMD instructions can then further vectorize many operations without a GPU. But it absolutely flies if you use a GPU with bit vectors or bit sets, because that's two levels of parallelization.

There are several pre-built C++ bit set classes that can be considered:

  • std::bitset<N> (in <bitset>)
  • std::vector<bool>
  • boost::dynamic_bitset<>

If the maximum size of the bit vector is known at compile-time, which is often the case with AI models, then std::bitset is a good choice. If not, then std::vector<bool> or boost::dynamic_bitset<> are good choices for dynamic-sized bit vectors. Alternatively, you can build your own bit vectors, if there is a particular need to hand-code them or if you just want some fun.

Permutation Arrays

Most of the vectors in AI engines are not just random lists of numbers. Rather, they are (conceptually) an array of the probabilities of output words, where the position in the vector indicates which word. So, if we have our logits array, then logits[0] is the probability of “the” whereas logits[1] is the probability for “cat”, and so on, up to about 50,000, which is a common vocabulary size for LLMs.

Problems arise if we want to sort our probabilities in the logit array, and we need this for our decoding top-k algorithm. We can't just sort the vector of probability numbers, because we'll lose track of which probability maps to which token number.

Permutation arrays to the rescue! A permutation array is an array that is the same size as some other array, but maps to the indices of the other array. A permutation array for our vocabulary has 50,000 integers, each of which is the index into other arrays.

The downside of permutation arrays is that they introduce inefficiency in both space and time. Space usage is increased by having two vectors. The time cost to access a vector element increases, too. Rather than just looking up the probability for the nth word in the logits array (i.e. “prob=logits[n]”), we have a two-step procedure:

    1. Look up the index in the nth element of the permutation array (i.e. “i=permut[n]”),

    2. Use that index to look up the probabilities in the main logits array (i.e. “prob=logits[i]”).

So, it's bigger and slower. Some rescue.

However, permutations can be valuable if it allows us to do much less arithmetic overall, which is the case with “sparse” arrays where most elements are zero. This is why permutation arrays are used for LLM sparsity optimizations, but not in normal practice.

Sorting with a Permutation Array: The way to sort another array, indirectly via a permutation array, is shown in detail for the top-k decoding algorithm in Chapter 26. The basic idea is:

    1. Set up the identity permutation.

    2. Sort using an indirect procedure: (a) compare elements in the main array indirectly accessed via the permutation array, (b) swap the indices in the permutation array (not changing the main array).

So, the original array doesn't actually get sorted with only the permutation array changing. If we want to print out the main array in a sorted list, we have to do so via the permutation array. The original main array is still unsorted if we access it directly.

Sparsity with Permutation Arrays. Sparsity is an optimization where most of the weights have been “pruned” to zero, and only a small percentage remain non-zero. This saves a lot of storage space for the model, and can also run much faster. The basic vector dot product kernel only needs to calculate with non-zero weights, so we want a way to avoid processing all of the many zero weights. Again, permutation arrays are the solution!

Sparse vectors (or matrices or tensors) can be stored as parallel arrays of:

  • Non-zero weights only
  • Permuted integer index of that non-zero weight in the original vector

These two arrays are much shorter than the original vectors if there is high sparsity. If sparsity is 90%, then 10% of numbers are non-zero, and the permutation approach uses two arrays, so it is 20% of the original size. The cost of doing a sparse dot product has reduced from the full length of the original vectors, down to the average sparsity factor (i.e. how many non-zero values). In other words, the number of multiplication computations goes down to 10% FLOPs, although there's the extra permutation calculation, so it's might seem like it's 20%, but we can often hardware-accelerate the permutation array step in CPU or GPU architectures. Hence, sparse vector dot products are fast. Calculation of the vector dot product for AI inference need only multiply using the much smaller number of non-zero weights.

Can we vectorize permuted arrays for hardware acceleration? Short answer: yes. Permutations can be vectorized with hardware acceleration in both CPU and GPU versions. The C++ AVX “gather” (load) and “scatter” (store) intrinsics work for x86 CPUs. Different GPU primitives are available for permuted arrays.

Sparsity doesn't really work without permutations. A raw full-size vector containing lots of zeros doesn't vectorize well, because it still sends all of those zeros for processing. A permuted index of sparse values works much better because it only considers non-zero values.

Vector Hashing

Vector hashing is needed in various parts of an AI engine as a speedup. There are various AI research papers on using hashing for various computations involving vectors and tensors of higher dimensions. Implementations of such algorithms are available in open source and commercial “vector database” products that you can use. Some of the applications for LLMs include inference caching, embeddings, and RAG architectures.

But how do you hash a full-length vector? Or a matrix? It's a complicated theoretical area. One of the main techniques is Locality-Sensitive Hashing (LSH), which is hashing to find vectors that are “close” in n-dimensional space.

One of the interesting research areas for vector hashing is total precomputation of vector dot products. Think about precomputation of vector dot products in AI inference. If you could hash the two vectors, then you could replace the main bottleneck in AI inference with two hash lookups. Is there a way to efficiently convert a vector dot product operation on two vectors into a hash lookup, thereby avoiding all those multiplications? What about speedup of matrix multiplication by hashing?

Remember that you can pre-compute anything about the weights before inference, because they don't change during inference. Hence, one of the vectors could potentially be pre-hashed offline. Maybe you could even use some type of “perfect hashing” for those vector hashes, if you've got a big enough compute budget. But you can't pre-hash both of the vectors or pre-compute the dot product, because the other vectors are dynamically calculated along the way, dependent on user inputs. This is being examined by advanced researchers, and is still a work in progress.

Perfect Hashing

Perfect hashing aims to achieve collision-free O(1) hashing at runtime, by investing a lot of offline compute budget to find an optimal hash function for a set of static data. There are many possible hash functions, and some are better than others. Perfect hashing tries to find an optimal hash function within the search space of possible methods. Mostly, it's by trial-and-error. Searching for a perfect hash function typically uses a brute-force and computationally expensive method of simply trying multiple hash functions and testing them for collisions.

Perfect hashing only works in the situation where all of the possible keys are known in advance (i.e. static data). Interestingly, this is exactly the situation with AI model vocabularies!

Hence, the idea of perfect hashing can be used to improve the performance of a hash table in the tokenizer. The general concept is that different hash tables are tested with various different meta-parameters (e.g. the hash table size, and multipliers in the hashing function). So, you can test various different hash functions against the 50,000 known tokens in the vocabulary, until you find a “perfect” one where there are no clashes. Amusingly, this longstanding algorithmic method sounds exactly like doing Neural Architecture Search (NAS) to find the best AI model hyper-parameters.

Bloom Filters

Bloom filters are a probabilistic data structure based on a combination of hashing and bit vectors. Multiple hash functions are computed for each key, and this is used to set bitflags, as described in more detail below. Bloom filters are mentioned in various research papers on AI, but are not yet used much in industrial AI applications. Perhaps they should be, as they seem very efficient.

Like hashing, Bloom filters have been used as a data structure to speed up neural network inference. However, much of the research literature about Bloom filters is about a different topic: Weightless Neural Networks (WNNs). WNNs have a different type of neuron based on binary bits, rather than matrix multiplications. These bitflag neurons can be approximated using Bloom filters. As such, that part of the research is less relevant to optimization of Transformer inference, and has not been examined in detail below.

How do Bloom Filters work? Given a key, multiple hash functions are calculated for that key, and a binary flag is set in a bitflag table for each of those hash offsets. In this way, an input key maps to a pattern of multiple bits.

The Bloom filter lookup for a key value works as follows: To test whether a key is found, the multiple hash functions are computed, and then the bitflag table is analyzed to see if all those bits are set. If any of the bits are missing, the key is not in the Bloom filter. If all of the bits are found, the key is probably in the Bloom filter, but it may also be that other keys have coincidentally set all those bits (a “false positive”), so it is not 100% guaranteed to be present.

If a probabilistic speedup is good enough, then a Bloom filter is all you need. For a 100% accurate table lookup, adding a second different type of backup data structure needs to be queried to confirm. Hence, the Bloom filter is a fast test to see if a key is not in a set, but a slow test if the key is found. This makes it an example of a “common case first” optimization, where fast computations may skip more involved computations.

The computational complexity of Bloom filters is constant, but not as fast as hashing. A hash filter uses only a single hash function, so it has O(1) lookup. However, a Bloom filter uses multiple functions, k, so it has O(k) lookup complexity.

 

Next: Chapter 19. Encoders and Decoders

Up: Table of Contents

Buy: Generative AI in C++: Coding Transformers and LLMs

Generative AI in C++ The new AI programming book by Aussie AI co-founders:
  • AI coding in C++
  • Transformer engine speedups
  • LLM models
  • Phone and desktop AI
  • Code examples
  • Research citations

Get your copy from Amazon: Generative AI in C++