Aussie AI
AVX Vector Max and Min Reductions
-
Book Excerpt from "Generative AI in C++"
-
by David Spuler, Ph.D.
AVX Vector Max and Min Reductions
The need to find a minimum or maximum of a vector's elements
is similar to a summation reduction.
Again, AVX1 and AVX2 don't have proper “reduction” intrinsics for max
or min
,
but we can compute them in parallel by keeping a running min
or max
value
of 4 or 8 float
values
(i.e. analogous to parallel accumulators when doing summation).
The AVX intrinsics are:
MIN
:_mm_min_ps
,_mm256_min_ps
MAX
:_mm_max_ps
,_mm256_max_ps
Here is the AVX1 version of MAX
vector reduction:
float aussie_vector_max_AVX1(float v[], int n) { // Maximum (horizontal) of a single vector if (n % 4 != 0) { yassert(n % 4 == 0); return 0.0; // fail } __m128 sumdst = _mm_loadu_ps(&v[0]); // Initial values for (int i = 4 /*not 0*/; i < n; i += 4) { __m128 r1 = _mm_loadu_ps(&v[i]); // Load floats into 128-bits sumdst = _mm_max_ps(r1, sumdst); // dst = MAX(dst, r1) } // Find Max of the final 4 accumulators float* farr = sumdst.m128_f32; float fmax = farr[0]; if (farr[1] > fmax) fmax = farr[1]; if (farr[2] > fmax) fmax = farr[2]; if (farr[3] > fmax) fmax = farr[3]; return fmax; }
And here is the analogous AVX2 version of MAX
vector reduction:
float aussie_vector_max_AVX2(float v[], int n) { // Maximum (horizontal) of a single vector if (n % 8 != 0) { // Safety check (no extra cases) yassert(n % 8 == 0); return 0.0; // fail } __m256 sumdst = _mm256_loadu_ps(&v[0]); // Initial 8 values for (int i = 8/*not 0*/; i < n; i += 8) { __m256 r1 = _mm256_loadu_ps(&v[i]); // Load floats into 256-bits sumdst = _mm256_max_ps(r1, sumdst); // dst = MAX(dst, r1) } // Find Max of the final 8 accumulators float* farr = sumdst.m256_f32; float fmax = farr[0]; if (farr[1] > fmax) fmax = farr[1]; if (farr[2] > fmax) fmax = farr[2]; if (farr[3] > fmax) fmax = farr[3]; if (farr[4] > fmax) fmax = farr[4]; if (farr[5] > fmax) fmax = farr[5]; if (farr[6] > fmax) fmax = farr[6]; if (farr[7] > fmax) fmax = farr[7]; return fmax; }
The MIN
versions are very similar. They use the “min
” AVX intrinsics,
and the final steps use “<
” not “>
” operations.
Here's the AVX1 version of a MIN
vector reduction:
float aussie_vector_min_AVX1(float v[], int n) { // Minimum (horizontal) of a single vector if (n % 4 != 0) { yassert(n % 4 == 0); return 0.0; // fail } __m128 sumdst = _mm_loadu_ps(&v[0]); // Initial values for (int i = 4 /*not 0*/; i < n; i += 4) { __m128 r1 = _mm_loadu_ps(&v[i]); // Load floats into 128-bits sumdst = _mm_min_ps(r1, sumdst); // dst = MIN(dst, r1) } // Find Min of the final 4 accumulators float* farr = sumdst.m128_f32; float fmin = farr[0]; if (farr[1] < fmin) fmin = farr[1]; if (farr[2] < fmin) fmin = farr[2]; if (farr[3] < fmin) fmin = farr[3]; return fmin; }
This is the AVX2 version of a MIN
vector reduction:
float aussie_vector_min_AVX2(float v[], int n) // Minimum (horizontal) of a single vector { if (n % 8 != 0) { // Safety check (no extra cases) yassert(n % 8 == 0); return 0.0; // fail } __m256 sumdst = _mm256_loadu_ps(&v[0]); // Initial 8 values for (int i = 8/*not 0*/; i < n; i += 8) { __m256 r1 = _mm256_loadu_ps(&v[i]); // Load floats into 256-bits sumdst = _mm256_min_ps(r1, sumdst); // dst = MIN(dst, r1) } // Find Min of the final 8 accumulators float* farr = sumdst.m256_f32; float fmin = farr[0]; if (farr[1] < fmin) fmin = farr[1]; if (farr[2] < fmin) fmin = farr[2]; if (farr[3] < fmin) fmin = farr[3]; if (farr[4] < fmin) fmin = farr[4]; if (farr[5] < fmin) fmin = farr[5]; if (farr[6] < fmin) fmin = farr[6]; if (farr[7] < fmin) fmin = farr[7]; return fmin; }
These versions are not especially optimized.
AVX-512 would allow us to further vectorize to 16 float
values.
Also, the final computation of the maximum or minimum of 8 float
numbers is far from optimal.
The AVX horizontal min
/max
intrinsics would be used (pairwise, multiple times).
Or we can at least avoid some comparisons
by doing it pairwise sequentially.
Here's the alternative for AVX1 minimum computation:
// Find Min of the final 4 accumulators #define FMIN(x,y) ( (x) < (y) ? (x) : (y) ) float* farr = sumdst.m128_f32; float fmin1 = FMIN(farr[0], farr[1]); float fmin2 = FMIN(farr[2], farr[3]); float fmin = FMIN(fmin1, fmin2); return fmin;
These functions can also have their main loops further improved.
Other basic optimizations would include using loop pointer arithmetic to remove
the index variable “i
”
and also unrolling the loop body multiple times.
• Next: • Up: Table of Contents |
The new AI programming book by Aussie AI co-founders:
Get your copy from Amazon: Generative AI in C++ |