1include(CheckCSourceRuns)
  2
  3set(AVX_CODE "
  4    #include <immintrin.h>
  5    int main()
  6    {
  7        __m256 a;
  8        a = _mm256_set1_ps(0);
  9        return 0;
 10    }
 11")
 12
 13set(AVX512_CODE "
 14    #include <immintrin.h>
 15    int main()
 16    {
 17        __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
 18                                    0, 0, 0, 0, 0, 0, 0, 0,
 19                                    0, 0, 0, 0, 0, 0, 0, 0,
 20                                    0, 0, 0, 0, 0, 0, 0, 0,
 21                                    0, 0, 0, 0, 0, 0, 0, 0,
 22                                    0, 0, 0, 0, 0, 0, 0, 0,
 23                                    0, 0, 0, 0, 0, 0, 0, 0,
 24                                    0, 0, 0, 0, 0, 0, 0, 0);
 25        __m512i b = a;
 26        __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
 27        return 0;
 28    }
 29")
 30
 31set(AVX2_CODE "
 32    #include <immintrin.h>
 33    int main()
 34    {
 35        __m256i a = {0};
 36        a = _mm256_abs_epi16(a);
 37        __m256i x;
 38        _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
 39        return 0;
 40    }
 41")
 42
 43set(FMA_CODE "
 44    #include <immintrin.h>
 45    int main()
 46    {
 47        __m256 acc = _mm256_setzero_ps();
 48        const __m256 d = _mm256_setzero_ps();
 49        const __m256 p = _mm256_setzero_ps();
 50        acc = _mm256_fmadd_ps( d, p, acc );
 51        return 0;
 52    }
 53")
 54
 55macro(check_sse type flags)
 56    set(__FLAG_I 1)
 57    set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
 58    foreach (__FLAG ${flags})
 59        if (NOT ${type}_FOUND)
 60            set(CMAKE_REQUIRED_FLAGS ${__FLAG})
 61            check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
 62            if (HAS_${type}_${__FLAG_I})
 63                set(${type}_FOUND TRUE CACHE BOOL "${type} support")
 64                set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
 65            endif()
 66            math(EXPR __FLAG_I "${__FLAG_I}+1")
 67        endif()
 68    endforeach()
 69    set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
 70
 71    if (NOT ${type}_FOUND)
 72        set(${type}_FOUND FALSE CACHE BOOL "${type} support")
 73        set(${type}_FLAGS "" CACHE STRING "${type} flags")
 74    endif()
 75
 76    mark_as_advanced(${type}_FOUND ${type}_FLAGS)
 77endmacro()
 78
 79# flags are for MSVC only!
 80check_sse("AVX" " ;/arch:AVX")
 81if (NOT ${AVX_FOUND})
 82    set(GGML_AVX OFF)
 83else()
 84    set(GGML_AVX ON)
 85endif()
 86
 87check_sse("AVX2" " ;/arch:AVX2")
 88check_sse("FMA" " ;/arch:AVX2")
 89if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
 90    set(GGML_AVX2 OFF)
 91else()
 92    set(GGML_AVX2 ON)
 93endif()
 94
 95check_sse("AVX512" " ;/arch:AVX512")
 96if (NOT ${AVX512_FOUND})
 97    set(GGML_AVX512 OFF)
 98else()
 99    set(GGML_AVX512 ON)
100endif()