Aussie AI
Float-to-Float Precomputation
-
Book Excerpt from "Generative AI in C++"
-
by David Spuler, Ph.D.
Float-to-Float Precomputation
Using a precomputed table lookup for a float-to-float function is more complicated than integers.
However, this is also the main approximation needed for AI activation functions, or even simple
math functions like sqrtf
or expf
or logf
.
Why is it tricky?
The reason that float
inputs are more difficult is that we need to convert a float
into an array index
in order to look it up.
For example, we could try type casts:
int offset = (int)f;
But that limits us to only precalculating values for 1.0, 2.0, 3.0, etc.
Our approximation works poorly on any fractions,
and we also haven't limited the array index to a fixed finite range,
so it won't work for any negative values or very large positive values.
And the type cast of a float
is also slow!
Scaled Multiple: Another idea is that we could scale it upwards to get more decimals:
int offset = (int) (f * 1000.0f);
This approach at least gives us 3 decimal places: e.g. 1.234 or 23.456, or similar. We will still have to check for negatives and large values to bound it. But again, this is even slower!
Bitwise Floating-Point Truncations:
The above truncation via a floating-point scaled multiple is not very fast.
Twiddling the bits is much faster.
For example, if we have a standard 32-bit float
type, it has 1 sign bit, 8 exponent bits, and 23 mantissa bits.
This is from left-to-right, with the sign bit as the most significant bit, and the low-end mantissa bits
are the least significant bits.
Remember that this is like Scientific notation:
- Number = Mantissa x 2 ^ Exponent
Also, the sign bit makes it all negative, if set. Note that exponent in 8-bits encodes the numbers -128 to +127, so that ranges from very small 2^-128 near-zero values, to very huge 2^127 sized values.
If the mantissa was in decimal, and it was “1234567” and the exponent was “17” then we'd have:
- Number = 1.234567 x 10^17
If the mantissa was 23 bits, it's actually binary digits, with about 3 binary digits per decimal digit, so a 23-bit mantissa is about 7 or 8 decimal digits. Note that the mantissa is actually 24 bits, not 23, because there's an extra “implicit one” mantissa bit, not that it changes the above calculation, but you needed to know that for C++ trivia night.
So, if we think about it for a year or two, it becomes obvious that the rightmost bits of the mantissa are simply the rightmost digits in “1.234567”, and if we truncate some of the rightmost bits, it's like truncating a very small fraction (e.g. “1.234567” becomes “1.2345” or whatever).
Hence, a first idea is just to cut off 2 of the 4 bytes of a 32-bit float
.
This leaves us with 1 sign bit, 8 exponent bits, and 7 mantissa bits (plus 1 implied bit makes 8 mantissa bits).
In decimal, the 8-bit mantissa now encodes only about 2 or 3 decimal digits, as if we've truncated “1.234567” to “1.23”.
Incidentally, congratulations, you've created “bloat16” type, which is what Google did with TPUs, making
a 2-byte float
format with 1 sign bit, 8 exponent bits, and 7 stored mantissa bits.
So, now you can get into your blue telephone booth, time travel back a decade, file a patent,
and retire on your royalties.
If you're ever a contestant on Wheel of Fortune you probably won't
need to know that the “b” in “bfloat16” stands for “brain float” and that is such a great name. But I digress.
Anyhow, this idea actually works for precomputation. A 2-byte integer in bloat16
format is easy
to extract from a 4-byte FP32 float (i.e., the uppermost two bytes).
The trick for bitwise processing is to convert the float
to unsigned int
, because the bitwise shift operators
don't work on float
(it's planned for C++37, as I heard at my fungus collector's club trivia night).
float f32 = 3.14f; unsigned u32 = *(unsigned int*)&f32;
Extracting the top-most 2 bytes (16 bits) is simply a right bitshift:
unsigned ubf16 = ( u32 >> 16 );
Note that here's a good reason that we had to use “unsigned
” integer type.
The right bitshift operator (>>
) has undefined behavior on negatives,
so “int
” type wouldn't work predictably (or portably) if the floating-point sign bit was set.
The result is a 16-bit unsigned
integer to use as the array index.
Hence, there are only 1<<16=65,536
entries in our precomputation table.
Assuming we store results as 4-byte float
values, this makes
the precomputation array's memory size about 262kb.
What's more, it works for negative float
numbers, because the sign bit is still part of that shemozzle,
and we also don't need to check any minimum or maximum bounds, because it works for all 32-bit float numbers.
Precomputing with 24-Bit Lookup Tables:
Interestingly, none of the above code is especially tied to 16-bit sizes.
The bfloat16
version truncates 32-bit float to 16-bit by truncating the rightmost 16 mantissa bits.
But we can actually choose to keep however many mantissa bits we like.
The trade-off is that more mantissa bits increase accuracy,
but at the cost of needing a much bigger precomputation array (doubling the storage size for each extra bit).
Let's try only cutting the rightmost 8 mantissa bits, leaving us with 24 stored bits total
(i.e. 1 sign bit, 8 exponent bits, and 15 stored mantissa bits).
The mantissa bits reduce from 23 to 15 (plus one implied bit makes 16), so this now stores
about 5 decimal digits (e.g. “1.2345”),
giving quite good precision on our results.
When I tested the 16-bit version, it had some reasonably large errors of almost 0.1 in computing sqrt
,
whereas this 24-bit version has much lower errors, as expected.
Code changes are minor.
The bitshift operations simply change from 16 bits to 8 bits (i.e. 32-24=8
bits).
This is the precomputation loop for 24 bits:
void aussie_generic_precompute_24bit_float(float farr[], unsigned int maxn, float (*fnptr)(float)) { for (unsigned int u = 0; u < maxn; u++) { unsigned int unum = (u << 8u); // 32-24=8 bits! float f = *(float*)&unum; farr[u] = fnptr(f); } }
And this is the call to the precomputation function in the startup phase:
aussie_generic_precompute_24bit_float( g_sqrt_float_24bit_precomp_table, // Bigger array float[1<<24] (int)AUSSIE_SQRT_24bit_MAX, // 1 << 24 aussie_sqrtf_basic_float // Function pointer );
The table lookup routine also similarly shifts 8 bits, rather than 16, but is otherwise unchanged:
float aussie_table_lookup_sqrt_24bit_float(float f) { unsigned u = *(unsigned int*)&f; u >>= 8; // 32-24=8 bits return g_sqrt_float_24bit_precomp_table[u]; }
Note that this only works if we are sure that both “float
” and “unsigned int
” are 32-bits,
so we should check that during startup with some assertions via static_assert
.
If we are sure of that fact, then not only will it work,
but we don't also need to check the array bounds.
It won't try a negative array index, and won't overflow no matter what bit pattern
we send it in as a float
.
But there is one problem. If we send the fast table lookup version
the special float
value of NaN
(“not a number”), then
the table lookup routine will actually return a valid numeric answer,
which probably isn't what we want.
Maybe we need to add a check for that special case, and this needs more testing.
The new size of the precomputation array is 2^24=16,777,216
,
so we have about 16.7 million results
Full disclosure: I used an AI because I ran out of fingers and toes.
If our results are 32-bit float
values, our bloat16
precomputed array above requires about 262kb,
and the new size with 24-bits is a lookup table (array) of about 67 megabytes.
It wouldn't have worked on my old TRS-80 CoCo in 1986, but it'll work nowadays.
• Next: • Up: Table of Contents |
The new AI programming book by Aussie AI co-founders:
Get your copy from Amazon: Generative AI in C++ |