diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cpu/amx')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp | 224 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cpu/amx/amx.h | 8 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cpu/amx/common.h | 91 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp | 2512 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cpu/amx/mmq.h | 10 |
5 files changed, 2845 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp b/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp new file mode 100644 index 0000000..895a571 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp | |||
| @@ -0,0 +1,224 @@ | |||
| 1 | #include "amx.h" | ||
| 2 | #include "common.h" | ||
| 3 | #include "mmq.h" | ||
| 4 | #include "ggml-backend-impl.h" | ||
| 5 | #include "ggml-backend.h" | ||
| 6 | #include "ggml-impl.h" | ||
| 7 | #include "ggml-cpu.h" | ||
| 8 | #include "traits.h" | ||
| 9 | |||
| 10 | #if defined(__linux__) | ||
| 11 | #include <sys/syscall.h> | ||
| 12 | #include <unistd.h> | ||
| 13 | #endif | ||
| 14 | |||
| 15 | #include <cstdlib> | ||
| 16 | #include <cstring> | ||
| 17 | #include <memory> | ||
| 18 | |||
| 19 | #if defined(__AMX_INT8__) && defined(__AVX512VNNI__) | ||
| 20 | |||
| 21 | // AMX type_trais | ||
| 22 | namespace ggml::cpu::amx { | ||
| 23 | class tensor_traits : public ggml::cpu::tensor_traits { | ||
| 24 | bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { | ||
| 25 | size = ggml_backend_amx_desired_wsize(op); | ||
| 26 | return true; | ||
| 27 | } | ||
| 28 | |||
| 29 | bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { | ||
| 30 | if (op->op == GGML_OP_MUL_MAT) { | ||
| 31 | ggml_backend_amx_mul_mat(params, op); | ||
| 32 | return true; | ||
| 33 | } | ||
| 34 | return false; | ||
| 35 | } | ||
| 36 | }; | ||
| 37 | |||
| 38 | static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) { | ||
| 39 | static tensor_traits traits; | ||
| 40 | return &traits; | ||
| 41 | } | ||
| 42 | } // namespace ggml::cpu::amx | ||
| 43 | |||
| 44 | // AMX buffer interface | ||
| 45 | static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) { | ||
| 46 | free(buffer->context); | ||
| 47 | } | ||
| 48 | |||
| 49 | static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) { | ||
| 50 | return (void *) (buffer->context); | ||
| 51 | } | ||
| 52 | |||
| 53 | static enum ggml_status ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { | ||
| 54 | tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor); | ||
| 55 | |||
| 56 | GGML_UNUSED(buffer); | ||
| 57 | return GGML_STATUS_SUCCESS; | ||
| 58 | } | ||
| 59 | |||
| 60 | static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, | ||
| 61 | uint8_t value, size_t offset, size_t size) { | ||
| 62 | memset((char *) tensor->data + offset, value, size); | ||
| 63 | |||
| 64 | GGML_UNUSED(buffer); | ||
| 65 | } | ||
| 66 | |||
| 67 | static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, | ||
| 68 | const void * data, size_t offset, size_t size) { | ||
| 69 | if (qtype_has_amx_kernels(tensor->type)) { | ||
| 70 | GGML_LOG_DEBUG("%s: amx repack tensor %s of type %s\n", __func__, tensor->name, ggml_type_name(tensor->type)); | ||
| 71 | ggml_backend_amx_convert_weight(tensor, data, offset, size); | ||
| 72 | } else { | ||
| 73 | memcpy((char *) tensor->data + offset, data, size); | ||
| 74 | } | ||
| 75 | |||
| 76 | GGML_UNUSED(buffer); | ||
| 77 | } | ||
| 78 | |||
| 79 | /* | ||
| 80 | // need to figure what we need to do with buffer->extra. | ||
| 81 | static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { | ||
| 82 | GGML_ASSERT(!qtype_has_amx_kernels(tensor->type)); | ||
| 83 | memcpy(data, (const char *)tensor->data + offset, size); | ||
| 84 | |||
| 85 | GGML_UNUSED(buffer); | ||
| 86 | } | ||
| 87 | |||
| 88 | static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { | ||
| 89 | if (ggml_backend_buffer_is_host(src->buffer)) { | ||
| 90 | if (qtype_has_amx_kernels(src->type)) { | ||
| 91 | ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_nbytes(dst)); | ||
| 92 | } else { | ||
| 93 | memcpy(dst->data, src->data, ggml_nbytes(src)); | ||
| 94 | } | ||
| 95 | return true; | ||
| 96 | } | ||
| 97 | return false; | ||
| 98 | |||
| 99 | GGML_UNUSED(buffer); | ||
| 100 | } | ||
| 101 | */ | ||
| 102 | |||
| 103 | static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { | ||
| 104 | memset(buffer->context, value, buffer->size); | ||
| 105 | } | ||
| 106 | |||
| 107 | static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { | ||
| 108 | /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer, | ||
| 109 | /* .get_base = */ ggml_backend_amx_buffer_get_base, | ||
| 110 | /* .init_tensor = */ ggml_backend_amx_buffer_init_tensor, | ||
| 111 | /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, | ||
| 112 | /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, | ||
| 113 | /* .get_tensor = */ nullptr, | ||
| 114 | /* .cpy_tensor = */ nullptr, | ||
| 115 | /* .clear = */ ggml_backend_amx_buffer_clear, | ||
| 116 | /* .reset = */ nullptr, | ||
| 117 | }; | ||
| 118 | |||
| 119 | static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) { | ||
| 120 | return "AMX"; | ||
| 121 | |||
| 122 | GGML_UNUSED(buft); | ||
| 123 | } | ||
| 124 | |||
| 125 | static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { | ||
| 126 | void * data = ggml_aligned_malloc(size); | ||
| 127 | if (data == NULL) { | ||
| 128 | fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size); | ||
| 129 | return NULL; | ||
| 130 | } | ||
| 131 | |||
| 132 | return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size); | ||
| 133 | } | ||
| 134 | |||
| 135 | static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { | ||
| 136 | return TENSOR_ALIGNMENT; | ||
| 137 | |||
| 138 | GGML_UNUSED(buft); | ||
| 139 | } | ||
| 140 | |||
| 141 | namespace ggml::cpu::amx { | ||
| 142 | class extra_buffer_type : ggml::cpu::extra_buffer_type { | ||
| 143 | bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { | ||
| 144 | // handle only 2d gemm for now | ||
| 145 | auto is_contiguous_2d = [](const struct ggml_tensor * t) { | ||
| 146 | return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; | ||
| 147 | }; | ||
| 148 | |||
| 149 | if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous | ||
| 150 | is_contiguous_2d(op->src[1]) && // src1 must be contiguous | ||
| 151 | op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && | ||
| 152 | op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315) | ||
| 153 | op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x | ||
| 154 | (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { | ||
| 155 | // src1 must be host buffer | ||
| 156 | if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { | ||
| 157 | return false; | ||
| 158 | } | ||
| 159 | // src1 must be float32 | ||
| 160 | if (op->src[1]->type == GGML_TYPE_F32) { | ||
| 161 | return true; | ||
| 162 | } | ||
| 163 | } | ||
| 164 | return false; | ||
| 165 | } | ||
| 166 | |||
| 167 | ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { | ||
| 168 | if (op->op == GGML_OP_MUL_MAT && op->src[0]->buffer && | ||
| 169 | op->src[0]->buffer->buft == ggml_backend_amx_buffer_type()) { | ||
| 170 | return (ggml::cpu::tensor_traits *) op->src[0]->extra; | ||
| 171 | } | ||
| 172 | |||
| 173 | return nullptr; | ||
| 174 | } | ||
| 175 | }; | ||
| 176 | } // namespace ggml::cpu::amx | ||
| 177 | |||
| 178 | static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { | ||
| 179 | return ggml_backend_amx_get_alloc_size(tensor); | ||
| 180 | |||
| 181 | GGML_UNUSED(buft); | ||
| 182 | } | ||
| 183 | |||
| 184 | #define ARCH_GET_XCOMP_PERM 0x1022 | ||
| 185 | #define ARCH_REQ_XCOMP_PERM 0x1023 | ||
| 186 | #define XFEATURE_XTILECFG 17 | ||
| 187 | #define XFEATURE_XTILEDATA 18 | ||
| 188 | |||
| 189 | static bool ggml_amx_init() { | ||
| 190 | #if defined(__linux__) | ||
| 191 | if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) { | ||
| 192 | fprintf(stderr, "AMX is not ready to be used!\n"); | ||
| 193 | return false; | ||
| 194 | } | ||
| 195 | return true; | ||
| 196 | #elif defined(_WIN32) | ||
| 197 | return true; | ||
| 198 | #else | ||
| 199 | return false; | ||
| 200 | #endif | ||
| 201 | } | ||
| 202 | |||
| 203 | ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() { | ||
| 204 | static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = { | ||
| 205 | /* .iface = */ { | ||
| 206 | /* .get_name = */ ggml_backend_amx_buffer_type_get_name, | ||
| 207 | /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer, | ||
| 208 | /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment, | ||
| 209 | /* .get_max_size = */ nullptr, // defaults to SIZE_MAX | ||
| 210 | /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size, | ||
| 211 | /* .is_host = */ nullptr, | ||
| 212 | }, | ||
| 213 | /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), | ||
| 214 | /* .context = */ new ggml::cpu::amx::extra_buffer_type(), | ||
| 215 | }; | ||
| 216 | |||
| 217 | if (!ggml_amx_init()) { | ||
| 218 | return nullptr; | ||
| 219 | } | ||
| 220 | |||
| 221 | return &ggml_backend_buffer_type_amx; | ||
| 222 | } | ||
| 223 | |||
| 224 | #endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__) | ||
diff --git a/llama.cpp/ggml/src/ggml-cpu/amx/amx.h b/llama.cpp/ggml/src/ggml-cpu/amx/amx.h new file mode 100644 index 0000000..5b65d76 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cpu/amx/amx.h | |||
| @@ -0,0 +1,8 @@ | |||
| 1 | #include "ggml-backend.h" | ||
| 2 | #include "ggml-cpu-impl.h" | ||
| 3 | |||
| 4 | // GGML internal header | ||
| 5 | |||
| 6 | #if defined(__AMX_INT8__) && defined(__AVX512VNNI__) | ||
| 7 | ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void); | ||
| 8 | #endif | ||
diff --git a/llama.cpp/ggml/src/ggml-cpu/amx/common.h b/llama.cpp/ggml/src/ggml-cpu/amx/common.h new file mode 100644 index 0000000..f392e89 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cpu/amx/common.h | |||
| @@ -0,0 +1,91 @@ | |||
| 1 | #pragma once | ||
| 2 | |||
| 3 | #include "ggml.h" | ||
| 4 | #include "ggml-cpu-impl.h" | ||
| 5 | |||
| 6 | #include <algorithm> | ||
| 7 | #include <memory> | ||
| 8 | #include <type_traits> | ||
| 9 | |||
| 10 | #if defined(GGML_USE_OPENMP) | ||
| 11 | #include <omp.h> | ||
| 12 | #endif | ||
| 13 | |||
| 14 | #define TILE_M 16 | ||
| 15 | #define TILE_N 16 | ||
| 16 | #define TILE_K 32 | ||
| 17 | #define VNNI_BLK 4 | ||
| 18 | |||
| 19 | #define AMX_BLK_SIZE 32 | ||
| 20 | |||
| 21 | #define TMM0 0 | ||
| 22 | #define TMM1 1 | ||
| 23 | #define TMM2 2 | ||
| 24 | #define TMM3 3 | ||
| 25 | #define TMM4 4 | ||
| 26 | #define TMM5 5 | ||
| 27 | #define TMM6 6 | ||
| 28 | #define TMM7 7 | ||
| 29 | |||
| 30 | // parallel routines | ||
| 31 | template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0> | ||
| 32 | inline T div_up(T x, T y) { return (x + y - 1) / y; } | ||
| 33 | |||
| 34 | template <typename T> | ||
| 35 | inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { | ||
| 36 | #if 0 | ||
| 37 | // onednn partition pattern | ||
| 38 | T& n_my = n_end; | ||
| 39 | if (nth <= 1 || n == 0) { | ||
| 40 | n_start = 0; | ||
| 41 | n_my = n; | ||
| 42 | } else { | ||
| 43 | T n1 = div_up(n, nth); | ||
| 44 | T n2 = n1 - 1; | ||
| 45 | T T1 = n - n2 * nth; | ||
| 46 | n_my = ith < T1 ? n1 : n2; | ||
| 47 | n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; | ||
| 48 | } | ||
| 49 | n_end += n_start; | ||
| 50 | #else | ||
| 51 | // pytorch aten partition pattern | ||
| 52 | T n_my = div_up(n, nth); | ||
| 53 | n_start = ith * n_my; | ||
| 54 | n_end = std::min(n_start + n_my, n); | ||
| 55 | #endif | ||
| 56 | } | ||
| 57 | |||
| 58 | template <typename func_t> | ||
| 59 | inline void parallel_for(int n, const func_t& f) { | ||
| 60 | #if defined(GGML_USE_OPENMP) | ||
| 61 | #pragma omp parallel | ||
| 62 | { | ||
| 63 | int nth = omp_get_num_threads(); | ||
| 64 | int ith = omp_get_thread_num(); | ||
| 65 | int tbegin, tend; | ||
| 66 | balance211(n, nth, ith, tbegin, tend); | ||
| 67 | f(tbegin, tend); | ||
| 68 | } | ||
| 69 | #else | ||
| 70 | f(0, n); | ||
| 71 | #endif | ||
| 72 | } | ||
| 73 | |||
| 74 | template <typename func_t> | ||
| 75 | inline void parallel_for_ggml(const ggml_compute_params * params, int n, const func_t & f) { | ||
| 76 | int tbegin, tend; | ||
| 77 | balance211(n, params->nth, params->ith, tbegin, tend); | ||
| 78 | f(tbegin, tend); | ||
| 79 | } | ||
| 80 | |||
| 81 | // quantized types that have AMX support | ||
| 82 | inline bool qtype_has_amx_kernels(const enum ggml_type type) { | ||
| 83 | // TODO: fix padding for vnni format | ||
| 84 | return (type == GGML_TYPE_Q4_0) || | ||
| 85 | (type == GGML_TYPE_Q4_1) || | ||
| 86 | (type == GGML_TYPE_Q8_0) || | ||
| 87 | (type == GGML_TYPE_Q4_K) || | ||
| 88 | (type == GGML_TYPE_Q5_K) || | ||
| 89 | (type == GGML_TYPE_Q6_K) || | ||
| 90 | (type == GGML_TYPE_IQ4_XS); | ||
| 91 | } | ||
diff --git a/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp b/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp new file mode 100644 index 0000000..47c61b8 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp | |||
| @@ -0,0 +1,2512 @@ | |||
| 1 | |||
| 2 | #if defined(__GNUC__) | ||
| 3 | #pragma GCC diagnostic ignored "-Wpedantic" | ||
| 4 | #pragma GCC diagnostic ignored "-Wunused-local-typedefs" | ||
| 5 | #endif | ||
| 6 | |||
| 7 | #include "amx.h" | ||
| 8 | #include "mmq.h" | ||
| 9 | #include "ggml-impl.h" | ||
| 10 | #include "ggml-cpu-impl.h" | ||
| 11 | #include "simd-mappings.h" | ||
| 12 | #include "quants.h" | ||
| 13 | #include "ggml-quants.h" | ||
| 14 | #include <algorithm> | ||
| 15 | #include <type_traits> | ||
| 16 | |||
| 17 | #if defined(__gnu_linux__) | ||
| 18 | #include <sys/syscall.h> | ||
| 19 | #include <unistd.h> | ||
| 20 | #endif | ||
| 21 | |||
| 22 | #if (defined(_WIN32) || defined(_WIN64)) | ||
| 23 | #define RESTRICT __restrict | ||
| 24 | #else | ||
| 25 | #define RESTRICT __restrict__ | ||
| 26 | #endif | ||
| 27 | |||
| 28 | #if (defined(_WIN32) || defined(_WIN64)) | ||
| 29 | #define ALWAYS_INLINE __forceinline | ||
| 30 | #elif __has_attribute(always_inline) || defined(__GNUC__) | ||
| 31 | #define ALWAYS_INLINE __attribute__((__always_inline__)) inline | ||
| 32 | #else | ||
| 33 | #define ALWAYS_INLINE inline | ||
| 34 | #endif | ||
| 35 | |||
| 36 | #if defined(__AMX_INT8__) && defined(__AVX512VNNI__) | ||
| 37 | |||
| 38 | namespace { | ||
| 39 | |||
| 40 | // Forced unrolling | ||
| 41 | template <int n> | ||
| 42 | struct Unroll { | ||
| 43 | template <typename Func, typename... Args> | ||
| 44 | ALWAYS_INLINE void operator()(const Func& f, Args... args) const { | ||
| 45 | Unroll<n - 1>{}(f, args...); | ||
| 46 | f(std::integral_constant<int, n - 1>{}, args...); | ||
| 47 | } | ||
| 48 | }; | ||
| 49 | |||
| 50 | template <> | ||
| 51 | struct Unroll<1> { | ||
| 52 | template <typename Func, typename... Args> | ||
| 53 | ALWAYS_INLINE void operator()(const Func& f, Args... args) const { | ||
| 54 | f(std::integral_constant<int, 0>{}, args...); | ||
| 55 | } | ||
| 56 | }; | ||
| 57 | |||
| 58 | // type traits | ||
| 59 | template <typename T> struct PackedTypes {}; | ||
| 60 | template <> struct PackedTypes<block_q4_0> { using type = int8_t; }; | ||
| 61 | template <> struct PackedTypes<block_q4_1> { using type = uint8_t; }; | ||
| 62 | template <> struct PackedTypes<block_q8_0> { using type = int8_t; }; | ||
| 63 | template <typename T> using packed_B_type = typename PackedTypes<T>::type; | ||
| 64 | |||
| 65 | template <typename T> | ||
| 66 | struct do_compensate : std::integral_constant<bool, | ||
| 67 | std::is_same<T, block_q8_0>::value> {}; | ||
| 68 | |||
| 69 | template <typename T> | ||
| 70 | struct do_unpack : std::integral_constant<bool, | ||
| 71 | std::is_same<T, block_q4_0>::value || | ||
| 72 | std::is_same<T, block_q4_1>::value> {}; | ||
| 73 | |||
| 74 | template <typename T> | ||
| 75 | struct is_type_qkk : std::integral_constant<bool, | ||
| 76 | std::is_same<T, block_q4_K>::value || | ||
| 77 | std::is_same<T, block_q5_K>::value || | ||
| 78 | std::is_same<T, block_q6_K>::value || | ||
| 79 | std::is_same<T, block_iq4_xs>::value> {}; | ||
| 80 | |||
| 81 | #define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...) \ | ||
| 82 | [&] { \ | ||
| 83 | switch (TYPE) { \ | ||
| 84 | case GGML_TYPE_F16: { \ | ||
| 85 | using type = ggml_fp16_t; \ | ||
| 86 | constexpr int blck_size = 16; \ | ||
| 87 | return __VA_ARGS__(); \ | ||
| 88 | } \ | ||
| 89 | case GGML_TYPE_BF16: { \ | ||
| 90 | using type = ggml_bf16_t; \ | ||
| 91 | constexpr int blck_size = 32; \ | ||
| 92 | return __VA_ARGS__(); \ | ||
| 93 | } \ | ||
| 94 | default: \ | ||
| 95 | fprintf(stderr, "Unsupported floating data type\n"); \ | ||
| 96 | } \ | ||
| 97 | }() | ||
| 98 | |||
| 99 | #define GGML_DISPATCH_QTYPES(QT, ...) \ | ||
| 100 | [&] { \ | ||
| 101 | switch (QT) { \ | ||
| 102 | case GGML_TYPE_Q4_0: { \ | ||
| 103 | using type = block_q4_0; \ | ||
| 104 | using vec_dot_type = block_q8_0; \ | ||
| 105 | constexpr int blck_size = QK4_0; \ | ||
| 106 | return __VA_ARGS__(); \ | ||
| 107 | } \ | ||
| 108 | case GGML_TYPE_Q4_1: { \ | ||
| 109 | using type = block_q4_1; \ | ||
| 110 | using vec_dot_type = block_q8_1; \ | ||
| 111 | constexpr int blck_size = QK4_1; \ | ||
| 112 | return __VA_ARGS__(); \ | ||
| 113 | } \ | ||
| 114 | case GGML_TYPE_Q8_0: { \ | ||
| 115 | using type = block_q8_0; \ | ||
| 116 | using vec_dot_type = block_q8_0; \ | ||
| 117 | constexpr int blck_size = QK8_0; \ | ||
| 118 | return __VA_ARGS__(); \ | ||
| 119 | } \ | ||
| 120 | case GGML_TYPE_Q4_K: { \ | ||
| 121 | using type = block_q4_K; \ | ||
| 122 | using vec_dot_type = block_q8_K; \ | ||
| 123 | constexpr int blck_size = QK_K; \ | ||
| 124 | return __VA_ARGS__(); \ | ||
| 125 | } \ | ||
| 126 | case GGML_TYPE_Q5_K: { \ | ||
| 127 | using type = block_q5_K; \ | ||
| 128 | using vec_dot_type = block_q8_K; \ | ||
| 129 | constexpr int blck_size = QK_K; \ | ||
| 130 | return __VA_ARGS__(); \ | ||
| 131 | } \ | ||
| 132 | case GGML_TYPE_Q6_K: { \ | ||
| 133 | using type = block_q6_K; \ | ||
| 134 | using vec_dot_type = block_q8_K; \ | ||
| 135 | constexpr int blck_size = QK_K; \ | ||
| 136 | return __VA_ARGS__(); \ | ||
| 137 | } \ | ||
| 138 | case GGML_TYPE_IQ4_XS: { \ | ||
| 139 | using type = block_iq4_xs; \ | ||
| 140 | using vec_dot_type = block_q8_K; \ | ||
| 141 | constexpr int blck_size = QK_K; \ | ||
| 142 | return __VA_ARGS__(); \ | ||
| 143 | } \ | ||
| 144 | default: \ | ||
| 145 | fprintf(stderr, "Unsupported quantized data type: %d\n", int(TYPE)); \ | ||
| 146 | } \ | ||
| 147 | }() | ||
| 148 | |||
| 149 | #define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \ | ||
| 150 | [&] { \ | ||
| 151 | if (BOOL_V) { \ | ||
| 152 | constexpr bool BOOL_NAME = true; \ | ||
| 153 | return __VA_ARGS__(); \ | ||
| 154 | } else { \ | ||
| 155 | constexpr bool BOOL_NAME = false; \ | ||
| 156 | return __VA_ARGS__(); \ | ||
| 157 | } \ | ||
| 158 | }() | ||
| 159 | |||
| 160 | // define amx tile config data structure | ||
| 161 | struct tile_config_t{ | ||
| 162 | uint8_t palette_id = 0; | ||
| 163 | uint8_t start_row = 0; | ||
| 164 | uint8_t reserved_0[14] = {0}; | ||
| 165 | uint16_t colsb[16] = {0}; | ||
| 166 | uint8_t rows[16] = {0}; | ||
| 167 | }; | ||
| 168 | |||
| 169 | // Notes: amx tile config | ||
| 170 | // | ||
| 171 | // Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values, | ||
| 172 | // and accumulate the result to a 16 x 16 matrix C containing INT32 values, | ||
| 173 | // | ||
| 174 | // As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used | ||
| 175 | // instead of the normally used 16-16-64 config. | ||
| 176 | // | ||
| 177 | // Block A: {16, 32}, dtype = int8_t | ||
| 178 | // Block B: {16, 32}, dtype = uint8_t/int8_t | ||
| 179 | // Block C: {16, 16}, dtype = int32_t | ||
| 180 | // | ||
| 181 | // Block B needs to be prepacked to vnni format before feeding into TMUL: | ||
| 182 | // packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64} | ||
| 183 | // | ||
| 184 | // Therefore, we get tileconfig: | ||
| 185 | // A B C | ||
| 186 | // rows 16 8 16 | ||
| 187 | // colsb 32 64 16 | ||
| 188 | // | ||
| 189 | // For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1, | ||
| 190 | // C used TMM4-TMM7: | ||
| 191 | // B TMM0 B TMM1 | ||
| 192 | // A TMM2 C TMM4 C TMM6 | ||
| 193 | // A TMM3 C TMM5 C TMM7 | ||
| 194 | // | ||
| 195 | // Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A | ||
| 196 | // will be needed. | ||
| 197 | // | ||
| 198 | // Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16; | ||
| 199 | // and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`. | ||
| 200 | // | ||
| 201 | // ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/ | ||
| 202 | // advanced-matrix-extensions-intrinsics-functions.html | ||
| 203 | // | ||
| 204 | |||
| 205 | #define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb | ||
| 206 | void ggml_tile_config_init(void) { | ||
| 207 | static thread_local bool is_first_time = true; | ||
| 208 | |||
| 209 | if (!is_first_time) { | ||
| 210 | return; | ||
| 211 | } | ||
| 212 | |||
| 213 | static thread_local tile_config_t tc; | ||
| 214 | tile_config_t current_tc; | ||
| 215 | _tile_storeconfig(¤t_tc); | ||
| 216 | |||
| 217 | // load only when config changes | ||
| 218 | if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 && | ||
| 219 | memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { | ||
| 220 | tc.palette_id = 1; | ||
| 221 | tc.start_row = 0; | ||
| 222 | TC_CONFIG_TILE(TMM0, 8, 64); | ||
| 223 | TC_CONFIG_TILE(TMM1, 8, 64); | ||
| 224 | TC_CONFIG_TILE(TMM2, 16, 32); | ||
| 225 | TC_CONFIG_TILE(TMM3, 16, 32); | ||
| 226 | TC_CONFIG_TILE(TMM4, 16, 64); | ||
| 227 | TC_CONFIG_TILE(TMM5, 16, 64); | ||
| 228 | TC_CONFIG_TILE(TMM6, 16, 64); | ||
| 229 | TC_CONFIG_TILE(TMM7, 16, 64); | ||
| 230 | _tile_loadconfig(&tc); | ||
| 231 | } | ||
| 232 | |||
| 233 | is_first_time = false; | ||
| 234 | } | ||
| 235 | |||
| 236 | // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation. | ||
| 237 | // See the notes `s8s8 igemm compensation in avx512-vnni` for detail. | ||
| 238 | template <typename TB> | ||
| 239 | int get_tile_size() { | ||
| 240 | int tile_size = TILE_N * sizeof(TB); | ||
| 241 | if (do_compensate<TB>::value) { | ||
| 242 | tile_size += TILE_N * sizeof(int32_t); | ||
| 243 | } | ||
| 244 | if (std::is_same<TB, block_q4_K>::value || | ||
| 245 | std::is_same<TB, block_q5_K>::value) { | ||
| 246 | tile_size += TILE_N * 4; | ||
| 247 | } | ||
| 248 | if (std::is_same<TB, block_iq4_xs>::value) { | ||
| 249 | tile_size += TILE_N * 2; | ||
| 250 | } | ||
| 251 | return tile_size; | ||
| 252 | } | ||
| 253 | |||
| 254 | template <typename TB, int BLOCK_K> | ||
| 255 | int get_row_size(int K) { | ||
| 256 | int KB = K / BLOCK_K; | ||
| 257 | int row_size = KB * sizeof(TB); | ||
| 258 | if (do_compensate<TB>::value) { | ||
| 259 | row_size += KB * sizeof(int32_t); | ||
| 260 | } | ||
| 261 | if (std::is_same<TB, block_q4_K>::value || | ||
| 262 | std::is_same<TB, block_q5_K>::value) { | ||
| 263 | row_size += KB * 4; | ||
| 264 | } | ||
| 265 | if (std::is_same<TB, block_iq4_xs>::value) { | ||
| 266 | row_size += KB * 2; | ||
| 267 | } | ||
| 268 | return row_size; | ||
| 269 | } | ||
| 270 | |||
| 271 | // vectorized dtype conversion | ||
| 272 | inline float FP16_TO_FP32(ggml_half val) { | ||
| 273 | __m256i v = _mm256_setr_epi16( | ||
| 274 | val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); | ||
| 275 | __m512 o = _mm512_cvtph_ps(v); | ||
| 276 | return _mm512_cvtss_f32(o); | ||
| 277 | } | ||
| 278 | |||
| 279 | inline __m512 FP16_TO_FP32_VEC(ggml_half val) { | ||
| 280 | __m256i v = _mm256_set1_epi16(val); | ||
| 281 | return _mm512_cvtph_ps(v); | ||
| 282 | } | ||
| 283 | |||
| 284 | // horizontal reduce | ||
| 285 | inline float _mm512_reduce_max_ps(const __m512 x) { | ||
| 286 | __m512 v = x; | ||
| 287 | __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E); | ||
| 288 | v = _mm512_max_ps(v, v1); | ||
| 289 | v1 = _mm512_shuffle_f32x4(v, v, 0xB1); | ||
| 290 | v = _mm512_max_ps(v, v1); | ||
| 291 | v1 = _mm512_shuffle_ps(v, v, 0x4E); | ||
| 292 | v = _mm512_max_ps(v, v1); | ||
| 293 | v1 = _mm512_shuffle_ps(v, v, 0xB1); | ||
| 294 | v = _mm512_max_ps(v, v1); | ||
| 295 | return _mm512_cvtss_f32(v); | ||
| 296 | } | ||
| 297 | |||
| 298 | // transpose utils | ||
| 299 | #define SHUFFLE_EPI32(a, b, mask) \ | ||
| 300 | _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask)) | ||
| 301 | inline void transpose_8x8_32bit(__m256i * v, __m256i * v1) { | ||
| 302 | // unpacking and 32-bit elements | ||
| 303 | v1[0] = _mm256_unpacklo_epi32(v[0], v[1]); | ||
| 304 | v1[1] = _mm256_unpackhi_epi32(v[0], v[1]); | ||
| 305 | v1[2] = _mm256_unpacklo_epi32(v[2], v[3]); | ||
| 306 | v1[3] = _mm256_unpackhi_epi32(v[2], v[3]); | ||
| 307 | v1[4] = _mm256_unpacklo_epi32(v[4], v[5]); | ||
| 308 | v1[5] = _mm256_unpackhi_epi32(v[4], v[5]); | ||
| 309 | v1[6] = _mm256_unpacklo_epi32(v[6], v[7]); | ||
| 310 | v1[7] = _mm256_unpackhi_epi32(v[6], v[7]); | ||
| 311 | |||
| 312 | // shuffling the 32-bit elements | ||
| 313 | v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44); | ||
| 314 | v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee); | ||
| 315 | v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44); | ||
| 316 | v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee); | ||
| 317 | v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44); | ||
| 318 | v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee); | ||
| 319 | v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44); | ||
| 320 | v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee); | ||
| 321 | |||
| 322 | // shuffling 128-bit elements | ||
| 323 | v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02); | ||
| 324 | v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02); | ||
| 325 | v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02); | ||
| 326 | v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02); | ||
| 327 | v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13); | ||
| 328 | v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13); | ||
| 329 | v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13); | ||
| 330 | v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13); | ||
| 331 | } | ||
| 332 | |||
| 333 | inline void transpose_16x4_32bit(__m512i * r, __m512i * d) { | ||
| 334 | |||
| 335 | static const __m512i index1 = _mm512_set_epi32( | ||
| 336 | 0x0f, 0x0b, 0x07, 0x03, | ||
| 337 | 0x0e, 0x0a, 0x06, 0x02, | ||
| 338 | 0x0d, 0x09, 0x05, 0x01, | ||
| 339 | 0x0c, 0x08, 0x04, 0x00); | ||
| 340 | |||
| 341 | d[0] = _mm512_permutexvar_epi32(index1, r[0]); | ||
| 342 | d[1] = _mm512_permutexvar_epi32(index1, r[1]); | ||
| 343 | d[2] = _mm512_permutexvar_epi32(index1, r[2]); | ||
| 344 | d[3] = _mm512_permutexvar_epi32(index1, r[3]); | ||
| 345 | |||
| 346 | r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); | ||
| 347 | r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); | ||
| 348 | r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44); | ||
| 349 | r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee); | ||
| 350 | |||
| 351 | d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88); | ||
| 352 | d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd); | ||
| 353 | d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88); | ||
| 354 | d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd); | ||
| 355 | } | ||
| 356 | |||
| 357 | inline void transpose_16x16_32bit(__m512i * v) { | ||
| 358 | __m512i v1[16]; | ||
| 359 | v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); | ||
| 360 | v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); | ||
| 361 | v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); | ||
| 362 | v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); | ||
| 363 | v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); | ||
| 364 | v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); | ||
| 365 | v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); | ||
| 366 | v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); | ||
| 367 | v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); | ||
| 368 | v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); | ||
| 369 | v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); | ||
| 370 | v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); | ||
| 371 | v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); | ||
| 372 | v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); | ||
| 373 | v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); | ||
| 374 | v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); | ||
| 375 | |||
| 376 | v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); | ||
| 377 | v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); | ||
| 378 | v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); | ||
| 379 | v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); | ||
| 380 | v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); | ||
| 381 | v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); | ||
| 382 | v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); | ||
| 383 | v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); | ||
| 384 | v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); | ||
| 385 | v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); | ||
| 386 | v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); | ||
| 387 | v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); | ||
| 388 | v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); | ||
| 389 | v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); | ||
| 390 | v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); | ||
| 391 | v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); | ||
| 392 | |||
| 393 | v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); | ||
| 394 | v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); | ||
| 395 | v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); | ||
| 396 | v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); | ||
| 397 | v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); | ||
| 398 | v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); | ||
| 399 | v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); | ||
| 400 | v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); | ||
| 401 | v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); | ||
| 402 | v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); | ||
| 403 | v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); | ||
| 404 | v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); | ||
| 405 | v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); | ||
| 406 | v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); | ||
| 407 | v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); | ||
| 408 | v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); | ||
| 409 | |||
| 410 | v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); | ||
| 411 | v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); | ||
| 412 | v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); | ||
| 413 | v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); | ||
| 414 | v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); | ||
| 415 | v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); | ||
| 416 | v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); | ||
| 417 | v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); | ||
| 418 | v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); | ||
| 419 | v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); | ||
| 420 | v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); | ||
| 421 | v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); | ||
| 422 | v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); | ||
| 423 | v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); | ||
| 424 | v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); | ||
| 425 | v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); | ||
| 426 | } | ||
| 427 | |||
| 428 | void quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) { | ||
| 429 | assert(k % QK_K == 0); | ||
| 430 | const int KB = k / QK_K; | ||
| 431 | constexpr int kVecs = QK_K / 16; | ||
| 432 | |||
| 433 | block_q8_K * y = reinterpret_cast<block_q8_K *>(vy); | ||
| 434 | |||
| 435 | // hold 16 float vecs from x | ||
| 436 | __m512 v[kVecs]; | ||
| 437 | |||
| 438 | // hold the quants vecs | ||
| 439 | __m512i vq[kVecs / 4]; | ||
| 440 | |||
| 441 | // hold the packed quants vecs | ||
| 442 | __m512i vq_packed[kVecs / 4]; | ||
| 443 | |||
| 444 | const __m512 signBit = _mm512_set1_ps(-0.f); | ||
| 445 | |||
| 446 | for (int i = 0; i < KB; ++i) { | ||
| 447 | // Compute max(abs(e)) for the block | ||
| 448 | __m512 vamax = _mm512_set1_ps(0.f); | ||
| 449 | for (int j = 0; j < kVecs; ++j) { | ||
| 450 | v[j] = _mm512_loadu_ps(x); x += 16; | ||
| 451 | vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j])); | ||
| 452 | } | ||
| 453 | const float amax = _mm512_reduce_max_ps(vamax); | ||
| 454 | |||
| 455 | // Quantize these floats | ||
| 456 | const float iscale = 127.f / amax; | ||
| 457 | y[i].d = GGML_CPU_FP32_TO_FP16(1 / iscale); | ||
| 458 | const float id = ( amax != 0.0f ) ? iscale : 0.f; | ||
| 459 | const __m512 vscale = _mm512_set1_ps(id); | ||
| 460 | |||
| 461 | // Apply multiplier and round to nearest integer | ||
| 462 | for (int j = 0; j < kVecs; ++j) { | ||
| 463 | v[j] = _mm512_mul_ps(v[j], vscale); | ||
| 464 | v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); | ||
| 465 | } | ||
| 466 | |||
| 467 | // Pack to epi8 vecs | ||
| 468 | for (int j = 0; j < kVecs / 4; ++j) { | ||
| 469 | __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0])); | ||
| 470 | __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1])); | ||
| 471 | __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2])); | ||
| 472 | __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3])); | ||
| 473 | |||
| 474 | __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1); | ||
| 475 | __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1); | ||
| 476 | |||
| 477 | vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1); | ||
| 478 | _mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]); | ||
| 479 | } | ||
| 480 | |||
| 481 | // Compute the bsums with vnni | ||
| 482 | transpose_16x4_32bit(vq, vq_packed); | ||
| 483 | |||
| 484 | const __m512i one = _mm512_set1_epi8(1); | ||
| 485 | __m512i sum = _mm512_setzero_si512(); | ||
| 486 | for (int k = 0; k < 4; ++k) { | ||
| 487 | sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]); | ||
| 488 | } | ||
| 489 | _mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum)); | ||
| 490 | } | ||
| 491 | } | ||
| 492 | |||
| 493 | // quantize A from float to `vec_dot_type` | ||
| 494 | template <typename T> | ||
| 495 | inline void from_float(const float * x, char * vy, int64_t k); | ||
| 496 | |||
| 497 | template <> | ||
| 498 | inline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) { | ||
| 499 | quantize_row_q8_0(x, (block_q8_0 *)vy, k); | ||
| 500 | } | ||
| 501 | |||
| 502 | template <> | ||
| 503 | inline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) { | ||
| 504 | quantize_row_q8_1(x, (block_q8_1 *)vy, k); | ||
| 505 | } | ||
| 506 | |||
| 507 | template <> | ||
| 508 | inline void from_float<block_q8_K>(const float * x, char * vy, int64_t k) { | ||
| 509 | #if 1 | ||
| 510 | // TODO: this is reference impl! | ||
| 511 | quantize_row_q8_K_ref(x, (block_q8_K *)vy, k); | ||
| 512 | #else | ||
| 513 | quantize_row_q8_K_vnni(x, vy, k); | ||
| 514 | #endif | ||
| 515 | } | ||
| 516 | |||
| 517 | // load A from memory to array when nrows can not fill in whole tile | ||
| 518 | void unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) { | ||
| 519 | assert(nr != TILE_M); | ||
| 520 | for (int m = 0; m < nr; ++m) { | ||
| 521 | const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); | ||
| 522 | _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); | ||
| 523 | } | ||
| 524 | } | ||
| 525 | |||
| 526 | void unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) { | ||
| 527 | assert(nr != TILE_M); | ||
| 528 | for (int m = 0; m < nr; ++m) { | ||
| 529 | const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); | ||
| 530 | _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); | ||
| 531 | } | ||
| 532 | } | ||
| 533 | |||
| 534 | template <typename TB> | ||
| 535 | void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { | ||
| 536 | assert(nr <= TILE_M); | ||
| 537 | for (int m = 0; m < nr; ++m) { | ||
| 538 | const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32)); | ||
| 539 | _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); | ||
| 540 | } | ||
| 541 | } | ||
| 542 | |||
| 543 | template <> | ||
| 544 | void unpack_A<block_q6_K>(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { | ||
| 545 | assert(nr <= TILE_M); | ||
| 546 | // zero padding k from 16 to 32, so that we don't have to re-config amx | ||
| 547 | const __m128i zero = _mm_setzero_si128(); | ||
| 548 | for (int m = 0; m < nr; ++m) { | ||
| 549 | const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16)); | ||
| 550 | const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1); | ||
| 551 | _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r); | ||
| 552 | } | ||
| 553 | } | ||
| 554 | |||
| 555 | #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) | ||
| 556 | inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { | ||
| 557 | const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); | ||
| 558 | const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); | ||
| 559 | const __m256i lowMask = _mm256_set1_epi8(0xF); | ||
| 560 | return _mm256_and_si256(lowMask, bytes); | ||
| 561 | } | ||
| 562 | |||
| 563 | // used for block_q4_K | ||
| 564 | inline __m512i bytes_from_nibbles_64(const uint8_t * rsi) { | ||
| 565 | const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi); | ||
| 566 | const __m256i lowMask = _mm256_set1_epi8(0xF); | ||
| 567 | const __m256i q4l = _mm256_and_si256(tmp, lowMask); | ||
| 568 | const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask); | ||
| 569 | return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1); | ||
| 570 | } | ||
| 571 | |||
| 572 | // used for block_q5_K | ||
| 573 | inline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) { | ||
| 574 | const __m256i lowMask = _mm256_set1_epi8(0xF); | ||
| 575 | __m256i hmask = _mm256_set1_epi8(1); | ||
| 576 | hmask = _mm256_slli_epi16(hmask, k); | ||
| 577 | |||
| 578 | const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs); | ||
| 579 | const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh); | ||
| 580 | |||
| 581 | const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask); | ||
| 582 | const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4); | ||
| 583 | const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); | ||
| 584 | hmask = _mm256_slli_epi16(hmask, 1); | ||
| 585 | |||
| 586 | const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask); | ||
| 587 | const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4); | ||
| 588 | const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); | ||
| 589 | |||
| 590 | return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1); | ||
| 591 | } | ||
| 592 | |||
| 593 | // used for block_q6_K | ||
| 594 | inline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) { | ||
| 595 | const __m256i m4 = _mm256_set1_epi8(0xF); | ||
| 596 | const __m256i m2 = _mm256_set1_epi8(0x3); | ||
| 597 | |||
| 598 | const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs); | ||
| 599 | const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32)); | ||
| 600 | const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh); | ||
| 601 | |||
| 602 | const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256( q6bitsH, m2), 4); | ||
| 603 | const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4); | ||
| 604 | const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4); | ||
| 605 | const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4); | ||
| 606 | |||
| 607 | const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0); | ||
| 608 | const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1); | ||
| 609 | const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2); | ||
| 610 | const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3); | ||
| 611 | |||
| 612 | r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1); | ||
| 613 | r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1); | ||
| 614 | } | ||
| 615 | |||
| 616 | inline __m512i packNibbles(__m512i r0, __m512i r1) { | ||
| 617 | return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4)); | ||
| 618 | } | ||
| 619 | |||
| 620 | template <typename TB> | ||
| 621 | inline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) { | ||
| 622 | int8_t tmp[8 * 64]; | ||
| 623 | __m256i v[8], v2[8]; | ||
| 624 | for (int n = 0; n < 8; ++n) { | ||
| 625 | v[n] = bytes_from_nibbles_32(B[n * KB].qs); | ||
| 626 | } | ||
| 627 | transpose_8x8_32bit(v, v2); | ||
| 628 | for (int n = 0; n < 8; ++n) { | ||
| 629 | _mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]); | ||
| 630 | } | ||
| 631 | for (int n = 0; n < 8; ++n) { | ||
| 632 | v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs); | ||
| 633 | } | ||
| 634 | transpose_8x8_32bit(v, v2); | ||
| 635 | for (int n = 0; n < 8; ++n) { | ||
| 636 | _mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]); | ||
| 637 | } | ||
| 638 | |||
| 639 | // pack again with 128 to fully utilize vector length | ||
| 640 | for (int n = 0; n < 8; n += 2) { | ||
| 641 | __m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64)); | ||
| 642 | __m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64)); | ||
| 643 | __m512i r1r0 = packNibbles(r0, r1); | ||
| 644 | _mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0); | ||
| 645 | } | ||
| 646 | } | ||
| 647 | |||
| 648 | template <> | ||
| 649 | inline void pack_qs<block_q8_0>(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { | ||
| 650 | __m256i v[8], v2[8]; | ||
| 651 | for (int n = 0; n < 8; ++n) { | ||
| 652 | v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs)); | ||
| 653 | } | ||
| 654 | transpose_8x8_32bit(v, v2); | ||
| 655 | for (int n = 0; n < 8; ++n) { | ||
| 656 | _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]); | ||
| 657 | } | ||
| 658 | for (int n = 0; n < 8; ++n) { | ||
| 659 | v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs)); | ||
| 660 | } | ||
| 661 | transpose_8x8_32bit(v, v2); | ||
| 662 | for (int n = 0; n < 8; ++n) { | ||
| 663 | _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]); | ||
| 664 | } | ||
| 665 | } | ||
| 666 | |||
| 667 | template <> | ||
| 668 | inline void pack_qs<block_q4_K>(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { | ||
| 669 | __m512i v[16]; | ||
| 670 | // QK_K 256 with 8 groups, handle 2 groups at a time | ||
| 671 | char * pb = (char *)packed_B; | ||
| 672 | for (int k = 0; k < QK_K / 64; ++k) { | ||
| 673 | // pack 2 groups { n, g, k} to {g, k/4, 4n} | ||
| 674 | // e.g. {16, 2, 32} to {2, 8, 64} | ||
| 675 | for (int n = 0; n < TILE_N; ++n) { | ||
| 676 | v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32); | ||
| 677 | } | ||
| 678 | |||
| 679 | transpose_16x16_32bit(v); | ||
| 680 | |||
| 681 | // pack again with 128 to fully utilize vector length | ||
| 682 | for (int n = 0; n < TILE_N; n += 2) { | ||
| 683 | _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); | ||
| 684 | pb += 64; | ||
| 685 | } | ||
| 686 | } | ||
| 687 | } | ||
| 688 | |||
| 689 | template <> | ||
| 690 | inline void pack_qs<block_q5_K>(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { | ||
| 691 | __m512i v[16]; | ||
| 692 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 693 | // QK_K 256 with 8 groups, handle 2 groups at a time | ||
| 694 | char * pb = (char *)packed_B; | ||
| 695 | char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; | ||
| 696 | for (int k = 0; k < QK_K / 64; ++k) { | ||
| 697 | // pack 2 groups { n, g, k} to {g, k/4, 4n} | ||
| 698 | // e.g. {16, 2, 32} to {2, 8, 64} | ||
| 699 | for (int n = 0; n < TILE_N; ++n) { | ||
| 700 | v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k); | ||
| 701 | } | ||
| 702 | |||
| 703 | transpose_16x16_32bit(v); | ||
| 704 | |||
| 705 | // 1. pack lower 4bits with 2 groups | ||
| 706 | for (int n = 0; n < TILE_N; n += 2) { | ||
| 707 | // get lower 4 bits | ||
| 708 | const __m512i r0 = _mm512_and_si512(v[n], lowMask); | ||
| 709 | const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); | ||
| 710 | _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; | ||
| 711 | } | ||
| 712 | |||
| 713 | // 2. pack higher 1bit with 2 groups | ||
| 714 | const __m512i hmask = _mm512_set1_epi8(0x10); | ||
| 715 | for (int g = 0; g < 2; ++g) { | ||
| 716 | __m512i hbits = _mm512_setzero_si512(); | ||
| 717 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4)); | ||
| 718 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3)); | ||
| 719 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2)); | ||
| 720 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1)); | ||
| 721 | hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask) ); | ||
| 722 | hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1)); | ||
| 723 | hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2)); | ||
| 724 | hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3)); | ||
| 725 | _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; | ||
| 726 | } | ||
| 727 | } | ||
| 728 | } | ||
| 729 | |||
| 730 | template <> | ||
| 731 | inline void pack_qs<block_q6_K>(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { | ||
| 732 | __m512i v[32]; | ||
| 733 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 734 | // QK_K 256 with 8 groups, handle 4 groups at a time | ||
| 735 | char * pb = (char *)packed_B; | ||
| 736 | char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; | ||
| 737 | for (int k = 0; k < QK_K / 128; ++k) { | ||
| 738 | for (int n = 0; n < TILE_N; ++n) { | ||
| 739 | bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32); | ||
| 740 | } | ||
| 741 | |||
| 742 | // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7 | ||
| 743 | transpose_16x16_32bit(v); | ||
| 744 | transpose_16x16_32bit(v + 16); | ||
| 745 | |||
| 746 | // 1. pack lower 4bits with 4 groups | ||
| 747 | for (int n = 0; n < 32; n += 2) { | ||
| 748 | const __m512i r0 = _mm512_and_si512(v[n], lowMask); | ||
| 749 | const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); | ||
| 750 | _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; | ||
| 751 | } | ||
| 752 | |||
| 753 | // 2. pack higher 2bit with 4 groups | ||
| 754 | const __m512i hmask = _mm512_set1_epi8(0x30); | ||
| 755 | for (int g = 0; g < 8; ++g) { | ||
| 756 | __m512i hbits = _mm512_setzero_si512(); | ||
| 757 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4)); | ||
| 758 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2)); | ||
| 759 | hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask) ); | ||
| 760 | hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2)); | ||
| 761 | _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; | ||
| 762 | } | ||
| 763 | } | ||
| 764 | } | ||
| 765 | |||
| 766 | template <> | ||
| 767 | inline void pack_qs<block_iq4_xs>(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { | ||
| 768 | __m512i v[16]; | ||
| 769 | char * pb = (char *)packed_B; | ||
| 770 | for (int k = 0; k < QK_K / 64; ++k) { | ||
| 771 | for (int n = 0; n < TILE_N; ++n) { | ||
| 772 | __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0); | ||
| 773 | __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16); | ||
| 774 | v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); | ||
| 775 | } | ||
| 776 | |||
| 777 | transpose_16x16_32bit(v); | ||
| 778 | |||
| 779 | // pack again with 128 to fully utilize vector length | ||
| 780 | for (int n = 0; n < TILE_N; n += 2) { | ||
| 781 | _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); | ||
| 782 | pb += 64; | ||
| 783 | } | ||
| 784 | } | ||
| 785 | } | ||
| 786 | |||
| 787 | // pack B to vnni formats in 4bits or 8 bits | ||
| 788 | void pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) { | ||
| 789 | pack_qs(packed_B, B, KB); | ||
| 790 | ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2); | ||
| 791 | for (int n = 0; n < TILE_N; ++n) { | ||
| 792 | d0[n] = B[n * KB].d; | ||
| 793 | } | ||
| 794 | } | ||
| 795 | |||
| 796 | void pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) { | ||
| 797 | pack_qs(packed_B, B, KB); | ||
| 798 | ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2); | ||
| 799 | ggml_half * m0 = d0 + TILE_N; | ||
| 800 | for (int n = 0; n < TILE_N; ++n) { | ||
| 801 | d0[n] = B[n * KB].d; | ||
| 802 | m0[n] = B[n * KB].m; | ||
| 803 | } | ||
| 804 | } | ||
| 805 | |||
| 806 | inline void s8s8_compensation(void * RESTRICT packed_B) { | ||
| 807 | // packed_B layout: | ||
| 808 | // quants {TILE_N, TILEK} int8_t | ||
| 809 | // d0 {TILE_N} ggml_half | ||
| 810 | // comp {TILE_N} int32_t | ||
| 811 | const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half); | ||
| 812 | __m512i vcomp = _mm512_setzero_si512(); | ||
| 813 | const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80)); | ||
| 814 | for (int k = 0; k < 8; ++k) { | ||
| 815 | __m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64)); | ||
| 816 | vcomp = _mm512_dpbusd_epi32(vcomp, off, vb); | ||
| 817 | } | ||
| 818 | _mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp); | ||
| 819 | } | ||
| 820 | |||
| 821 | void pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { | ||
| 822 | pack_qs(packed_B, B, KB); | ||
| 823 | ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K); | ||
| 824 | for (int n = 0; n < TILE_N; ++n) { | ||
| 825 | d0[n] = B[n * KB].d; | ||
| 826 | } | ||
| 827 | s8s8_compensation(packed_B); | ||
| 828 | } | ||
| 829 | |||
| 830 | // convert 8 * {min, scale} from int6 to int8 | ||
| 831 | inline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) { | ||
| 832 | const uint32_t kmask1 = 0x3f3f3f3f; | ||
| 833 | const uint32_t kmask2 = 0x0f0f0f0f; | ||
| 834 | const uint32_t kmask3 = 0x03030303; | ||
| 835 | |||
| 836 | memcpy(utmp, scales, 12); | ||
| 837 | utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); | ||
| 838 | const uint32_t uaux = utmp[1] & kmask1; | ||
| 839 | utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); | ||
| 840 | utmp[2] = uaux; | ||
| 841 | utmp[0] &= kmask1; | ||
| 842 | } | ||
| 843 | |||
| 844 | // packed_B layout: | ||
| 845 | // quants {8, TILE_N, 16} uint8 | ||
| 846 | // scales {8, TILE_N} uint8 | ||
| 847 | // mins {8, TILE_N} uint8 | ||
| 848 | // d {TILE_N} ggml_half | ||
| 849 | // dmin {TILE_N} ggml_half | ||
| 850 | void pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { | ||
| 851 | pack_qs(packed_B, B, KB); | ||
| 852 | |||
| 853 | uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N); | ||
| 854 | uint8_t * mins = scales + 8 * TILE_N; | ||
| 855 | ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N); | ||
| 856 | ggml_half * dmin = d + TILE_N; | ||
| 857 | |||
| 858 | union { | ||
| 859 | uint32_t u32[4]; | ||
| 860 | uint8_t u8[16]; | ||
| 861 | } s; | ||
| 862 | |||
| 863 | for (int n = 0; n < TILE_N; ++n) { | ||
| 864 | unpack_mins_and_scales(B[n * KB].scales, s.u32); | ||
| 865 | for (int k = 0; k < 8; ++k) { | ||
| 866 | scales[k * TILE_N + n] = s.u8[k]; | ||
| 867 | mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; | ||
| 868 | } | ||
| 869 | d[n] = B[n * KB].d; | ||
| 870 | dmin[n] = B[n * KB].dmin; | ||
| 871 | } | ||
| 872 | } | ||
| 873 | |||
| 874 | // packed_B layout: | ||
| 875 | // quants {8, TILE_N, 16} uint8 | ||
| 876 | // qh {8, TILE_N, 4} uint8 | ||
| 877 | // scales {8, TILE_N} uint8 | ||
| 878 | // mins {8, TILE_N} uint8 | ||
| 879 | // d {TILE_N} ggml_half | ||
| 880 | // dmin {TILE_N} ggml_half | ||
| 881 | void pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { | ||
| 882 | pack_qs(packed_B, B, KB); | ||
| 883 | |||
| 884 | uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); | ||
| 885 | uint8_t * mins = scales + 8 * TILE_N; | ||
| 886 | ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N); | ||
| 887 | ggml_half * dmin = d + TILE_N; | ||
| 888 | |||
| 889 | union { | ||
| 890 | uint32_t u32[4]; | ||
| 891 | uint8_t u8[16]; | ||
| 892 | } s; | ||
| 893 | |||
| 894 | for (int n = 0; n < TILE_N; ++n) { | ||
| 895 | unpack_mins_and_scales(B[n * KB].scales, s.u32); | ||
| 896 | for (int k = 0; k < 8; ++k) { | ||
| 897 | scales[k * TILE_N + n] = s.u8[k]; | ||
| 898 | mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; | ||
| 899 | } | ||
| 900 | d[n] = B[n * KB].d; | ||
| 901 | dmin[n] = B[n * KB].dmin; | ||
| 902 | } | ||
| 903 | } | ||
| 904 | |||
| 905 | // packed_B layout: | ||
| 906 | // quants {16, TILE_N, 8} uint8 | ||
| 907 | // qh {16, TILE_N, 4} uint8 | ||
| 908 | // scales {16, TILE_N} uint8 | ||
| 909 | // d {TILE_N} ggml_half | ||
| 910 | void pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { | ||
| 911 | pack_qs(packed_B, B, KB); | ||
| 912 | |||
| 913 | uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); | ||
| 914 | ggml_half * d = reinterpret_cast<ggml_half *>(scales + 16 * TILE_N); | ||
| 915 | for (int n = 0; n < TILE_N; ++n) { | ||
| 916 | const int8_t * ps = B[n * KB].scales; | ||
| 917 | for (int k = 0; k < 16; ++k) { | ||
| 918 | scales[k * TILE_N + n] = ps[k]; | ||
| 919 | } | ||
| 920 | d[n] = B[n * KB].d; | ||
| 921 | } | ||
| 922 | } | ||
| 923 | |||
| 924 | // packed_B layout: | ||
| 925 | // quants {8, TILE_N, 16} uint8 | ||
| 926 | // scales {8, TILE_N} int8 | ||
| 927 | // d {TILE_N} ggml_half | ||
| 928 | void pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { | ||
| 929 | pack_qs(packed_B, B, KB); | ||
| 930 | |||
| 931 | int8_t * scales = reinterpret_cast<int8_t *>((char *)packed_B + (QK_K / 2) * TILE_N); | ||
| 932 | ggml_half * d = reinterpret_cast<ggml_half *>(scales + 8 * TILE_N); | ||
| 933 | |||
| 934 | // pack the scales | ||
| 935 | for (int n = 0; n < TILE_N; ++n) { | ||
| 936 | uint16_t sh = B[n * KB].scales_h; | ||
| 937 | for (int k = 0; k < 8; k += 2) { | ||
| 938 | const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32; | ||
| 939 | const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32; | ||
| 940 | scales[(k + 0) * TILE_N + n] = ls1; | ||
| 941 | scales[(k + 1) * TILE_N + n] = ls2; | ||
| 942 | sh >>= 4; | ||
| 943 | } | ||
| 944 | d[n] = B[n * KB].d; | ||
| 945 | } | ||
| 946 | } | ||
| 947 | |||
| 948 | template<typename TB, typename packed_B_t = packed_B_type<TB>> | ||
| 949 | void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) { | ||
| 950 | GGML_UNUSED(tile); | ||
| 951 | GGML_UNUSED(packed_B); | ||
| 952 | } | ||
| 953 | |||
| 954 | template <> | ||
| 955 | void unpack_B<block_q4_0>(int8_t * RESTRICT tile, const void * RESTRICT packed_B) { | ||
| 956 | const __m512i off = _mm512_set1_epi8(8); | ||
| 957 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 958 | for (int n = 0; n < 8; n += 2) { | ||
| 959 | __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); | ||
| 960 | const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off); | ||
| 961 | const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off); | ||
| 962 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); | ||
| 963 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); | ||
| 964 | } | ||
| 965 | } | ||
| 966 | |||
| 967 | template <> | ||
| 968 | void unpack_B<block_q4_1>(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) { | ||
| 969 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 970 | for (int n = 0; n < 8; n += 2) { | ||
| 971 | __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); | ||
| 972 | const __m512i r0 = _mm512_and_si512(bytes, lowMask); | ||
| 973 | const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); | ||
| 974 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); | ||
| 975 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); | ||
| 976 | } | ||
| 977 | } | ||
| 978 | |||
| 979 | // packed_B_t for QKK is int8_t | ||
| 980 | template <typename TB> | ||
| 981 | void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { | ||
| 982 | const int packed_B_group_size = QK_K / 2 * TILE_N / 8; | ||
| 983 | const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size; | ||
| 984 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 985 | for (int n = 0; n < 8; n += 2) { | ||
| 986 | __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32); | ||
| 987 | const __m512i r0 = _mm512_and_si512(bytes, lowMask); | ||
| 988 | const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); | ||
| 989 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); | ||
| 990 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); | ||
| 991 | } | ||
| 992 | } | ||
| 993 | |||
| 994 | template <> | ||
| 995 | void unpack_B<block_q5_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { | ||
| 996 | // lower 4bits, stride 256 bytes | ||
| 997 | const int packed_l4_group_size = QK_K / 2 * TILE_N / 8; | ||
| 998 | const char * pb = (const char *)packed_B + k * packed_l4_group_size; | ||
| 999 | |||
| 1000 | // higher 1bit, stride 64 bytes | ||
| 1001 | const int packed_h1_group_size = QK_K / 8 * TILE_N / 8; | ||
| 1002 | const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size; | ||
| 1003 | const __m512i hbits = _mm512_loadu_si512(ph); | ||
| 1004 | |||
| 1005 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 1006 | __m512i hmask0 = _mm512_set1_epi8(0x1); | ||
| 1007 | __m512i hmask1 = _mm512_set1_epi8(0x2); | ||
| 1008 | |||
| 1009 | for (int n = 0; n < 8; n += 2) { | ||
| 1010 | __m512i bytes = _mm512_loadu_si512(pb + n * 32); | ||
| 1011 | __m512i r0 = _mm512_and_si512(bytes, lowMask); | ||
| 1012 | __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); | ||
| 1013 | __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4); | ||
| 1014 | __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4); | ||
| 1015 | |||
| 1016 | hmask0 = _mm512_slli_epi16(hmask0, 2); | ||
| 1017 | hmask1 = _mm512_slli_epi16(hmask1, 2); | ||
| 1018 | r0 = _mm512_add_epi8(r0, h0); | ||
| 1019 | r1 = _mm512_add_epi8(r1, h1); | ||
| 1020 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); | ||
| 1021 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); | ||
| 1022 | } | ||
| 1023 | } | ||
| 1024 | |||
| 1025 | template <> | ||
| 1026 | void unpack_B<block_q6_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { | ||
| 1027 | // lower 4bits, stride 128 bytes | ||
| 1028 | const int packed_l4_group_size = QK_K / 2 * TILE_N / 16; | ||
| 1029 | const char * pb = (const char *)packed_B + k * packed_l4_group_size; | ||
| 1030 | |||
| 1031 | // higher 2bits, stride 64 bytes | ||
| 1032 | const int packed_h2_group_size = QK_K / 4 * TILE_N / 16; | ||
| 1033 | const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size; | ||
| 1034 | const __m512i hbits = _mm512_loadu_si512(ph); | ||
| 1035 | |||
| 1036 | const __m512i off = _mm512_set1_epi8(32); | ||
| 1037 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 1038 | __m512i hmask0 = _mm512_set1_epi8(0x3); // 0011 | ||
| 1039 | __m512i hmask1 = _mm512_set1_epi8(0xC); // 1100 | ||
| 1040 | |||
| 1041 | // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A` | ||
| 1042 | __m512i bytes = _mm512_loadu_si512(pb); | ||
| 1043 | __m512i r0 = _mm512_and_si512(bytes, lowMask); | ||
| 1044 | __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); | ||
| 1045 | __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4); | ||
| 1046 | __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2); | ||
| 1047 | _mm512_storeu_si512((__m512i *)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); | ||
| 1048 | _mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); | ||
| 1049 | |||
| 1050 | hmask0 = _mm512_slli_epi16(hmask0, 4); | ||
| 1051 | hmask1 = _mm512_slli_epi16(hmask1, 4); | ||
| 1052 | |||
| 1053 | bytes = _mm512_loadu_si512(pb + 64); | ||
| 1054 | r0 = _mm512_and_si512(bytes, lowMask); | ||
| 1055 | r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); | ||
| 1056 | h0 = _mm512_and_si512(hbits, hmask0); | ||
| 1057 | h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2); | ||
| 1058 | _mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); | ||
| 1059 | _mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); | ||
| 1060 | } | ||
| 1061 | |||
| 1062 | template <> | ||
| 1063 | void unpack_B<block_iq4_xs>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { | ||
| 1064 | static const __m512i values128 = _mm512_set_epi8( | ||
| 1065 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, | ||
| 1066 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, | ||
| 1067 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, | ||
| 1068 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 | ||
| 1069 | ); | ||
| 1070 | |||
| 1071 | const int packed_B_group_size = QK_K / 2 * TILE_N / 8; | ||
| 1072 | const char * pb = (const char *)packed_B + k * packed_B_group_size; | ||
| 1073 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 1074 | |||
| 1075 | for (int n = 0; n < 8; n += 2) { | ||
| 1076 | __m512i bytes = _mm512_loadu_si512(pb + n * 32); | ||
| 1077 | const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask)); | ||
| 1078 | const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); | ||
| 1079 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); | ||
| 1080 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); | ||
| 1081 | } | ||
| 1082 | } | ||
| 1083 | |||
| 1084 | template <typename TA, typename TB, bool is_acc> | ||
| 1085 | struct acc_C {}; | ||
| 1086 | |||
| 1087 | template <bool is_acc> | ||
| 1088 | struct acc_C<block_q8_0, block_q4_0, is_acc> { | ||
| 1089 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { | ||
| 1090 | const int offset = TILE_N * TILE_K / 2; | ||
| 1091 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); | ||
| 1092 | |||
| 1093 | for (int m = 0; m < nr; ++m) { | ||
| 1094 | const __m512 vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].d)); | ||
| 1095 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); | ||
| 1096 | |||
| 1097 | __m512 vsum; | ||
| 1098 | if (is_acc) { | ||
| 1099 | vsum = _mm512_loadu_ps(C + m * ldc); | ||
| 1100 | } else { | ||
| 1101 | vsum = _mm512_set1_ps(0.f); | ||
| 1102 | } | ||
| 1103 | vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); | ||
| 1104 | _mm512_storeu_ps(C + m * ldc, vsum); | ||
| 1105 | } | ||
| 1106 | } | ||
| 1107 | }; | ||
| 1108 | |||
| 1109 | template <bool is_acc> | ||
| 1110 | struct acc_C<block_q8_1, block_q4_1, is_acc> { | ||
| 1111 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) { | ||
| 1112 | const int offset = TILE_N * TILE_K / 2; | ||
| 1113 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); | ||
| 1114 | const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(ggml_half)))); | ||
| 1115 | |||
| 1116 | for (int m = 0; m < nr; ++m) { | ||
| 1117 | const __m512 vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].d)); | ||
| 1118 | const __m512 vs1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].s)); | ||
| 1119 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); | ||
| 1120 | |||
| 1121 | __m512 vsum; | ||
| 1122 | if (is_acc) { | ||
| 1123 | vsum = _mm512_loadu_ps(C + m * ldc); | ||
| 1124 | } else { | ||
| 1125 | vsum = _mm512_set1_ps(0.f); | ||
| 1126 | } | ||
| 1127 | vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); | ||
| 1128 | vsum = _mm512_fmadd_ps(vm0, vs1, vsum); | ||
| 1129 | _mm512_storeu_ps(C + m * ldc, vsum); | ||
| 1130 | } | ||
| 1131 | } | ||
| 1132 | }; | ||
| 1133 | |||
| 1134 | template <bool is_acc> | ||
| 1135 | struct acc_C<block_q8_0, block_q8_0, is_acc> { | ||
| 1136 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { | ||
| 1137 | const int offset = TILE_N * TILE_K; | ||
| 1138 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); | ||
| 1139 | |||
| 1140 | for (int m = 0; m < nr; ++m) { | ||
| 1141 | const __m512 vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].d)); | ||
| 1142 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); | ||
| 1143 | |||
| 1144 | __m512 vsum; | ||
| 1145 | if (is_acc) { | ||
| 1146 | vsum = _mm512_loadu_ps(C + m * ldc); | ||
| 1147 | } else { | ||
| 1148 | vsum = _mm512_set1_ps(0.f); | ||
| 1149 | } | ||
| 1150 | vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); | ||
| 1151 | _mm512_storeu_ps(C + m * ldc, vsum); | ||
| 1152 | } | ||
| 1153 | } | ||
| 1154 | }; | ||
| 1155 | |||
| 1156 | template <bool is_acc> | ||
| 1157 | struct acc_C<block_q8_K, block_q4_K, is_acc> { | ||
| 1158 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { | ||
| 1159 | const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N); | ||
| 1160 | const uint8_t * mins = scales + 8 * TILE_N; | ||
| 1161 | const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N); | ||
| 1162 | const ggml_half * dmin = d0 + TILE_N; | ||
| 1163 | |||
| 1164 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); | ||
| 1165 | const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); | ||
| 1166 | |||
| 1167 | for (int m = 0; m < nr; ++m) { | ||
| 1168 | const float d1 = A[m * lda].d; | ||
| 1169 | const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); | ||
| 1170 | const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); | ||
| 1171 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); | ||
| 1172 | |||
| 1173 | __m512 vsum; | ||
| 1174 | if (is_acc) { | ||
| 1175 | vsum = _mm512_loadu_ps(C + m * ldc); | ||
| 1176 | } else { | ||
| 1177 | vsum = _mm512_set1_ps(0.f); | ||
| 1178 | } | ||
| 1179 | |||
| 1180 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); | ||
| 1181 | const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); | ||
| 1182 | |||
| 1183 | __m512i acc_m = _mm512_setzero_si512(); | ||
| 1184 | for (int k = 0; k < 4; ++k) { | ||
| 1185 | __m512i vmask = _mm512_set1_epi32(k); | ||
| 1186 | __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); | ||
| 1187 | __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); | ||
| 1188 | acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); | ||
| 1189 | } | ||
| 1190 | |||
| 1191 | vsum = _mm512_fmadd_ps(vtile, vd, vsum); | ||
| 1192 | vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); | ||
| 1193 | _mm512_storeu_ps(C + m * ldc, vsum); | ||
| 1194 | } | ||
| 1195 | } | ||
| 1196 | }; | ||
| 1197 | |||
| 1198 | template <bool is_acc> | ||
| 1199 | struct acc_C<block_q8_K, block_q5_K, is_acc> { | ||
| 1200 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { | ||
| 1201 | const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); | ||
| 1202 | const uint8_t * mins = scales + 8 * TILE_N; | ||
| 1203 | const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N); | ||
| 1204 | const ggml_half * dmin = d0 + TILE_N; | ||
| 1205 | |||
| 1206 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); | ||
| 1207 | const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); | ||
| 1208 | |||
| 1209 | for (int m = 0; m < nr; ++m) { | ||
| 1210 | const float d1 = A[m * lda].d; | ||
| 1211 | const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); | ||
| 1212 | const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); | ||
| 1213 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); | ||
| 1214 | |||
| 1215 | __m512 vsum; | ||
| 1216 | if (is_acc) { | ||
| 1217 | vsum = _mm512_loadu_ps(C + m * ldc); | ||
| 1218 | } else { | ||
| 1219 | vsum = _mm512_set1_ps(0.f); | ||
| 1220 | } | ||
| 1221 | |||
| 1222 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); | ||
| 1223 | const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); | ||
| 1224 | |||
| 1225 | __m512i acc_m = _mm512_setzero_si512(); | ||
| 1226 | for (int k = 0; k < 4; ++k) { | ||
| 1227 | __m512i vmask = _mm512_set1_epi32(k); | ||
| 1228 | __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); | ||
| 1229 | __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); | ||
| 1230 | acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); | ||
| 1231 | } | ||
| 1232 | |||
| 1233 | vsum = _mm512_fmadd_ps(vtile, vd, vsum); | ||
| 1234 | vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); | ||
| 1235 | _mm512_storeu_ps(C + m * ldc, vsum); | ||
| 1236 | } | ||
| 1237 | } | ||
| 1238 | }; | ||
| 1239 | |||
| 1240 | template <bool is_acc> | ||
| 1241 | struct acc_C<block_q8_K, block_q6_K, is_acc> { | ||
| 1242 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { | ||
| 1243 | const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); | ||
| 1244 | const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 16 * TILE_N); | ||
| 1245 | |||
| 1246 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); | ||
| 1247 | |||
| 1248 | for (int m = 0; m < nr; ++m) { | ||
| 1249 | const float d1 = A[m * lda].d; | ||
| 1250 | const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); | ||
| 1251 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); | ||
| 1252 | |||
| 1253 | __m512 vsum; | ||
| 1254 | if (is_acc) { | ||
| 1255 | vsum = _mm512_loadu_ps(C + m * ldc); | ||
| 1256 | } else { | ||
| 1257 | vsum = _mm512_set1_ps(0.f); | ||
| 1258 | } | ||
| 1259 | |||
| 1260 | vsum = _mm512_fmadd_ps(vtile, vd, vsum); | ||
| 1261 | _mm512_storeu_ps(C + m * ldc, vsum); | ||
| 1262 | } | ||
| 1263 | } | ||
| 1264 | }; | ||
| 1265 | |||
| 1266 | template <bool is_acc> | ||
| 1267 | struct acc_C<block_q8_K, block_iq4_xs, is_acc> { | ||
| 1268 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { | ||
| 1269 | const int8_t * scales = reinterpret_cast<const int8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N); | ||
| 1270 | const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 8 * TILE_N); | ||
| 1271 | |||
| 1272 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); | ||
| 1273 | |||
| 1274 | for (int m = 0; m < nr; ++m) { | ||
| 1275 | const float d1 = A[m * lda].d; | ||
| 1276 | const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); | ||
| 1277 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); | ||
| 1278 | |||
| 1279 | __m512 vsum; | ||
| 1280 | if (is_acc) { | ||
| 1281 | vsum = _mm512_loadu_ps(C + m * ldc); | ||
| 1282 | } else { | ||
| 1283 | vsum = _mm512_set1_ps(0.f); | ||
| 1284 | } | ||
| 1285 | |||
| 1286 | vsum = _mm512_fmadd_ps(vtile, vd, vsum); | ||
| 1287 | _mm512_storeu_ps(C + m * ldc, vsum); | ||
| 1288 | } | ||
| 1289 | } | ||
| 1290 | }; | ||
| 1291 | |||
| 1292 | template <typename TB> constexpr int get_quants_size(); | ||
| 1293 | template <> constexpr int get_quants_size<block_q4_K>() { return (QK_K / 2) * TILE_N; } | ||
| 1294 | template <> constexpr int get_quants_size<block_q5_K>() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; } | ||
| 1295 | template <> constexpr int get_quants_size<block_q6_K>() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; } | ||
| 1296 | template <> constexpr int get_quants_size<block_iq4_xs>() { return (QK_K / 2) * TILE_N; } | ||
| 1297 | |||
| 1298 | // used for QKK format | ||
| 1299 | template <typename TB, bool is_acc, | ||
| 1300 | typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0> | ||
| 1301 | inline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) { | ||
| 1302 | const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + get_quants_size<TB>()); | ||
| 1303 | const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N))); | ||
| 1304 | |||
| 1305 | for (int m = 0; m < nr; ++m) { | ||
| 1306 | __m512i vsumi; | ||
| 1307 | if (is_acc) { | ||
| 1308 | vsumi = _mm512_loadu_si512(sumi + m * TILE_N); | ||
| 1309 | } else { | ||
| 1310 | vsumi = _mm512_setzero_si512(); | ||
| 1311 | } | ||
| 1312 | __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N); | ||
| 1313 | vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale)); | ||
| 1314 | _mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi); | ||
| 1315 | } | ||
| 1316 | } | ||
| 1317 | |||
| 1318 | template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K> | ||
| 1319 | struct tinygemm_kernel_avx { | ||
| 1320 | static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) { | ||
| 1321 | GGML_UNUSED(K); | ||
| 1322 | GGML_UNUSED(A); | ||
| 1323 | GGML_UNUSED(B); | ||
| 1324 | GGML_UNUSED(C); | ||
| 1325 | GGML_UNUSED(ldc); | ||
| 1326 | } | ||
| 1327 | }; | ||
| 1328 | |||
| 1329 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> | ||
| 1330 | struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K> { | ||
| 1331 | static void apply(int K, const float * RESTRICT A, const ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) { | ||
| 1332 | constexpr int ROWS = BLOCK_M; | ||
| 1333 | constexpr int COLS = BLOCK_N; | ||
| 1334 | assert(BLOCK_K == 16); | ||
| 1335 | |||
| 1336 | __m512 va; | ||
| 1337 | __m512 vb[COLS]; | ||
| 1338 | __m512 vc[ROWS * COLS]; | ||
| 1339 | |||
| 1340 | auto loadc = [&](auto idx) { | ||
| 1341 | vc[idx] = _mm512_setzero_ps(); | ||
| 1342 | }; | ||
| 1343 | Unroll<ROWS * COLS>{}(loadc); | ||
| 1344 | |||
| 1345 | auto compute = [&](auto idx, auto k) { | ||
| 1346 | constexpr int row = idx / COLS; | ||
| 1347 | constexpr int col = idx % COLS; | ||
| 1348 | |||
| 1349 | if constexpr (col == 0) { | ||
| 1350 | va = _mm512_loadu_ps(A + row * K + k); | ||
| 1351 | } | ||
| 1352 | if constexpr (row == 0) { | ||
| 1353 | vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k))); | ||
| 1354 | } | ||
| 1355 | vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]); | ||
| 1356 | }; | ||
| 1357 | |||
| 1358 | for (int k = 0; k < K; k += 16) { | ||
| 1359 | Unroll<ROWS * COLS>{}(compute, k); | ||
| 1360 | } | ||
| 1361 | |||
| 1362 | auto storec = [&](auto idx) { | ||
| 1363 | constexpr int row = idx / COLS; | ||
| 1364 | constexpr int col = idx % COLS; | ||
| 1365 | C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]); | ||
| 1366 | }; | ||
| 1367 | Unroll<ROWS * COLS>{}(storec); | ||
| 1368 | } | ||
| 1369 | }; | ||
| 1370 | |||
| 1371 | #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \ | ||
| 1372 | tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \ | ||
| 1373 | K, (const float *)src1->data + mb_start * K, \ | ||
| 1374 | (const type *)src0->data + nb_start * K, \ | ||
| 1375 | (float *)dst->data + mb_start * ldc + nb_start, ldc); | ||
| 1376 | |||
| 1377 | |||
| 1378 | // re-organize in the format {NB, KB, TILE_SIZE}: | ||
| 1379 | #define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size | ||
| 1380 | |||
| 1381 | template<typename TB, int BLOCK_K> | ||
| 1382 | void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K) { | ||
| 1383 | const int NB = N / TILE_N; | ||
| 1384 | const int KB = K / BLOCK_K; | ||
| 1385 | const int TILE_SIZE = get_tile_size<TB>(); | ||
| 1386 | |||
| 1387 | // parallel on NB should be enough | ||
| 1388 | parallel_for(NB, [&](int begin, int end) { | ||
| 1389 | for (int n = begin; n < end; ++n) { | ||
| 1390 | for (int k = 0; k < KB; ++k) { | ||
| 1391 | int n0 = n * TILE_N; | ||
| 1392 | pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB); | ||
| 1393 | } | ||
| 1394 | } | ||
| 1395 | }); | ||
| 1396 | } | ||
| 1397 | |||
| 1398 | template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K> | ||
| 1399 | struct tinygemm_kernel_vnni {}; | ||
| 1400 | |||
| 1401 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> | ||
| 1402 | struct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLOCK_K> { | ||
| 1403 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { | ||
| 1404 | |||
| 1405 | constexpr int COLS = BLOCK_N / 16; | ||
| 1406 | const int TILE_SIZE = TILE_N * sizeof(block_q4_0); | ||
| 1407 | |||
| 1408 | const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A); | ||
| 1409 | const char * RESTRICT B = static_cast<const char *>(_B); | ||
| 1410 | |||
| 1411 | __m512i va[8]; | ||
| 1412 | __m512 vc[COLS]; | ||
| 1413 | __m512 vd1; | ||
| 1414 | |||
| 1415 | // sum of offsets, shared across COLS | ||
| 1416 | // | ||
| 1417 | // avx512-vnni does not have `_mm512_dpbssd_epi32`, | ||
| 1418 | // need to transfrom ss to us: | ||
| 1419 | // a * (b - 8) is equavilent to b * a - 8 * a | ||
| 1420 | // s u u u s u s | ||
| 1421 | // | ||
| 1422 | __m512i vcomp; | ||
| 1423 | |||
| 1424 | const __m512i off = _mm512_set1_epi8(8); | ||
| 1425 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 1426 | |||
| 1427 | auto loadc = [&](auto col) { | ||
| 1428 | vc[col] = _mm512_setzero_ps(); | ||
| 1429 | }; | ||
| 1430 | Unroll<COLS>{}(loadc); | ||
| 1431 | |||
| 1432 | auto compute = [&](auto col, auto i) { | ||
| 1433 | // load a and compute compensation | ||
| 1434 | if constexpr (col == 0) { | ||
| 1435 | const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs); | ||
| 1436 | vcomp = _mm512_setzero_si512(); | ||
| 1437 | for (int k = 0; k < 8; ++k) { | ||
| 1438 | va[k] = _mm512_set1_epi32(a_ptr[k]); | ||
| 1439 | vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]); | ||
| 1440 | } | ||
| 1441 | vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].d)); | ||
| 1442 | } | ||
| 1443 | |||
| 1444 | // load b | ||
| 1445 | __m512i vsum = _mm512_setzero_si512(); | ||
| 1446 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); | ||
| 1447 | for (int k = 0; k < 8; k += 2) { | ||
| 1448 | __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); | ||
| 1449 | __m512i vb0 = _mm512_and_si512(bytes, lowMask); | ||
| 1450 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]); | ||
| 1451 | __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); | ||
| 1452 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]); | ||
| 1453 | } | ||
| 1454 | const int offset = TILE_N * TILE_K / 2; | ||
| 1455 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); | ||
| 1456 | vsum = _mm512_sub_epi32(vsum, vcomp); | ||
| 1457 | |||
| 1458 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); | ||
| 1459 | }; | ||
| 1460 | |||
| 1461 | for (int i = 0; i < KB; ++i) { | ||
| 1462 | Unroll<COLS>{}(compute, i); | ||
| 1463 | } | ||
| 1464 | |||
| 1465 | //store to C | ||
| 1466 | auto storec = [&](auto col) { | ||
| 1467 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); | ||
| 1468 | }; | ||
| 1469 | Unroll<COLS>{}(storec); | ||
| 1470 | } | ||
| 1471 | }; | ||
| 1472 | |||
| 1473 | template <int BLOCK_N, int BLOCK_K> | ||
| 1474 | struct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K> { | ||
| 1475 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { | ||
| 1476 | |||
| 1477 | constexpr int COLS = BLOCK_N / 16; | ||
| 1478 | const int TILE_SIZE = TILE_N * sizeof(block_q4_1); | ||
| 1479 | |||
| 1480 | const block_q8_1 * RESTRICT A = static_cast<const block_q8_1 *>(_A); | ||
| 1481 | const char * RESTRICT B = static_cast<const char *>(_B); | ||
| 1482 | |||
| 1483 | __m512i va[8]; | ||
| 1484 | __m512i vb[8]; | ||
| 1485 | __m512 vc[COLS]; | ||
| 1486 | __m512 vd1, vs1; | ||
| 1487 | |||
| 1488 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 1489 | |||
| 1490 | auto loadc = [&](auto col) { | ||
| 1491 | vc[col] = _mm512_setzero_ps(); | ||
| 1492 | }; | ||
| 1493 | Unroll<COLS>{}(loadc); | ||
| 1494 | |||
| 1495 | auto compute = [&](auto col, auto i) { | ||
| 1496 | // load a | ||
| 1497 | if constexpr (col == 0) { | ||
| 1498 | const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs); | ||
| 1499 | for (int k = 0; k < 8; ++k) { | ||
| 1500 | va[k] = _mm512_set1_epi32(a_ptr[k]); | ||
| 1501 | } | ||
| 1502 | vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].d)); | ||
| 1503 | vs1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].s)); | ||
| 1504 | } | ||
| 1505 | |||
| 1506 | // load b | ||
| 1507 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); | ||
| 1508 | for (int k = 0; k < 8; k += 2) { | ||
| 1509 | __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); | ||
| 1510 | vb[k + 0] = _mm512_and_si512(bytes, lowMask); | ||
| 1511 | vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); | ||
| 1512 | } | ||
| 1513 | const int offset = TILE_N * TILE_K / 2; | ||
| 1514 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); | ||
| 1515 | const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(ggml_half)))); | ||
| 1516 | |||
| 1517 | __m512i vsum = _mm512_setzero_si512(); | ||
| 1518 | for (int k = 0; k < 8; ++k) { | ||
| 1519 | vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]); | ||
| 1520 | } | ||
| 1521 | |||
| 1522 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); | ||
| 1523 | vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]); | ||
| 1524 | }; | ||
| 1525 | |||
| 1526 | for (int i = 0; i < KB; ++i) { | ||
| 1527 | Unroll<COLS>{}(compute, i); | ||
| 1528 | } | ||
| 1529 | |||
| 1530 | //store to C | ||
| 1531 | auto storec = [&](auto col) { | ||
| 1532 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); | ||
| 1533 | }; | ||
| 1534 | Unroll<COLS>{}(storec); | ||
| 1535 | } | ||
| 1536 | }; | ||
| 1537 | |||
| 1538 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> | ||
| 1539 | struct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLOCK_K> { | ||
| 1540 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { | ||
| 1541 | |||
| 1542 | constexpr int COLS = BLOCK_N / 16; | ||
| 1543 | const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t); | ||
| 1544 | |||
| 1545 | const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A); | ||
| 1546 | const char * RESTRICT B = static_cast<const char *>(_B); | ||
| 1547 | |||
| 1548 | __m512i va[8]; | ||
| 1549 | __m512i vb[8]; | ||
| 1550 | __m512 vc[COLS]; | ||
| 1551 | __m512 vd1; | ||
| 1552 | |||
| 1553 | // Notes: s8s8 igemm compensation in avx512-vnni | ||
| 1554 | // change s8s8 to u8s8 with compensate | ||
| 1555 | // a * b = (a + 128) * b - 128 * b | ||
| 1556 | // s s u s u s | ||
| 1557 | // | ||
| 1558 | // (128 * b is pre-computed when packing B to vnni formats) | ||
| 1559 | // | ||
| 1560 | const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80)); | ||
| 1561 | |||
| 1562 | auto loadc = [&](auto col) { | ||
| 1563 | vc[col] = _mm512_setzero_ps(); | ||
| 1564 | }; | ||
| 1565 | Unroll<COLS>{}(loadc); | ||
| 1566 | |||
| 1567 | auto compute = [&](auto col, auto i) { | ||
| 1568 | // load a and add offset 128 | ||
| 1569 | if constexpr (col == 0) { | ||
| 1570 | const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs); | ||
| 1571 | for (int k = 0; k < 8; ++k) { | ||
| 1572 | va[k] = _mm512_set1_epi32(a_ptr[k]); | ||
| 1573 | va[k] = _mm512_add_epi8(va[k], off); | ||
| 1574 | } | ||
| 1575 | vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].d)); | ||
| 1576 | } | ||
| 1577 | |||
| 1578 | // load b | ||
| 1579 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); | ||
| 1580 | for (int k = 0; k < 8; ++k) { | ||
| 1581 | vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64)); | ||
| 1582 | } | ||
| 1583 | const int offset = TILE_N * TILE_K; | ||
| 1584 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); | ||
| 1585 | const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half); | ||
| 1586 | const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2)); | ||
| 1587 | |||
| 1588 | __m512i vsum = _mm512_setzero_si512(); | ||
| 1589 | for (int k = 0; k < 8; ++k) { | ||
| 1590 | vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]); | ||
| 1591 | } | ||
| 1592 | vsum = _mm512_sub_epi32(vsum, vcomp); | ||
| 1593 | |||
| 1594 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); | ||
| 1595 | }; | ||
| 1596 | |||
| 1597 | for (int i = 0; i < KB; ++i) { | ||
| 1598 | Unroll<COLS>{}(compute, i); | ||
| 1599 | } | ||
| 1600 | |||
| 1601 | //store to C | ||
| 1602 | auto storec = [&](auto col) { | ||
| 1603 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); | ||
| 1604 | }; | ||
| 1605 | Unroll<COLS>{}(storec); | ||
| 1606 | } | ||
| 1607 | }; | ||
| 1608 | |||
| 1609 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> | ||
| 1610 | struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLOCK_K> { | ||
| 1611 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { | ||
| 1612 | |||
| 1613 | constexpr int COLS = BLOCK_N / 16; | ||
| 1614 | const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4; | ||
| 1615 | |||
| 1616 | const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A); | ||
| 1617 | const char * RESTRICT B = static_cast<const char *>(_B); | ||
| 1618 | |||
| 1619 | // a.qs: 8 groups, 32 bytes each group (m256i) | ||
| 1620 | __m512i va[8]; | ||
| 1621 | // a.bsum: 8 groups, 2 bytes each group (m128i) | ||
| 1622 | __m512i va_bsum; | ||
| 1623 | __m512 vc[COLS]; | ||
| 1624 | __m512 vd1; | ||
| 1625 | |||
| 1626 | // packed_B: | ||
| 1627 | const int offset_scales = (QK_K / 2) * TILE_N; | ||
| 1628 | const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N; | ||
| 1629 | const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N; | ||
| 1630 | const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half); | ||
| 1631 | |||
| 1632 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 1633 | |||
| 1634 | auto loadc = [&](auto col) { | ||
| 1635 | vc[col] = _mm512_setzero_ps(); | ||
| 1636 | }; | ||
| 1637 | Unroll<COLS>{}(loadc); | ||
| 1638 | |||
| 1639 | // Notes: vnni formats in QK_K | ||
| 1640 | // a) quants vnni format | ||
| 1641 | // int8 {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32 | ||
| 1642 | // from {16, 32} to {8, 64} | ||
| 1643 | // | ||
| 1644 | // b) min vnni format | ||
| 1645 | // int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8 | ||
| 1646 | // from {16, 8} to {4, 32} | ||
| 1647 | // | ||
| 1648 | auto compute = [&](auto col, auto i) { | ||
| 1649 | // load a | ||
| 1650 | if constexpr (col == 0) { | ||
| 1651 | for (int k_group = 0; k_group < QK_K / 32; ++k_group) { | ||
| 1652 | va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); | ||
| 1653 | } | ||
| 1654 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); | ||
| 1655 | const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); | ||
| 1656 | va_bsum = _mm512_castsi128_si512(q8s); | ||
| 1657 | vd1 = _mm512_set1_ps(A[0 * KB + i].d); | ||
| 1658 | } | ||
| 1659 | |||
| 1660 | // step 1: accumultate the quants | ||
| 1661 | __m512i acc = _mm512_setzero_si512(); | ||
| 1662 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); | ||
| 1663 | const char * b_qs = b_ptr; | ||
| 1664 | for (int k_group = 0; k_group < QK_K / 32; ++k_group) { | ||
| 1665 | __m512i vsum = _mm512_setzero_si512(); | ||
| 1666 | for (int k = 0; k < 8; k += 2) { | ||
| 1667 | __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); | ||
| 1668 | __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); | ||
| 1669 | |||
| 1670 | __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); | ||
| 1671 | __m512i vb0 = _mm512_and_si512(bytes, lowMask); | ||
| 1672 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); | ||
| 1673 | __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); | ||
| 1674 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); | ||
| 1675 | |||
| 1676 | b_qs += 64; | ||
| 1677 | } | ||
| 1678 | // vacc += scale * (q8 @ q4) | ||
| 1679 | const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); | ||
| 1680 | acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); | ||
| 1681 | } | ||
| 1682 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); | ||
| 1683 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); | ||
| 1684 | |||
| 1685 | // step 2: accumulate the mins | ||
| 1686 | __m512i acc_m = _mm512_setzero_si512(); | ||
| 1687 | for (int k = 0; k < 4; ++k) { | ||
| 1688 | __m512i vmask = _mm512_set1_epi32(k); | ||
| 1689 | __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); | ||
| 1690 | __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); | ||
| 1691 | acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); | ||
| 1692 | } | ||
| 1693 | const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); | ||
| 1694 | vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); | ||
| 1695 | }; | ||
| 1696 | |||
| 1697 | for (int i = 0; i < KB; ++i) { | ||
| 1698 | Unroll<COLS>{}(compute, i); | ||
| 1699 | } | ||
| 1700 | |||
| 1701 | //store to C | ||
| 1702 | auto storec = [&](auto col) { | ||
| 1703 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); | ||
| 1704 | }; | ||
| 1705 | Unroll<COLS>{}(storec); | ||
| 1706 | } | ||
| 1707 | }; | ||
| 1708 | |||
| 1709 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> | ||
| 1710 | struct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLOCK_K> { | ||
| 1711 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { | ||
| 1712 | |||
| 1713 | constexpr int COLS = BLOCK_N / 16; | ||
| 1714 | const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4; | ||
| 1715 | |||
| 1716 | const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A); | ||
| 1717 | const char * RESTRICT B = static_cast<const char *>(_B); | ||
| 1718 | |||
| 1719 | // a.qs: 8 groups, 32 bytes each group (m256i) | ||
| 1720 | __m512i va[8]; | ||
| 1721 | // a.bsum: 8 groups, 2 bytes each group (m128i) | ||
| 1722 | __m512i va_bsum; | ||
| 1723 | __m512 vc[COLS]; | ||
| 1724 | __m512 vd1; | ||
| 1725 | |||
| 1726 | // packed_B: | ||
| 1727 | const int offset_qh = (QK_K / 2) * TILE_N; | ||
| 1728 | const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; | ||
| 1729 | const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N; | ||
| 1730 | const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N; | ||
| 1731 | const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half); | ||
| 1732 | |||
| 1733 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 1734 | |||
| 1735 | auto loadc = [&](auto col) { | ||
| 1736 | vc[col] = _mm512_setzero_ps(); | ||
| 1737 | }; | ||
| 1738 | Unroll<COLS>{}(loadc); | ||
| 1739 | |||
| 1740 | // Q5_K and Q4_K shares the same vnni formats, refer to notes above. | ||
| 1741 | auto compute = [&](auto col, auto i) { | ||
| 1742 | // load a | ||
| 1743 | if constexpr (col == 0) { | ||
| 1744 | for (int k_group = 0; k_group < QK_K / 32; ++k_group) { | ||
| 1745 | va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); | ||
| 1746 | } | ||
| 1747 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); | ||
| 1748 | const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); | ||
| 1749 | va_bsum = _mm512_castsi128_si512(q8s); | ||
| 1750 | vd1 = _mm512_set1_ps(A[0 * KB + i].d); | ||
| 1751 | } | ||
| 1752 | |||
| 1753 | // step 1: accumultate the quants | ||
| 1754 | __m512i acc = _mm512_setzero_si512(); | ||
| 1755 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); | ||
| 1756 | const char * b_qs = b_ptr; | ||
| 1757 | const char * b_qh = b_ptr + offset_qh; | ||
| 1758 | for (int k_group = 0; k_group < QK_K / 32; ++k_group) { | ||
| 1759 | __m512i vsum = _mm512_setzero_si512(); | ||
| 1760 | __m512i hmask0 = _mm512_set1_epi8(0x1); | ||
| 1761 | __m512i hmask1 = _mm512_set1_epi8(0x2); | ||
| 1762 | __m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64)); | ||
| 1763 | for (int k = 0; k < 8; k += 2) { | ||
| 1764 | __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); | ||
| 1765 | __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); | ||
| 1766 | |||
| 1767 | __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); | ||
| 1768 | __m512i vb0 = _mm512_and_si512(bytes, lowMask); | ||
| 1769 | __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); | ||
| 1770 | |||
| 1771 | __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4); | ||
| 1772 | __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4); | ||
| 1773 | |||
| 1774 | hmask0 = _mm512_slli_epi16(hmask0, 2); | ||
| 1775 | hmask1 = _mm512_slli_epi16(hmask1, 2); | ||
| 1776 | vb0 = _mm512_add_epi8(vb0, vh0); | ||
| 1777 | vb1 = _mm512_add_epi8(vb1, vh1); | ||
| 1778 | |||
| 1779 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); | ||
| 1780 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); | ||
| 1781 | |||
| 1782 | b_qs += 64; | ||
| 1783 | } | ||
| 1784 | // vacc += scale * (q8 @ q5) | ||
| 1785 | const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); | ||
| 1786 | acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); | ||
| 1787 | } | ||
| 1788 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); | ||
| 1789 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); | ||
| 1790 | |||
| 1791 | // step 2: accumulate the mins | ||
| 1792 | __m512i acc_m = _mm512_setzero_si512(); | ||
| 1793 | for (int k = 0; k < 4; ++k) { | ||
| 1794 | __m512i vmask = _mm512_set1_epi32(k); | ||
| 1795 | __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); | ||
| 1796 | __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); | ||
| 1797 | acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); | ||
| 1798 | } | ||
| 1799 | const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); | ||
| 1800 | vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); | ||
| 1801 | }; | ||
| 1802 | |||
| 1803 | for (int i = 0; i < KB; ++i) { | ||
| 1804 | Unroll<COLS>{}(compute, i); | ||
| 1805 | } | ||
| 1806 | |||
| 1807 | //store to C | ||
| 1808 | auto storec = [&](auto col) { | ||
| 1809 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); | ||
| 1810 | }; | ||
| 1811 | Unroll<COLS>{}(storec); | ||
| 1812 | } | ||
| 1813 | }; | ||
| 1814 | |||
| 1815 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> | ||
| 1816 | struct tinygemm_kernel_vnni<block_q8_K, block_q6_K, float, BLOCK_M, BLOCK_N, BLOCK_K> { | ||
| 1817 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { | ||
| 1818 | |||
| 1819 | constexpr int COLS = BLOCK_N / 16; | ||
| 1820 | const int TILE_SIZE = TILE_N * sizeof(block_q6_K); | ||
| 1821 | |||
| 1822 | const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A); | ||
| 1823 | const char * RESTRICT B = static_cast<const char *>(_B); | ||
| 1824 | |||
| 1825 | // load the 256 bytes from A to 4 avx512 vectors | ||
| 1826 | __m512i va[4]; | ||
| 1827 | __m512 vc[COLS]; | ||
| 1828 | __m512 vd1; | ||
| 1829 | |||
| 1830 | // packed_B: | ||
| 1831 | const int offset_qh = (QK_K / 2) * TILE_N; | ||
| 1832 | const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; | ||
| 1833 | const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N; | ||
| 1834 | |||
| 1835 | // compensation | ||
| 1836 | __m512i vcomp; | ||
| 1837 | |||
| 1838 | const __m512i m32s = _mm512_set1_epi32(32); | ||
| 1839 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 1840 | |||
| 1841 | auto loadc = [&](auto col) { | ||
| 1842 | vc[col] = _mm512_setzero_ps(); | ||
| 1843 | }; | ||
| 1844 | Unroll<COLS>{}(loadc); | ||
| 1845 | |||
| 1846 | auto compute = [&](auto col, auto i) { | ||
| 1847 | if constexpr (col == 0) { | ||
| 1848 | // load a | ||
| 1849 | va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); | ||
| 1850 | va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); | ||
| 1851 | va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); | ||
| 1852 | va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); | ||
| 1853 | |||
| 1854 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); | ||
| 1855 | vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s); | ||
| 1856 | vd1 = _mm512_set1_ps(A[0 * KB + i].d); | ||
| 1857 | } | ||
| 1858 | |||
| 1859 | // accmulate the quants | ||
| 1860 | __m512i acc = _mm512_setzero_si512(); | ||
| 1861 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); | ||
| 1862 | const char * b_qs = b_ptr; | ||
| 1863 | const char * b_qh = b_ptr + offset_qh; | ||
| 1864 | int mask = 0; | ||
| 1865 | for (int k_group = 0; k_group < QK_K / 16; ++k_group) { | ||
| 1866 | int r = k_group >> 2; | ||
| 1867 | __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); | ||
| 1868 | __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); | ||
| 1869 | |||
| 1870 | __m512i vsum = _mm512_setzero_si512(); | ||
| 1871 | __m512i hmask = _mm512_set1_epi8(0x3); | ||
| 1872 | |||
| 1873 | __m512i bytes = _mm512_loadu_si512(b_qs); | ||
| 1874 | __m512i hbits = _mm512_loadu_si512(b_qh); | ||
| 1875 | __m512i vb0 = _mm512_and_si512(bytes, lowMask); | ||
| 1876 | __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); | ||
| 1877 | __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4); | ||
| 1878 | __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2); | ||
| 1879 | |||
| 1880 | vb0 = _mm512_add_epi8(vb0, vh0); | ||
| 1881 | vb1 = _mm512_add_epi8(vb1, vh1); | ||
| 1882 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); | ||
| 1883 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); | ||
| 1884 | b_qs += 64; | ||
| 1885 | |||
| 1886 | va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); | ||
| 1887 | va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); | ||
| 1888 | |||
| 1889 | bytes = _mm512_loadu_si512(b_qs); | ||
| 1890 | vb0 = _mm512_and_si512(bytes, lowMask); | ||
| 1891 | vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); | ||
| 1892 | vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4)); | ||
| 1893 | vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2); | ||
| 1894 | vb0 = _mm512_add_epi8(vb0, vh0); | ||
| 1895 | vb1 = _mm512_add_epi8(vb1, vh1); | ||
| 1896 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); | ||
| 1897 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); | ||
| 1898 | b_qs += 64; | ||
| 1899 | b_qh += 64; | ||
| 1900 | |||
| 1901 | // B * A - 32 * A | ||
| 1902 | __m512i vmask = _mm512_set1_epi32(k_group); | ||
| 1903 | vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); | ||
| 1904 | |||
| 1905 | // vacc += scale * (q8 @ q6) | ||
| 1906 | const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); | ||
| 1907 | acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); | ||
| 1908 | } | ||
| 1909 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); | ||
| 1910 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); | ||
| 1911 | }; | ||
| 1912 | |||
| 1913 | for (int i = 0; i < KB; ++i) { | ||
| 1914 | Unroll<COLS>{}(compute, i); | ||
| 1915 | } | ||
| 1916 | |||
| 1917 | //store to C | ||
| 1918 | auto storec = [&](int col) { | ||
| 1919 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); | ||
| 1920 | }; | ||
| 1921 | Unroll<COLS>{}(storec); | ||
| 1922 | } | ||
| 1923 | }; | ||
| 1924 | |||
| 1925 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> | ||
| 1926 | struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, BLOCK_K> { | ||
| 1927 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { | ||
| 1928 | |||
| 1929 | constexpr int COLS = BLOCK_N / 16; | ||
| 1930 | const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2; | ||
| 1931 | |||
| 1932 | const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A); | ||
| 1933 | const char * RESTRICT B = static_cast<const char *>(_B); | ||
| 1934 | |||
| 1935 | // load the 256 bytes from A to 4 avx512 vectors | ||
| 1936 | __m512i va[4]; | ||
| 1937 | __m512 vc[COLS]; | ||
| 1938 | __m512 vd1; | ||
| 1939 | |||
| 1940 | // packed_B: | ||
| 1941 | const int offset_scales = (QK_K / 2) * TILE_N ; | ||
| 1942 | const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N; | ||
| 1943 | |||
| 1944 | // compensation | ||
| 1945 | __m512i vcomp; | ||
| 1946 | |||
| 1947 | const __m256i m128s = _mm256_set1_epi16(128); | ||
| 1948 | const __m512i lowMask = _mm512_set1_epi8(0xF); | ||
| 1949 | |||
| 1950 | const __m512i values128 = _mm512_set_epi8( | ||
| 1951 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, | ||
| 1952 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, | ||
| 1953 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, | ||
| 1954 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 | ||
| 1955 | ); | ||
| 1956 | const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80)); | ||
| 1957 | const __m512i values256 = _mm512_add_epi8(values128, off); | ||
| 1958 | |||
| 1959 | auto loadc = [&](auto col) { | ||
| 1960 | vc[col] = _mm512_setzero_ps(); | ||
| 1961 | }; | ||
| 1962 | Unroll<COLS>{}(loadc); | ||
| 1963 | |||
| 1964 | auto compute = [&](auto col, auto i) { | ||
| 1965 | if constexpr (col == 0) { | ||
| 1966 | // load a | ||
| 1967 | va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); | ||
| 1968 | va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); | ||
| 1969 | va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); | ||
| 1970 | va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); | ||
| 1971 | |||
| 1972 | // compensation: 128 * A | ||
| 1973 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); | ||
| 1974 | vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s)); | ||
| 1975 | vd1 = _mm512_set1_ps(A[0 * KB + i].d); | ||
| 1976 | } | ||
| 1977 | |||
| 1978 | // accmulate the quants | ||
| 1979 | __m512i acc = _mm512_setzero_si512(); | ||
| 1980 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); | ||
| 1981 | const char * b_qs = b_ptr; | ||
| 1982 | int mask = 0; | ||
| 1983 | for (int k_group = 0; k_group < QK_K / 32; ++k_group) { | ||
| 1984 | int r = k_group >> 1; | ||
| 1985 | __m512i vmask = _mm512_set1_epi32(k_group); | ||
| 1986 | __m512i vsum = _mm512_setzero_si512(); | ||
| 1987 | for (int k = 0; k < 8; k += 2) { | ||
| 1988 | __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); | ||
| 1989 | __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); | ||
| 1990 | |||
| 1991 | __m512i bytes = _mm512_loadu_si512(b_qs); | ||
| 1992 | __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask)); | ||
| 1993 | __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); | ||
| 1994 | |||
| 1995 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); | ||
| 1996 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); | ||
| 1997 | b_qs += 64; | ||
| 1998 | } | ||
| 1999 | // (B + 128) * A - 128 * A | ||
| 2000 | vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); | ||
| 2001 | |||
| 2002 | // vacc += scale * (q8 @ q4) | ||
| 2003 | const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); | ||
| 2004 | acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); | ||
| 2005 | } | ||
| 2006 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); | ||
| 2007 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); | ||
| 2008 | }; | ||
| 2009 | |||
| 2010 | for (int i = 0; i < KB; ++i) { | ||
| 2011 | Unroll<COLS>{}(compute, i); | ||
| 2012 | } | ||
| 2013 | |||
| 2014 | //store to C | ||
| 2015 | auto storec = [&](auto col) { | ||
| 2016 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); | ||
| 2017 | }; | ||
| 2018 | Unroll<COLS>{}(storec); | ||
| 2019 | } | ||
| 2020 | }; | ||
| 2021 | |||
| 2022 | #define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \ | ||
| 2023 | tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \ | ||
| 2024 | KB, (const char *)wdata + 0 * row_size_A, \ | ||
| 2025 | (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ | ||
| 2026 | (float *) dst->data + 0 * N + nb_start, ldc) | ||
| 2027 | |||
| 2028 | template <typename TA, typename TB, typename TC, int BLOCK_K, | ||
| 2029 | typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0> | ||
| 2030 | void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) { | ||
| 2031 | using packed_B_t = packed_B_type<TB>; | ||
| 2032 | const int TILE_SIZE = get_tile_size<TB>(); | ||
| 2033 | const bool need_unpack = do_unpack<TB>::value; | ||
| 2034 | |||
| 2035 | GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); | ||
| 2036 | const TA * RESTRICT A = static_cast<const TA *>(_A); | ||
| 2037 | const char * RESTRICT B = static_cast<const char *>(_B); | ||
| 2038 | |||
| 2039 | const int m0 = std::min(M, TILE_M); | ||
| 2040 | const int m1 = std::max(M - TILE_M, 0); | ||
| 2041 | const int lda = KB * sizeof(TA); | ||
| 2042 | //const int ldb = KB * sizeof(TB); | ||
| 2043 | |||
| 2044 | static thread_local packed_B_t Tile0[TILE_N * TILE_K]; | ||
| 2045 | static thread_local packed_B_t Tile1[TILE_N * TILE_K]; | ||
| 2046 | static thread_local int8_t Tile23[TILE_M * TILE_K]; | ||
| 2047 | |||
| 2048 | static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; | ||
| 2049 | static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; | ||
| 2050 | |||
| 2051 | // double buffering C to interleave avx512 and amx | ||
| 2052 | int32_t * C_cur = TileC0; | ||
| 2053 | int32_t * C_pre = TileC1; | ||
| 2054 | |||
| 2055 | auto Tile4 = [&](int32_t * base) { return base; }; | ||
| 2056 | auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; }; | ||
| 2057 | auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; }; | ||
| 2058 | auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; }; | ||
| 2059 | |||
| 2060 | if (M == 2 * TILE_M) { | ||
| 2061 | // i = 0 | ||
| 2062 | const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE); | ||
| 2063 | const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE); | ||
| 2064 | if (need_unpack) { | ||
| 2065 | unpack_B<TB>(Tile0, B_blk0); | ||
| 2066 | _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); | ||
| 2067 | } else { | ||
| 2068 | _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); | ||
| 2069 | } | ||
| 2070 | |||
| 2071 | _tile_zero(TMM4); | ||
| 2072 | _tile_loadd(TMM2, A[0].qs, lda); | ||
| 2073 | _tile_dpbssd(TMM4, TMM2, TMM0); | ||
| 2074 | _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t)); | ||
| 2075 | |||
| 2076 | _tile_zero(TMM5); | ||
| 2077 | _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda); | ||
| 2078 | _tile_dpbssd(TMM5, TMM3, TMM0); | ||
| 2079 | _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t)); | ||
| 2080 | |||
| 2081 | if (need_unpack) { | ||
| 2082 | unpack_B<TB>(Tile1, B_blk0); | ||
| 2083 | _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); | ||
| 2084 | } else { | ||
| 2085 | _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); | ||
| 2086 | } | ||
| 2087 | |||
| 2088 | _tile_zero(TMM6); | ||
| 2089 | _tile_dpbssd(TMM6, TMM2, TMM1); | ||
| 2090 | _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t)); | ||
| 2091 | |||
| 2092 | _tile_zero(TMM7); | ||
| 2093 | _tile_dpbssd(TMM7, TMM3, TMM1); | ||
| 2094 | _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t)); | ||
| 2095 | |||
| 2096 | for (int i = 1; i < KB; ++i) { | ||
| 2097 | // index of previous iter | ||
| 2098 | const int ii = i - 1; | ||
| 2099 | const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); | ||
| 2100 | const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); | ||
| 2101 | GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] { | ||
| 2102 | if (need_unpack) { | ||
| 2103 | unpack_B<TB>(Tile0, B_blk0); | ||
| 2104 | _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); | ||
| 2105 | } else { | ||
| 2106 | _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); | ||
| 2107 | } | ||
| 2108 | _tile_zero(TMM4); | ||
| 2109 | _tile_loadd(TMM2, A[i].qs, lda); | ||
| 2110 | acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); | ||
| 2111 | |||
| 2112 | _tile_dpbssd(TMM4, TMM2, TMM0); | ||
| 2113 | _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); | ||
| 2114 | |||
| 2115 | _tile_zero(TMM5); | ||
| 2116 | _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda); | ||
| 2117 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); | ||
| 2118 | |||
| 2119 | _tile_dpbssd(TMM5, TMM3, TMM0); | ||
| 2120 | _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); | ||
| 2121 | |||
| 2122 | if (need_unpack) { | ||
| 2123 | unpack_B<TB>(Tile1, B_blk1); | ||
| 2124 | _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); | ||
| 2125 | } else { | ||
| 2126 | _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); | ||
| 2127 | } | ||
| 2128 | _tile_zero(TMM6); | ||
| 2129 | acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); | ||
| 2130 | |||
| 2131 | _tile_dpbssd(TMM6, TMM2, TMM1); | ||
| 2132 | _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); | ||
| 2133 | |||
| 2134 | _tile_zero(TMM7); | ||
| 2135 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); | ||
| 2136 | |||
| 2137 | _tile_dpbssd(TMM7, TMM3, TMM1); | ||
| 2138 | _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); | ||
| 2139 | |||
| 2140 | std::swap(C_cur, C_pre); | ||
| 2141 | }); | ||
| 2142 | } | ||
| 2143 | // final accumulation | ||
| 2144 | { | ||
| 2145 | int ii = KB - 1; | ||
| 2146 | acc_C<TA, TB, true>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); | ||
| 2147 | acc_C<TA, TB, true>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); | ||
| 2148 | acc_C<TA, TB, true>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); | ||
| 2149 | acc_C<TA, TB, true>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); | ||
| 2150 | } | ||
| 2151 | } else { | ||
| 2152 | for (int i = 0; i < KB; ++i) { | ||
| 2153 | _tile_zero(TMM4); | ||
| 2154 | _tile_zero(TMM6); | ||
| 2155 | if (m1 != 0) { | ||
| 2156 | _tile_zero(TMM5); | ||
| 2157 | _tile_zero(TMM7); | ||
| 2158 | } | ||
| 2159 | |||
| 2160 | const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); | ||
| 2161 | const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); | ||
| 2162 | if (need_unpack) { | ||
| 2163 | unpack_B<TB>(Tile0, B_blk0); | ||
| 2164 | _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); | ||
| 2165 | } else { | ||
| 2166 | _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); | ||
| 2167 | } | ||
| 2168 | |||
| 2169 | if (need_unpack) { | ||
| 2170 | unpack_B<TB>(Tile1, B_blk1); | ||
| 2171 | _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); | ||
| 2172 | } else { | ||
| 2173 | _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); | ||
| 2174 | } | ||
| 2175 | |||
| 2176 | if (m0 == TILE_M) { | ||
| 2177 | _tile_loadd(TMM2, A[i].qs, lda); | ||
| 2178 | } else { | ||
| 2179 | unpack_A(Tile23, &A[i], KB, m0); | ||
| 2180 | _tile_loadd(TMM2, Tile23, TILE_K); | ||
| 2181 | } | ||
| 2182 | |||
| 2183 | _tile_dpbssd(TMM4, TMM2, TMM0); | ||
| 2184 | _tile_dpbssd(TMM6, TMM2, TMM1); | ||
| 2185 | |||
| 2186 | _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); | ||
| 2187 | _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); | ||
| 2188 | |||
| 2189 | GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { | ||
| 2190 | acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); | ||
| 2191 | acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); | ||
| 2192 | }); | ||
| 2193 | |||
| 2194 | if (m1 != 0) { | ||
| 2195 | unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1); | ||
| 2196 | _tile_loadd(TMM3, Tile23, TILE_K); | ||
| 2197 | |||
| 2198 | _tile_dpbssd(TMM5, TMM3, TMM0); | ||
| 2199 | _tile_dpbssd(TMM7, TMM3, TMM1); | ||
| 2200 | _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); | ||
| 2201 | _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); | ||
| 2202 | GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { | ||
| 2203 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); | ||
| 2204 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); | ||
| 2205 | }); | ||
| 2206 | } | ||
| 2207 | } | ||
| 2208 | } | ||
| 2209 | return; | ||
| 2210 | } | ||
| 2211 | |||
| 2212 | template <typename TA, typename TB, typename TC, int BLOCK_K, | ||
| 2213 | typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0> | ||
| 2214 | void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { | ||
| 2215 | static_assert(std::is_same<TA, block_q8_K>::value); | ||
| 2216 | const int TILE_SIZE = get_tile_size<TB>(); | ||
| 2217 | |||
| 2218 | GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); | ||
| 2219 | const TA * RESTRICT A = static_cast<const TA *>(_A); | ||
| 2220 | const char * RESTRICT B = static_cast<const char *>(_B); | ||
| 2221 | |||
| 2222 | const int m0 = std::min(M, TILE_M); | ||
| 2223 | const int m1 = std::max(M - TILE_M, 0); | ||
| 2224 | //const int lda = KB * sizeof(TA); | ||
| 2225 | |||
| 2226 | static thread_local int8_t Tile0[TILE_N * TILE_K]; | ||
| 2227 | static thread_local int8_t Tile1[TILE_N * TILE_K]; | ||
| 2228 | static thread_local int8_t Tile23[TILE_M * TILE_K]; | ||
| 2229 | |||
| 2230 | // mat mul result for each group | ||
| 2231 | static thread_local int32_t Tile4[TILE_M * TILE_N]; | ||
| 2232 | static thread_local int32_t Tile5[TILE_M * TILE_N]; | ||
| 2233 | static thread_local int32_t Tile6[TILE_M * TILE_N]; | ||
| 2234 | static thread_local int32_t Tile7[TILE_M * TILE_N]; | ||
| 2235 | |||
| 2236 | // sum of each QK_K block, contains 8 groups, int32 | ||
| 2237 | static thread_local int32_t Sumi4[TILE_M * TILE_N]; | ||
| 2238 | static thread_local int32_t Sumi5[TILE_M * TILE_N]; | ||
| 2239 | static thread_local int32_t Sumi6[TILE_M * TILE_N]; | ||
| 2240 | static thread_local int32_t Sumi7[TILE_M * TILE_N]; | ||
| 2241 | |||
| 2242 | const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32; | ||
| 2243 | for (int i = 0; i < KB; ++i) { | ||
| 2244 | // step 1: accumulate the quants across 8 groups, each group with 32 | ||
| 2245 | for (int k = 0; k < QK_K / k_group_size; ++k) { | ||
| 2246 | GGML_DISPATCH_BOOL(k > 0, is_acc, [&] { | ||
| 2247 | _tile_zero(TMM4); | ||
| 2248 | _tile_zero(TMM6); | ||
| 2249 | |||
| 2250 | unpack_B<TB>(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k); | ||
| 2251 | _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); | ||
| 2252 | |||
| 2253 | unpack_B<TB>(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k); | ||
| 2254 | _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); | ||
| 2255 | |||
| 2256 | unpack_A<TB>(Tile23, &A[i], KB, k, m0); | ||
| 2257 | _tile_loadd(TMM2, Tile23, TILE_K); | ||
| 2258 | |||
| 2259 | _tile_dpbssd(TMM4, TMM2, TMM0); | ||
| 2260 | _tile_dpbssd(TMM6, TMM2, TMM1); | ||
| 2261 | |||
| 2262 | _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); | ||
| 2263 | _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); | ||
| 2264 | |||
| 2265 | scale_C<TB, is_acc>(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0); | ||
| 2266 | scale_C<TB, is_acc>(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0); | ||
| 2267 | |||
| 2268 | if (m1 != 0) { | ||
| 2269 | _tile_zero(TMM5); | ||
| 2270 | _tile_zero(TMM7); | ||
| 2271 | |||
| 2272 | unpack_A<TB>(Tile23, &A[TILE_M * KB + i], KB, k, m1); | ||
| 2273 | _tile_loadd(TMM3, Tile23, TILE_K); | ||
| 2274 | |||
| 2275 | _tile_dpbssd(TMM5, TMM3, TMM0); | ||
| 2276 | _tile_dpbssd(TMM7, TMM3, TMM1); | ||
| 2277 | |||
| 2278 | _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); | ||
| 2279 | _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); | ||
| 2280 | |||
| 2281 | scale_C<TB, is_acc>(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1); | ||
| 2282 | scale_C<TB, is_acc>(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1); | ||
| 2283 | } | ||
| 2284 | }); | ||
| 2285 | } | ||
| 2286 | |||
| 2287 | // step 2: accmulate the mins | ||
| 2288 | GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { | ||
| 2289 | acc_C<TA, TB, is_acc>::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); | ||
| 2290 | acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); | ||
| 2291 | if (m1 != 0) { | ||
| 2292 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); | ||
| 2293 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); | ||
| 2294 | } | ||
| 2295 | }); | ||
| 2296 | } | ||
| 2297 | return; | ||
| 2298 | } | ||
| 2299 | |||
| 2300 | } // anonymous namespace | ||
| 2301 | |||
| 2302 | // get the packed tensor size for quantized weights | ||
| 2303 | size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor) { | ||
| 2304 | const enum ggml_type TYPE = tensor->type; | ||
| 2305 | |||
| 2306 | const int K = tensor->ne[0]; // ne0: in_features | ||
| 2307 | const int N = tensor->ne[1]; // ne1: out_features | ||
| 2308 | |||
| 2309 | auto get_tensor_size = [&] { | ||
| 2310 | size_t row_size_B{0}; | ||
| 2311 | GGML_DISPATCH_QTYPES(TYPE, [&] { | ||
| 2312 | row_size_B = get_row_size<type, blck_size>(K); | ||
| 2313 | }); | ||
| 2314 | return N * row_size_B; | ||
| 2315 | }; | ||
| 2316 | |||
| 2317 | if (qtype_has_amx_kernels(TYPE)) { | ||
| 2318 | return get_tensor_size(); | ||
| 2319 | } else { | ||
| 2320 | // for f16, bf16 we don't do packing | ||
| 2321 | return ggml_nbytes(tensor); | ||
| 2322 | } | ||
| 2323 | } | ||
| 2324 | |||
| 2325 | // pack weight to vnni format | ||
| 2326 | void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { | ||
| 2327 | GGML_ASSERT(offset == 0 && size == ggml_nbytes(tensor)); // only full tensor conversion is supported for now | ||
| 2328 | |||
| 2329 | const enum ggml_type TYPE = tensor->type; | ||
| 2330 | |||
| 2331 | const int K = tensor->ne[0]; // ne0: in_features | ||
| 2332 | const int N = tensor->ne[1]; // ne1: out_features | ||
| 2333 | |||
| 2334 | GGML_DISPATCH_QTYPES(TYPE, [&] { | ||
| 2335 | convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K); | ||
| 2336 | }); | ||
| 2337 | } | ||
| 2338 | |||
| 2339 | size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { | ||
| 2340 | struct ggml_tensor * src0 = dst->src[0]; | ||
| 2341 | |||
| 2342 | const enum ggml_type TYPE = src0->type; | ||
| 2343 | |||
| 2344 | const bool is_floating_type = TYPE == GGML_TYPE_F16; | ||
| 2345 | if (is_floating_type) { | ||
| 2346 | return 0; | ||
| 2347 | } | ||
| 2348 | |||
| 2349 | const int M = dst->ne[1]; | ||
| 2350 | const int K = src0->ne[0]; | ||
| 2351 | |||
| 2352 | size_t desired_wsize = 0; | ||
| 2353 | |||
| 2354 | GGML_DISPATCH_QTYPES(TYPE, [&] { | ||
| 2355 | const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); | ||
| 2356 | desired_wsize = M * row_size_A; | ||
| 2357 | }); | ||
| 2358 | |||
| 2359 | return desired_wsize; | ||
| 2360 | } | ||
| 2361 | |||
| 2362 | // NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX) | ||
| 2363 | // | ||
| 2364 | // src0: weight in shape of {N, K}, quantized | ||
| 2365 | // src1: input in shape of {M, K}, float32 | ||
| 2366 | // dst: output in shape of {M, N}, float32 | ||
| 2367 | // | ||
| 2368 | // the function performs: dst = src1 @ src0.T | ||
| 2369 | // | ||
| 2370 | void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) { | ||
| 2371 | struct ggml_tensor * src0 = dst->src[0]; | ||
| 2372 | struct ggml_tensor * src1 = dst->src[1]; | ||
| 2373 | |||
| 2374 | const enum ggml_type TYPE = src0->type; | ||
| 2375 | |||
| 2376 | // f16 only has avx512 kernels for now, | ||
| 2377 | // amx kernels will be added once 6th gen xeon is released. | ||
| 2378 | const bool is_floating_type = TYPE == GGML_TYPE_F16; | ||
| 2379 | |||
| 2380 | const int M = dst->ne[1]; | ||
| 2381 | const int N = dst->ne[0]; | ||
| 2382 | const int K = src0->ne[0]; | ||
| 2383 | const int ldc = dst->nb[1] / dst->nb[0]; | ||
| 2384 | |||
| 2385 | if (is_floating_type) { | ||
| 2386 | constexpr int BLOCK_M = 4; | ||
| 2387 | constexpr int BLOCK_N = 6; | ||
| 2388 | const int MB = div_up(M, BLOCK_M); | ||
| 2389 | const int NB = div_up(N, BLOCK_N); | ||
| 2390 | |||
| 2391 | parallel_for_ggml(params, MB * NB, [&](int begin, int end) { | ||
| 2392 | GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] { | ||
| 2393 | for (int i = begin; i < end; ++i) { | ||
| 2394 | int mb = i / NB; | ||
| 2395 | int nb = i % NB; | ||
| 2396 | |||
| 2397 | int mb_start = mb * BLOCK_M; | ||
| 2398 | int mb_size = std::min(BLOCK_M, M - mb_start); | ||
| 2399 | int nb_start = nb * BLOCK_N; | ||
| 2400 | int nb_size = std::min(BLOCK_N, N - nb_start); | ||
| 2401 | |||
| 2402 | switch (mb_size << 4 | nb_size) { | ||
| 2403 | case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break; | ||
| 2404 | case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break; | ||
| 2405 | case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break; | ||
| 2406 | case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break; | ||
| 2407 | case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break; | ||
| 2408 | case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break; | ||
| 2409 | case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break; | ||
| 2410 | case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break; | ||
| 2411 | case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break; | ||
| 2412 | case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break; | ||
| 2413 | case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break; | ||
| 2414 | case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break; | ||
| 2415 | default: fprintf(stderr, "Unexpected block size!\n"); | ||
| 2416 | } | ||
| 2417 | } | ||
| 2418 | }); | ||
| 2419 | }); | ||
| 2420 | return; | ||
| 2421 | } | ||
| 2422 | |||
| 2423 | // pointer to work space, used convert A from float to quantized type | ||
| 2424 | void * wdata = params->wdata; | ||
| 2425 | |||
| 2426 | //TODO: performance improvement: merge quant A | ||
| 2427 | if (params->ith == 0) { | ||
| 2428 | GGML_DISPATCH_QTYPES(TYPE, [&] { | ||
| 2429 | const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); | ||
| 2430 | const size_t desired_wsize = M * row_size_A; | ||
| 2431 | if (params->wsize < desired_wsize) { | ||
| 2432 | GGML_ABORT("insufficient work space size"); | ||
| 2433 | } | ||
| 2434 | |||
| 2435 | // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size | ||
| 2436 | // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size | ||
| 2437 | GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); | ||
| 2438 | |||
| 2439 | const float * A_data = static_cast<const float *>(src1->data); | ||
| 2440 | for (int m = 0; m < M; ++m) { | ||
| 2441 | from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K); | ||
| 2442 | } | ||
| 2443 | }); | ||
| 2444 | } | ||
| 2445 | |||
| 2446 | ggml_barrier(params->threadpool); | ||
| 2447 | |||
| 2448 | if (M == 1) { | ||
| 2449 | // MB = 1 and handle 8 tiles in each block | ||
| 2450 | constexpr int kTilesN = 4; | ||
| 2451 | constexpr int BLOCK_N = TILE_N * kTilesN; | ||
| 2452 | const int NB = div_up(N, BLOCK_N); | ||
| 2453 | |||
| 2454 | parallel_for_ggml(params, NB, [&](int begin, int end) { | ||
| 2455 | GGML_DISPATCH_QTYPES(TYPE, [&] { | ||
| 2456 | const int KB = K / blck_size; | ||
| 2457 | const int TILE_SIZE = get_tile_size<type>(); | ||
| 2458 | const int row_size_A = KB * sizeof(vec_dot_type); | ||
| 2459 | for (int i = begin; i < end; ++i) { | ||
| 2460 | int nb = i; | ||
| 2461 | int nb_start = nb * BLOCK_N; | ||
| 2462 | int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96 | ||
| 2463 | |||
| 2464 | switch (nb_size) { | ||
| 2465 | //case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break; | ||
| 2466 | case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break; | ||
| 2467 | case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break; | ||
| 2468 | case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break; | ||
| 2469 | case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break; | ||
| 2470 | default: fprintf(stderr, "Unexpected n block size!\n"); | ||
| 2471 | } | ||
| 2472 | } | ||
| 2473 | }); | ||
| 2474 | }); | ||
| 2475 | return; | ||
| 2476 | } | ||
| 2477 | |||
| 2478 | // handle 4 tiles at a tile | ||
| 2479 | constexpr int BLOCK_M = TILE_M * 2; | ||
| 2480 | constexpr int BLOCK_N = TILE_N * 2; | ||
| 2481 | const int MB = div_up(M, BLOCK_M); | ||
| 2482 | const int NB = div_up(N, BLOCK_N); | ||
| 2483 | |||
| 2484 | parallel_for_ggml(params, MB * NB, [&](int begin, int end) { | ||
| 2485 | // init tile config for each thread | ||
| 2486 | ggml_tile_config_init(); | ||
| 2487 | |||
| 2488 | GGML_DISPATCH_QTYPES(TYPE, [&] { | ||
| 2489 | const int KB = K / blck_size; | ||
| 2490 | const int TILE_SIZE = get_tile_size<type>(); | ||
| 2491 | const int row_size_A = KB * sizeof(vec_dot_type); | ||
| 2492 | |||
| 2493 | for (int i = begin; i < end; ++i) { | ||
| 2494 | int mb = i / NB; | ||
| 2495 | int nb = i % NB; | ||
| 2496 | |||
| 2497 | int mb_start = mb * BLOCK_M; | ||
| 2498 | int mb_size = std::min(BLOCK_M, M - mb_start); | ||
| 2499 | int nb_start = nb * BLOCK_N; | ||
| 2500 | int nb_size = BLOCK_N; | ||
| 2501 | |||
| 2502 | tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>( | ||
| 2503 | mb_size, nb_size, KB, | ||
| 2504 | (const char *)wdata + mb_start * row_size_A, | ||
| 2505 | (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), | ||
| 2506 | (float *) dst->data + mb_start * N + nb_start, ldc); | ||
| 2507 | } | ||
| 2508 | }); | ||
| 2509 | }); | ||
| 2510 | } | ||
| 2511 | |||
| 2512 | #endif // if defined(__AMX_INT8__) && defined(__AVX512VNNI__) | ||
diff --git a/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h b/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h new file mode 100644 index 0000000..baf7684 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h | |||
| @@ -0,0 +1,10 @@ | |||
| 1 | #pragma once | ||
| 2 | #include "common.h" | ||
| 3 | |||
| 4 | size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst); | ||
| 5 | |||
| 6 | size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor); | ||
| 7 | |||
| 8 | void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); | ||
| 9 | |||
| 10 | void ggml_backend_amx_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||
