Aussie AI
Example: AVX Vectorized Dot Product
-
Book Excerpt from "Generative AI in C++"
-
by David Spuler, Ph.D.
Example: AVX Vectorized Dot Product
Here is the basic non-vectorized dot product computation without any optimization attempts.
float aussie_vecdot_basic(float v1[], float v2[], int n) { // Basic FLOAT vector dot product float sum = 0.0; for (int i = 0; i < n; i++) { sum += v1[i] * v2[i]; } return sum; }
To use AVX to vectorize it, we need to unroll the loop first.
Here's a simple vector dot product with its inner loop unrolled 4 times.
This version assumes that n
is a multiple of 4 rather than handling odd cases:
float aussie_vecdot_unroll4_basic(float v1[], float v2[], int n) { // Loop-unrolled Vector dot product if (n % 4 != 0) { yassert(n % 4 == 0); return 0.0; // fail } float sum = 0.0; for (int i = 0; i < n; ) { sum += v1[i] * v2[i]; i++; sum += v1[i] * v2[i]; i++; sum += v1[i] * v2[i]; i++; sum += v1[i] * v2[i]; i++; } return sum; }
So, now we can change those 4 unrolled multiplications into one AVX computation
of the vector dot product of 4 float
numbers.
#include <intrin.h> float aussie_vecdot_unroll_AVX1(float v1[], float v2[], int n) { // AVX-1 loop-unrolled (4 floats) vector dot product if (n % 4 != 0) { yassert(n % 4 == 0); return 0.0; // fail } float sum = 0.0; for (int i = 0; i < n; i += 4) { // AVX1: Vector dot product of 2 vectors // ... process 4x32-bit floats in 128 bits __m128 r1 = _mm_loadu_ps(&v1[i]); // Load floats into 128-bits __m128 r2 = _mm_loadu_ps(&v2[i]); __m128 dst = _mm_dp_ps(r1, r2, 0xf1); // Dot product sum += _mm_cvtss_f32(dst); } return sum; }
This basic AVX sequence of code to do the 4 float dot product has been analyzed in a separate chapter.
The main dot product computation is “_mm_dp_ps
” which is an AVX intrinsic
and multiplies 4 pairs of 32-bit float
numbers, and then sums them, all in one call to an intrinsic.
Note that the loop now iterates 4 at a time through the array of float
values (i.e. “i+=4
”)
and then the AVX intrinsic does the rest.
Here's the benchmark analysis showing that the AVX-vectorized version is more than twice as fast:
FLOAT Vector dot product benchmarks: Time taken: Vecdot basic: 2805 ticks (2.81 seconds) Time taken: Vecdot AVX1 unroll (4 floats, 128-bits): 1142 ticks (1.14 seconds)
Fused Multiply-Add (FMA) in AVX-2. The AVX-2 FMA intrinsic takes 3 vectors, each of size 256-bits, multiplies two of them pair-wise, and then adds the third vector. Both the multiplication and addition are done in element-wise SIMD style. At first blush this sounds like doing a vector multiply and then adding a “bias” vector, and hence doesn't sound like a good optimization for the vector dot product. The SIMD pairwise multiplication is the first step of dot products, but the vector addition seems the opposite of what we want, which is “horizontal” addition of the products that result from the multiplications.
The default idea is doing a dot product of 8 float
values, and then another one,
and then adding each individual sum at the end.
With that idea, the vertical addition in FMA is not what we want,
and it looks like using SIMD multiplication and an extra horizontal addition would be better than using a single FMA intrinsic.
However, we can make like Superman III...
Reverse it!
If you think about FMA not as a multiplication and then addition, but as “adding multiplications” in the reverse order, then there is a eureka moment: put the addition first. The idea is that we can maintain a vector of running sums, and then only do a single horizontal addition at the very end. It's kind of mind-bending, but here's the code:
float aussie_vecdot_FMA_unroll_AVX2(float v1[], float v2[], int n) { // AVX2 vecdot using FMA (Fused Multiply-Add) primitives if (n % 8 != 0) { yassert(n % 8 == 0); return 0.0; // fail } __m256 sumdst = _mm256_setzero_ps(); // Set accumulators to zero for (int i = 0; i < n; i += 8) { // AVX2: process 8x32-bit floats in 256 bits __m256 r1 = _mm256_loadu_ps(&v1[i]); // Load floats into 256-bits __m256 r2 = _mm256_loadu_ps(&v2[i]); sumdst = _mm256_fmadd_ps(r1, r2, sumdst); // FMA of 3 vectors } // Add the final 8 accumulators manually float* farr = (float*)&sumdst; float sum = farr[0] + farr[1] + farr[2] + farr[3] + farr[4] + farr[5] + farr[6] + farr[7]; return sum; }
How does this work?
Well, we declare “sumdst
” as a vector of 8 float
numbers that maintains
the 8 parallel accumulators,
which is first initialized to all-zeros via the “_mm256_setzero_ps
” intrinsic.
In the main loop, we use “sumdst
” to maintain a running sum
in all 8 of those parallel accumulators across multiple segments of the vector.
One accumulator sums the products in array indices 0,8,16,..., and the
next accumulator sums the products for indices 1,9,17,...
We use the FMA intrinsic (“_mm256_fmadd_ps
” in AVX2)
to do the SIMD multiplication,
but rather than trying to add the 8 resulting products together,
we add each product to a separate accumulator.
This works very neatly, because the AVX-2 FMA intrinsics does this all in SIMD parallelism with the combined FMA intrinsic.
Only at the very end, after the main loop, we do a horizontal add of the 8 parallel accumulators to
get the final sum.
This idea works surprisingly well, and is gratifying since I couldn't get the AVX-2 256-bit
version with the dot product “_mm256_dp_ps
” intrinsic to run correctly on 8 float
values.
Here's the benchmarking, which shows that AVX-2 using FMA on 8 float
values in parallel runs
much faster than the AVX1 unrolled vector dot product using the intrinsic “_mm_dp_ps
” with 4 float
values.
FLOAT Vector dot product benchmarks: (N=1024, Iter=1000000) Vecdot basic: 2961 ticks (2.96 seconds) Vecdot AVX1 unroll (4 floats, 128-bits): 1169 ticks (1.17 seconds) Vecdot AVX1 FMA (4 floats, 128-bits): 1314 ticks (1.31 seconds) Vecdot AVX2 FMA (8 floats, 256-bits): 783 ticks (0.78 seconds)
Note that we can improve on the horizontal addition at the very end.
The example code just uses basic C++ with 7 additions and 8 array index computations.
Instead, this last computation should really
use some AVX “hadd” intrinsics instead (it needs 3 calls to horizontal-pairwise add 8 float
values).
• Next: • Up: Table of Contents |
The new AI programming book by Aussie AI co-founders:
Get your copy from Amazon: Generative AI in C++ |