Aussie AI

AVX Vector Max and Min Reductions

  • Book Excerpt from "Generative AI in C++"
  • by David Spuler, Ph.D.

AVX Vector Max and Min Reductions

The need to find a minimum or maximum of a vector's elements is similar to a summation reduction. Again, AVX1 and AVX2 don't have proper “reduction” intrinsics for max or min, but we can compute them in parallel by keeping a running min or max value of 4 or 8 float values (i.e. analogous to parallel accumulators when doing summation). The AVX intrinsics are:

  • MIN: _mm_min_ps, _mm256_min_ps
  • MAX: _mm_max_ps, _mm256_max_ps

Here is the AVX1 version of MAX vector reduction:

    float aussie_vector_max_AVX1(float v[], int n)   
    {
        // Maximum (horizontal) of a single vector
        if (n % 4 != 0) {
            yassert(n % 4 == 0);
            return 0.0; // fail
        }
        __m128 sumdst = _mm_loadu_ps(&v[0]);   // Initial values
        for (int i = 4 /*not 0*/; i < n; i += 4) {
            __m128 r1 = _mm_loadu_ps(&v[i]); // Load floats into 128-bits
            sumdst = _mm_max_ps(r1, sumdst); // dst = MAX(dst, r1)
        }
        // Find Max of the final 4 accumulators
        float* farr = sumdst.m128_f32;
        float fmax = farr[0];
        if (farr[1] > fmax) fmax = farr[1];
        if (farr[2] > fmax) fmax = farr[2];
        if (farr[3] > fmax) fmax = farr[3];
        return fmax;
    }

And here is the analogous AVX2 version of MAX vector reduction:

    float aussie_vector_max_AVX2(float v[], int n)
    {
        // Maximum (horizontal) of a single vector
        if (n % 8 != 0) { // Safety check (no extra cases)
            yassert(n % 8 == 0);
            return 0.0; // fail
        }
        __m256 sumdst = _mm256_loadu_ps(&v[0]);   // Initial 8 values
        for (int i = 8/*not 0*/; i < n; i += 8) {
            __m256 r1 = _mm256_loadu_ps(&v[i]); // Load floats into 256-bits
            sumdst = _mm256_max_ps(r1, sumdst); // dst = MAX(dst, r1)
        }

        // Find Max of the final 8 accumulators
        float* farr = sumdst.m256_f32;
        float fmax = farr[0];
        if (farr[1] > fmax) fmax = farr[1];
        if (farr[2] > fmax) fmax = farr[2];
        if (farr[3] > fmax) fmax = farr[3];
        if (farr[4] > fmax) fmax = farr[4];
        if (farr[5] > fmax) fmax = farr[5];
        if (farr[6] > fmax) fmax = farr[6];
        if (farr[7] > fmax) fmax = farr[7];
        return fmax;
    }

The MIN versions are very similar. They use the “min” AVX intrinsics, and the final steps use “<” not “>” operations. Here's the AVX1 version of a MIN vector reduction:

    float aussie_vector_min_AVX1(float v[], int n)
    {
        // Minimum (horizontal) of a single vector
        if (n % 4 != 0) {
            yassert(n % 4 == 0);
            return 0.0; // fail
        }
        __m128 sumdst = _mm_loadu_ps(&v[0]);   // Initial values
        for (int i = 4 /*not 0*/; i < n; i += 4) {
            __m128 r1 = _mm_loadu_ps(&v[i]); // Load floats into 128-bits
            sumdst = _mm_min_ps(r1, sumdst); // dst = MIN(dst, r1)
        }
        // Find Min of the final 4 accumulators
        float* farr = sumdst.m128_f32;
        float fmin = farr[0];
        if (farr[1] < fmin) fmin = farr[1];
        if (farr[2] < fmin) fmin = farr[2];
        if (farr[3] < fmin) fmin = farr[3];
        return fmin;
    }

This is the AVX2 version of a MIN vector reduction:

    float aussie_vector_min_AVX2(float v[], int n)   // Minimum (horizontal) of a single vector
    {
        if (n % 8 != 0) { // Safety check (no extra cases)
            yassert(n % 8 == 0);
            return 0.0; // fail
        }
        __m256 sumdst = _mm256_loadu_ps(&v[0]);   // Initial 8 values
        for (int i = 8/*not 0*/; i < n; i += 8) {
            __m256 r1 = _mm256_loadu_ps(&v[i]); // Load floats into 256-bits
            sumdst = _mm256_min_ps(r1, sumdst); // dst = MIN(dst, r1)
        }

        // Find Min of the final 8 accumulators
        float* farr = sumdst.m256_f32;
        float fmin = farr[0];
        if (farr[1] < fmin) fmin = farr[1];
        if (farr[2] < fmin) fmin = farr[2];
        if (farr[3] < fmin) fmin = farr[3];
        if (farr[4] < fmin) fmin = farr[4];
        if (farr[5] < fmin) fmin = farr[5];
        if (farr[6] < fmin) fmin = farr[6];
        if (farr[7] < fmin) fmin = farr[7];
        return fmin;
    }

These versions are not especially optimized. AVX-512 would allow us to further vectorize to 16 float values. Also, the final computation of the maximum or minimum of 8 float numbers is far from optimal. The AVX horizontal min/max intrinsics would be used (pairwise, multiple times). Or we can at least avoid some comparisons by doing it pairwise sequentially. Here's the alternative for AVX1 minimum computation:

        // Find Min of the final 4 accumulators
    #define FMIN(x,y)  ( (x) < (y) ? (x) : (y) )
        float* farr = sumdst.m128_f32;
        float fmin1 = FMIN(farr[0], farr[1]);
        float fmin2 = FMIN(farr[2], farr[3]);
        float fmin = FMIN(fmin1, fmin2);
        return fmin;

These functions can also have their main loops further improved. Other basic optimizations would include using loop pointer arithmetic to remove the index variable “i” and also unrolling the loop body multiple times.

 

Next:

Up: Table of Contents

Buy: Generative AI in C++: Coding Transformers and LLMs

Generative AI in C++ The new AI programming book by Aussie AI co-founders:
  • AI coding in C++
  • Transformer engine speedups
  • LLM models
  • Phone and desktop AI
  • Code examples
  • Research citations

Get your copy from Amazon: Generative AI in C++