diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-hexagon/htp')
40 files changed, 11859 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/llama.cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt new file mode 100644 index 0000000..2c23b60 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.22.2) +project(ggml-htp C CXX ASM) + +include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) + +include_directories( + ${HEXAGON_SDK_ROOT}/incs + ${HEXAGON_SDK_ROOT}/incs/stddef + ${CMAKE_CURRENT_SOURCE_DIR}/../../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../.. + ${CMAKE_CURRENT_SOURCE_DIR}/.. + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR}) + +set(HTP_LIB ggml-htp-${DSP_VERSION}) + +add_library(${HTP_LIB} SHARED + main.c + htp_iface_skel.c + worker-pool.c + hex-dma.c + matmul-ops.c + binary-ops.c + unary-ops.c + sum-rows-ops.c + softmax-ops.c + act-ops.c + rope-ops.c + flash-attn-ops.c + set-rows-ops.c + get-rows-ops.c + cpy-ops.c + argsort-ops.c +) + +target_compile_definitions(${HTP_LIB} PRIVATE + $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1> + $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,> + FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) + +build_idl(htp_iface.idl ${HTP_LIB}) + +set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) + +install(TARGETS ${HTP_LIB}) diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c new file mode 100644 index 0000000..950d836 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c @@ -0,0 +1,823 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <math.h> +#include <string.h> + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#define htp_act_preamble3 \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +#define htp_act_preamble2 \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0, + const struct htp_tensor * src1, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * src1_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_act_preamble3; + + size_t src0_row_size = nb01; + size_t src1_row_size = nb11; + size_t dst_row_size = nb1; + + + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const bool src1_valid = src1->ne[0]; + const int nc = (src1_valid) ? ne00 : ne00 / 2; + if (!src1_valid) { + const int32_t swapped = op_params[1]; + data_src1 = data_src0; + src1_row_size = src0_row_size; + + const size_t nc_in_bytes = nc * SIZEOF_FP32; + data_src0 += swapped ? nc_in_bytes : 0; + data_src1 += swapped ? 0 : nc_in_bytes; + } + + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t src1_spad_half_size = src1_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + if (BLOCK == 0) { + FARF(ERROR, + "swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); + return; + } + + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)), + src1_row_size_aligned, src1_row_size, block_size); + } + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = (float *) dma_queue_pop(dma_queue).dst; + + for (uint32_t ib = 0; ib < block_size; ib++) { + const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float)); + const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float)); + float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); + + //swiglu(x) = x1 * sigmoid(x0) + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, nc); + hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, + (const uint8_t *) src1_spad_ptr, nc); + } + + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, + dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)), + src1_row_size_aligned, src1_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0, + const struct htp_tensor * src1, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * src1_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_act_preamble3; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + size_t src0_row_size = nb01; + size_t src1_row_size = nb11; + size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const bool src1_valid = src1->ne[0]; + const int nc = (src1_valid) ? ne00 : ne00 / 2; + if (!src1_valid) { + const int32_t swapped = op_params[1]; + data_src1 = data_src0; + src1_row_size = src0_row_size; + + const size_t nc_in_bytes = nc * SIZEOF_FP32; + data_src0 += swapped ? nc_in_bytes : 0; + data_src1 += swapped ? 0 : nc_in_bytes; + } + + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t src1_spad_half_size = src1_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + if (BLOCK == 0) { + FARF(ERROR, + "swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least " + "%zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); + return; + } + const float alpha = ((const float *) (op_params))[2]; + const float limit = ((const float *) (op_params))[3]; + + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm( + dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + dma_queue_push_ddr_to_vtcm( + dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)), + src1_row_size_aligned, src1_row_size, block_size); + } + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = (float *) dma_queue_pop(dma_queue).dst; + + for (uint32_t ib = 0; ib < block_size; ib++) { + const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float)); + const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float)); + float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); + + // x (src0_spad_data) = std::min(src0_p[k], limit); + hvx_min_scalar_f32((uint8_t *) src0_spad_ptr, (const uint8_t *) src0_spad_ptr, limit, nc); + // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit); + hvx_clamp_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, -limit, limit, nc); + // y (src1_spad_data) = y1 + 1.f + hvx_add_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, 1.0, nc); + // x1 (dst_spad_data) = alpha * (x) + hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, alpha, nc); + // x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1)) + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); + // out = x * sigmoid(alpha * x) * (y + 1.f) + hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, + (const uint8_t *) src1_spad_ptr, nc); + } + + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, + dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)), + src1_row_size_aligned, src1_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "swiglu-oai-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0], + src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], + src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + + +static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_act_preamble2; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const uint8_t * data_src0 = (const uint8_t *) src0->data; + uint8_t * data_dst = (uint8_t *) dst->data; + + uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + + // In gelu = x*sigmoid(x*1.702) + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + + if (BLOCK == 0) { + FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); + return; + } + + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + } + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float* dst_spad = (float *) dma_queue_pop(dma_queue).src; + float* src0_spad = (float *) dma_queue_pop(dma_queue).dst; + + for (uint32_t ib = 0; ib < block_size; ib++) { + const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float)); + float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); + + // gelu = x * sigmoid(1.702 * x) // current implementation + hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + } + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), + dst_row_size, dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "gelu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02, + ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, + octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + + + +static void unary_silu_f32_per_thread(const struct htp_tensor * src0, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_act_preamble2; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const uint8_t * data_src0 = (const uint8_t *) src0->data; + uint8_t * data_dst = (uint8_t *) dst->data; + + uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + + if (BLOCK == 0) { + FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); + return; + } + + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + } + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float* dst_spad = (float *) dma_queue_pop(dma_queue).src; + float* src0_spad = (float *) dma_queue_pop(dma_queue).dst; + + for (uint32_t ib = 0; ib < block_size; ib++) { + const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float)); + float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); + + // silu = x * sigmoid(x) + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + } + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), + dst_row_size, dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "silu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02, + ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static const float GELU_COEF_A = 0.044715f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, + const struct htp_tensor * src1, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * src1_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_act_preamble3; + + size_t src0_row_size = nb01; + size_t src1_row_size = nb11; + size_t dst_row_size = nb1; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const bool src1_valid = src1->ne[0]; + const int nc = (src1_valid) ? ne00 : ne00 / 2; + if (!src1_valid) { + const int32_t swapped = op_params[1]; + data_src1 = data_src0; + src1_row_size = src0_row_size; + + const size_t nc_in_bytes = nc * SIZEOF_FP32; + data_src0 += swapped ? nc_in_bytes : 0; + data_src1 += swapped ? 0 : nc_in_bytes; + } + + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t src1_spad_half_size = src1_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + if (BLOCK == 0) { + FARF(ERROR, + "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); + return; + } + + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)), + src1_row_size_aligned, src1_row_size, block_size); + } + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = (float *) dma_queue_pop(dma_queue).dst; + + for (uint32_t ib = 0; ib < block_size; ib++) { + const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float))); + const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float))); + uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float))); + + // geglu tanh implementation + // geglu(x, g) = gelu(x) * g + // gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))) + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A + hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI + hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res) + hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f + hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g + } + + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, + dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)), + src1_row_size_aligned, src1_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void unary_silu_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, + octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, + &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, + &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, + &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static int execute_op_activations_f32(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) { + FARF(ERROR, "Non-contiguous tensors are not supported at this time \n"); + return HTP_STATUS_NO_SUPPORT; + } + + worker_callback_t act_op_func; + const char * op_type = NULL; + + switch (octx->op) { + case HTP_OP_UNARY_SILU: + act_op_func = unary_silu_f32; + op_type = "silu-f32"; + break; + + case HTP_OP_GLU_SWIGLU: + act_op_func = glu_swiglu_f32; + op_type = "swiglu-f32"; + break; + + case HTP_OP_GLU_SWIGLU_OAI: + act_op_func = glu_swiglu_oai_f32; + op_type = "swiglu-oai-f32"; + break; + case HTP_OP_UNARY_GELU: + act_op_func = unary_gelu_f32; + op_type = "gelu-f32"; + break; + + case HTP_OP_GLU_GEGLU: + act_op_func = glu_geglu_f32; + op_type = "geglu-f32"; + break; + default: + FARF(ERROR, "Unsupported activations Op %u\n", octx->op); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t n_threads = octx->n_threads; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + + size_t src0_row_size = src0->nb[1]; + size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used + size_t dst_row_size = dst->nb[1]; + + const bool src1_valid = src1->ne[0]; + if (!src1_valid) { + src1_row_size = src0_row_size; + } + + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + // VTCM scratchpads for all tensors + // N rows per thread, padded to HVX vector size + + size_t spad_size_per_row = (src0_row_size_aligned + src1_row_size_aligned) + dst_row_size_aligned; + size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads* spad_size_per_row); + + // Make sure the reserved vtcm size is sufficient + if(vtcm_row_per_thread ==0){ + FARF(ERROR, "act-%s : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", op_type, octx->ctx->vtcm_size, + spad_size_per_row * n_threads); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread; + octx->src1_spad.size_per_thread = src1_row_size_aligned * vtcm_row_per_thread; + octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread; + + octx->dst_spad.size = n_threads* octx->dst_spad.size_per_thread; + octx->src0_spad.size = n_threads* octx->src0_spad.size_per_thread; + octx->src1_spad.size = n_threads* octx->src1_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + if (src1->ne[0]) { + FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], + src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, + octx->dst_spad.size); + } else { + FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + } + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t n_jobs = MIN(n_threads, src0_nrows); + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs); + } + + return err; +} + +int op_activations(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src0.type) { + case HTP_TYPE_F32: + err = execute_op_activations_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c new file mode 100644 index 0000000..a4cee98 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -0,0 +1,281 @@ +#include <string.h> +#include <stdlib.h> +#include <math.h> +#include <HAP_farf.h> +#include <HAP_perf.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "ggml.h" + +#include "hvx-utils.h" +#include "hex-dma.h" + +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +struct htp_argsort_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; +}; + +static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y) +{ + const HVX_Vector one = Q6_V_vsplat_R(1); + const HVX_Vector zero = Q6_V_vzero(); + + HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y); + HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero); + HVX_Vector sum = hvx_vec_reduce_sum_i32(matches); + return hvx_vec_get_i32(sum) == 32; +} + +// Sorts values and mirrors swaps to indices. +static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + while (i <= j) { + // Vectorized scan for i + while (i <= j) { + // Check if we have at least one full vector + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(pivot_vec, vals_vec)) { + // If all elements are < pivot, we can skip this whole block + i += 32; + continue; + } + } + + // Scalar fallback / cleanup + if (values[i] < pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j + while (i <= j) { + if (j - 32 >= i) { + // Load 32 elements ending at j. + // Since we want `values[j] > pivot`, let's load from j-31 to j. + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(vals_vec, pivot_vec)) { + j -= 32; + continue; + } + } + + if (values[j] > pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_asc(values, indices, left, j); + if (i < right) quicksort_values_indices_asc(values, indices, i, right); +} + +static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + + while (i <= j) { + // Vectorized scan for i (values[i] > pivot) + while (i <= j) { + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(vals_vec, pivot_vec)) { + i += 32; + continue; + } + } + + if (values[i] > pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j (values[j] < pivot) + while (i <= j) { + if (j - 32 >= i) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(pivot_vec, vals_vec)) { + j -= 32; + continue; + } + } + + if (values[j] < pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_desc(values, indices, left, j); + if (i < right) quicksort_values_indices_desc(values, indices, i, right); +} + +static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { + struct htp_argsort_context * actx = (struct htp_argsort_context *)data; + struct htp_ops_context * octx = actx->octx; + + // Unpack context + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + // Scratchpad memory + uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i; + + // Dimensions + uint32_t ne00 = src0->ne[0]; + uint32_t ne01 = src0->ne[1]; + uint32_t ne02 = src0->ne[2]; + uint32_t ne03 = src0->ne[3]; + + uint32_t nb01 = src0->nb[1]; + //uint32_t nb02 = src0->nb[2]; + //uint32_t nb03 = src0->nb[3]; + + uint32_t nb1 = dst->nb[1]; + //uint32_t nb2 = dst->nb[2]; + //uint32_t nb3 = dst->nb[3]; + + // Sort order + enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0]; + + // Rows to process + uint32_t total_rows = ne01 * ne02 * ne03; + uint32_t rows_per_thread = actx->nrows_per_thread; + uint32_t start_row = rows_per_thread * i; + uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); + + // Scratchpad layout: + // We need space for one row of float data (values) and one row of int32 indices. + // values: ne00 * sizeof(float) + // indices: ne00 * sizeof(int32_t) + // Padded to 128 bytes. + + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + float * values_buf = (float *) spad; + int32_t * indices_buf = (int32_t *) (spad + values_size); + + for (uint32_t r = start_row; r < end_row; r++) { + uint32_t src_offset = r * nb01; + uint32_t dst_offset = r * nb1; + + uint8_t * src_ptr = (uint8_t *) src0->data + src_offset; + uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset; + + hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); + hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00); + + // Initialize indices + for (uint32_t j = 0; j < ne00; j++) { + indices_buf[j] = j; + } + + // Sort values and mirror swaps to indices + if (order == GGML_SORT_ORDER_ASC) { + quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1); + } else { + quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1); + } + + // Copy indices back to DDR + hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00); + } +} + +int op_argsort(struct htp_ops_context * octx) { + // Check supported types + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // Allocate scratchpad + // We need 1 row of float + 1 row of int32 per thread. + uint32_t ne00 = octx->src0.ne[0]; + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128); + size_t spad_per_thread = values_size + indices_size; + + // Make sure we round up to 256 for alignment requirements + spad_per_thread = hex_round_up(spad_per_thread, 256); + + size_t total_spad_size = spad_per_thread * octx->n_threads; + + if (octx->ctx->vtcm_size < total_spad_size) { + FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.size = total_spad_size; + octx->src0_spad.size_per_thread = spad_per_thread; + + FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", + octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3], + octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], + octx->src0.data, octx->dst.data); + + uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + uint32_t n_jobs = MIN(total_rows, octx->n_threads); + + struct htp_argsort_context actx; + actx.octx = octx; + actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs; + + // Run jobs + worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs); + + return HTP_STATUS_OK; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/binary-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/binary-ops.c new file mode 100644 index 0000000..00dbcf8 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -0,0 +1,827 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <math.h> +#include <string.h> + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +// Context for binary operations +struct htp_binary_context { + struct htp_ops_context * octx; + struct fastdiv_values dim1_div; + struct fastdiv_values dim2_div; + struct fastdiv_values dim12_div; + + struct fastdiv_values src1_dim1_div; // ne11 + struct fastdiv_values src1_dim2_div; // ne12 + struct fastdiv_values src1_dim3_div; // ne13 + + uint32_t nrows_per_thread; + bool split_at_ne01; + bool split_at_ne02; + + // Precomputed values + uint32_t block_max; + size_t src0_row_size_aligned; + size_t src1_row_size_aligned; + size_t dst_row_size_aligned; + uint32_t src1_fetch_rows; // 1 or block_max + uint32_t src1_dma_stride; // 0 or stride +}; + +#define htp_binary_preamble \ + const struct htp_tensor * src0 = &octx->src0; \ + const struct htp_tensor * src1 = &octx->src1; \ + struct htp_tensor * dst = &octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, + uint32_t ne01, uint32_t ne02) { + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint32_t rows_left = end_row - ir; + uint32_t block_limit = rows_left; + + if (bctx->split_at_ne01) { + block_limit = MIN(block_limit, ne01 - i01); + } + if (bctx->split_at_ne02) { + uint32_t rows_in_plane = (ne02 * ne01) - rem; + block_limit = MIN(block_limit, rows_in_plane); + } + + return MIN(bctx->block_max, block_limit); +} + +// Macro for scalar op switch +#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \ + default: break; \ + } + +// Macro for vector op switch (All Aligned) +#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } + +// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned) +#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ + } + +// Macro for vector op switch (All Unaligned - generic loop used in element repeat) +#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ + } + +// 1. Scalar src1 (ne10 == 1) +static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + // Preamble + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + // Main loop + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + // src1 indices (broadcast/repeat) + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div); + + uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint32_t s1_stride = (ne11 == 1) ? 0 : nb11; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + float val = *(float *)src1_ptr; + src1_ptr += s1_stride; + COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00); + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast +static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t src1_spad_half = octx->src1_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint32_t i13 = (ne13 == 1) ? 0 : i03; + uint32_t i12 = (ne12 == 1) ? 0 : i02; + uint32_t i11 = (ne11 == 1) ? 0 : i01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + } + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + + uint32_t p13 = (ne13 == 1) ? 0 : p03; + uint32_t p12 = (ne12 == 1) ? 0 : p02; + uint32_t p11 = (ne11 == 1) ? 0 : p01; + + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; + + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size); + + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1) +static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + void * s1_ptr = (void *) src1_spad; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + } + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 4. Vector Complex (ne10 == ne00, complex broadcast) +static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div); + + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + // Read src1 from DDR (unaligned) + COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00); + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 5. Element Repeat (ne10 != ne00) +static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div); + + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + // Repeat src1 row + for (uint32_t c = 0; c < ne00; c += ne10) { + uint32_t len = MIN(ne10, ne00 - c); + // Use UUU for speed and simplicity + COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len); + } + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 6. ADD_ID (src1 gathered via src2 indices) +static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src2 = &octx->src2; + struct htp_tensor * dst = &octx->dst; + + const uint32_t ne00 = src0->ne[0]; + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + const uint32_t ne03 = src0->ne[3]; + const uint32_t ne11 = src1->ne[1]; // for bounds check + + const uint32_t nb01 = src0->nb[1]; + const uint32_t nb02 = src0->nb[2]; + const uint32_t nb03 = src0->nb[3]; + const uint32_t nb11 = src1->nb[1]; // src1 row stride + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; // linear within block since we split at ne01 + + const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]); + + uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11; + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00); + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +static int execute_op_binary_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + const uint32_t n_threads = octx->n_threads; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + + // Use packed row sizes for VTCM allocation + const size_t src0_row_size = src0->ne[0] * sizeof(float); + const size_t src1_row_size = src1->ne[0] * sizeof(float); + const size_t dst_row_size = dst->ne[0] * sizeof(float); + + // Align to VLEN + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + + bool is_add_id = (octx->op == HTP_OP_ADD_ID); + bool is_scalar = !is_add_id && (src1->ne[0] == 1); + + // Determine which kernel we will use to alloc memory and dispatch + bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] && + (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && + (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && + (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); + + bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); + bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]); + bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]); + + size_t spad_row_total; + if (is_scalar) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } else if (is_row_bcast) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } else if (use_vector_same) { + spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned); + } else if (is_add_id) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly + } else { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } + + size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total); + // Adjust for static src1 in row_bcast case + if (is_row_bcast) { + size_t needed_static = src1_row_size_aligned; + if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL; + size_t avail = octx->ctx->vtcm_size - needed_static; + rows_per_buffer = avail / (n_threads * spad_row_total); + } + + if (rows_per_buffer < 1) { + FARF(ERROR, "binary-f32: VTCM too small\n"); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned; + octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned; + + if (is_scalar || use_complex || use_repeat || is_add_id) { + octx->src1_spad.size_per_thread = 0; + } else if (is_row_bcast) { + octx->src1_spad.size_per_thread = 0; + } else { + octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned; + } + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + if (is_row_bcast) { + octx->src1_spad.size = src1_row_size_aligned; + } else { + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + } + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + return HTP_STATUS_OK; + } + + uint32_t n_jobs = MIN(n_threads, src0_nrows); + + dma_queue * q = octx->ctx->dma[0]; + if (is_row_bcast) { + dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1); + } + + struct htp_binary_context bctx; + bctx.octx = octx; + bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + bctx.block_max = rows_per_buffer; + bctx.src0_row_size_aligned = src0_row_size_aligned; + bctx.src1_row_size_aligned = src1_row_size_aligned; + bctx.dst_row_size_aligned = dst_row_size_aligned; + + bctx.dim1_div = init_fastdiv_values(src0->ne[1]); + bctx.dim2_div = init_fastdiv_values(src0->ne[2]); + bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); + + bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); + bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); + bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); + + bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]); + bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); + + bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]); + bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); + + bctx.split_at_ne01 = (src0->ne[2] > 1) && + ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); + + bctx.split_at_ne02 = (src0->ne[3] > 1) && + ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2); + + // Precompute specific kernel parameters + if (use_vector_same) { + bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1]; + bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer; + } + + worker_callback_t worker_func; + if (is_add_id) worker_func = binary_job_add_id; + else if (is_scalar) worker_func = binary_job_scalar; + else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; + else if (use_vector_same) worker_func = binary_job_vector_same_shape; + else if (use_complex) worker_func = binary_job_vector_complex; + else worker_func = binary_job_element_repeat; + + if (is_row_bcast) { + dma_queue_pop(q); + } + + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs); + + return HTP_STATUS_OK; +} + +int op_binary(struct htp_ops_context * octx) { + if (octx->src0.type == HTP_TYPE_F32) { + return execute_op_binary_f32(octx); + } + return HTP_STATUS_NO_SUPPORT; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake b/llama.cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake new file mode 100644 index 0000000..7fa236e --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake @@ -0,0 +1,157 @@ +if (HEXAGON_TOOLCHAIN_INCLUDED) + return() +endif() +set(HEXAGON_TOOLCHAIN_INCLUDED true) + +#Cross Compiling for Hexagon +set(HEXAGON TRUE) +set(CMAKE_SYSTEM_NAME QURT) +set(CMAKE_SYSTEM_PROCESSOR Hexagon) +set(CMAKE_SYSTEM_VERSION "1") #${HEXAGON_PLATFORM_LEVEL}) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) +set(CUSTOM_RUNELF_PATH "") + +#To fix backward compatibility with EAI addon. +if (NOT HEXAGON_SDK_ROOT) + set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT}) +endif() + +if (NOT HEXAGON_TOOLS_ROOT) + if (DEFINED ENV{HEXAGON_TOOLS_ROOT}) + set(HEXAGON_TOOLS_ROOT $ENV{HEXAGON_TOOLS_ROOT}) + endif() + if(NOT HEXAGON_TOOLS_ROOT) + set(HEXAGON_TOOLS_ROOT $ENV{DEFAULT_HEXAGON_TOOLS_ROOT}) + endif() +endif() + +file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) +file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT) + +#Get the Binary extension of the Hexagon Toolchain +if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows) + set(HEXAGON_TOOLCHAIN_SUFFIX .exe) +endif() +message(DEBUG "CMAKE_HOST_SYSTEM_NAME:${CMAKE_HOST_SYSTEM_NAME}") + +include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_arch.cmake) + +set(HEXAGON_TOOLCHAIN ${HEXAGON_TOOLS_ROOT}) +set(HEXAGON_LIB_DIR "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib") +set(HEXAGON_ISS_DIR ${HEXAGON_TOOLCHAIN}/Tools/lib/iss) + +set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES + HEXAGON_SDK_ROOT + HEXAGON_TOOLS_ROOT +) + +#QURT Related includes and linker flags +set(V_ARCH ${HEXAGON_ARCH}) +set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}") +set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}") + +if( ${TREE} MATCHES PAKMAN ) + set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}") +endif() +message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}") +set(RTOS_DIR ${_QURT_INSTALL_DIR}) +set(QCC_DIR "${HEXAGON_QCC_DIR}/${V_ARCH}/G0") +set(TARGET_DIR "${HEXAGON_LIB_DIR}/${V_ARCH}/G0") + +include_directories( + ${_QURT_INSTALL_DIR}/include + ${_QURT_INSTALL_DIR}/include/qurt + ${_QURT_INSTALL_DIR}/include/posix + ) + +set(QURT_START_LINK_LIBS) +set(QURT_START_LINK_LIBS + "${TARGET_DIR}/init.o" + "${RTOS_DIR}/lib/crt1.o" + "${RTOS_DIR}/lib/debugmon.o" + "${RTOS_DIR}/lib/libqurt.a" + "${TARGET_DIR}/libc.a" + "${TARGET_DIR}/libqcc.a" + "${TARGET_DIR}/libhexagon.a" + "${RTOS_DIR}/lib/libqurtcfs.a" + "${RTOS_DIR}/lib/libtimer_island.a" + "${RTOS_DIR}/lib/libtimer_main.a" + "${RTOS_DIR}/lib/libposix.a" + ) +STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}") + +set(QURT_END_LINK_LIBS + ${TARGET_DIR}/fini.o + ) + +#Non QURT related includes and linker flags + +set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}") + +if (NOT NO_WRAP_MEM_API) + set(WRAP_MALLOC -Wl,--wrap=malloc) + set(WRAP_CALLOC -Wl,--wrap=calloc) + set(WRAP_FREE -Wl,--wrap=free) + set(WRAP_REALLOC -Wl,--wrap=realloc) + set(WRAP_MEMALIGN -Wl,--wrap=memalign) +endif() + +set(PIC_SHARED_LD_FLAGS + -mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} + -G0 + -fpic + -Wl,-Bsymbolic + -Wl,-L${TARGET_DIR_NOOS}/G0/pic + -Wl,-L${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/ + -Wl,--no-threads ${WRAP_MALLOC} ${WRAP_CALLOC} ${WRAP_FREE} ${WRAP_REALLOC} ${WRAP_MEMALIGN} + -shared + "-o <TARGET> <SONAME_FLAG><TARGET_SONAME>" + "<LINK_FLAGS>" + -Wl,--start-group + "<OBJECTS>" + "<LINK_LIBRARIES>" + -Wl,--end-group + -lc + ) +STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}") + +set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}") + +#System include paths +include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs) +include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef) +include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs) + +#LLVM toolchain setup +#Compiler paths, options and architecture +set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX}) +set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX}) +set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX}) +set(CMAKE_ASM_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX}) +set(HEXAGON_LINKER ${CMAKE_C_COMPILER}) +set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon) + +set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,") +set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,") + +#Compiler Options +set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") + +set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") +set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3") + +set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") +set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") +set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3") + +set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}") +set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}") +set(CMAKE_ASM_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}" ) + +#Linker Options +set(CMAKE_C_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}") +set(CMAKE_CXX_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}") diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c new file mode 100644 index 0000000..559ca18 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -0,0 +1,251 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <math.h> +#include <string.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" + +struct htp_copy_context { + struct htp_ops_context * octx; + + uint32_t src0_type_size; + uint32_t src0_block_size; + + uint32_t dst_type_size; + uint32_t dst_block_size; + + uint32_t src0_blocks_per_row; + uint32_t dst_blocks_per_row; + + uint32_t src0_nrows_per_thread; + + void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith); +}; + +#define cpy_preamble \ + struct htp_tensor *src0 = &octx->src0; \ + struct htp_tensor *dst = &octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const uint32_t nr = ne01; + +static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { + cpy_preamble; + + // parallelize by src0 rows + const uint32_t dr = ct->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + + // copy by rows + for (uint32_t i03 = 0; i03 < ne03; i03++) { + for (uint32_t i02 = 0; i02 < ne02; i02++) { + #pragma unroll(2) + for (uint32_t i01 = ir0; i01 < ir1; i01++) { + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2); + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size); + } + } + } +} + +static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) { + cpy_preamble; + + // parallelize by src0 rows + const uint32_t dr = ct->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + + // dst counters + int64_t k10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + // number of blocks in a row + const int64_t nk00 = ct->src0_blocks_per_row; + const int64_t nk0 = ct->dst_blocks_per_row; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + k10 += nk00 * ir0; + while (k10 >= nk0) { + k10 -= nk0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t k00 = 0; k00 < nk00; k00++) { + const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + memcpy(dst_ptr, src0_ptr, ct->dst_type_size); + + if (++k10 == nk0) { + k10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + k10 += nk00 * (ne01 - ir1); + while (k10 >= nk0) { + k10 -= nk0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } +} + +static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { + cpy_preamble; + + // parallelize by src0 rows + const uint32_t dr = ct->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + + // copy by rows + for (uint32_t i03 = 0; i03 < ne03; i03++) { + for (uint32_t i02 = 0; i02 < ne02; i02++) { + #pragma unroll(2) + for (uint32_t i01 = ir0; i01 < ir1; i01++) { + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + hex_l2fetch(src0_ptr, ne00 * sizeof(float), nb01, 2); + hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00); + } + } + } +} + +static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { + cpy_preamble; + + // parallelize by src0 rows + const uint32_t dr = ct->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + + // copy by rows + for (uint32_t i03 = 0; i03 < ne03; i03++) { + for (uint32_t i02 = 0; i02 < ne02; i02++) { + #pragma unroll(2) + for (uint32_t i01 = ir0; i01 < ir1; i01++) { + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + hex_l2fetch(src0_ptr, ne00 * sizeof(__fp16), nb01, 2); + hvx_copy_f32_f16_uu(dst_ptr, src0_ptr, ne00); + } + } + } +} + +static void cpy_work_func(unsigned int n, unsigned int i, void *data) { + struct htp_copy_context *ct = (struct htp_copy_context *) data; + ct->copy(ct, ct->octx, n, i); +} + +int op_cpy(struct htp_ops_context * octx) { + cpy_preamble; + + struct htp_copy_context ct; + ct.octx = octx; + + switch (src0->type) { + case HTP_TYPE_F32: ct.src0_type_size = 4; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break; + case HTP_TYPE_F16: ct.src0_type_size = 2; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + switch (dst->type) { + case HTP_TYPE_F32: ct.dst_type_size = 4; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break; + case HTP_TYPE_F16: ct.dst_type_size = 2; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const bool sametype = (src0->type == dst->type); + const bool transposed = (nb00 > nb01) || (nb0 > nb1); + const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + if (sametype && sameshape) { + ct.copy = cpy_thread_sametype_sameshape; + } else if (sameshape) { + /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32) + ct.copy = cpy_thread_f16_f32_sameshape; + else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16) + ct.copy = cpy_thread_f32_f16_sameshape; + else + return HTP_STATUS_NO_SUPPORT; + } else if (sametype) { + ct.copy = cpy_thread_sametype_reshape; + } else { + return HTP_STATUS_NO_SUPPORT; + } + + worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs); + + return HTP_STATUS_OK; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c new file mode 100644 index 0000000..c184637 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -0,0 +1,684 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <assert.h> +#include <HAP_farf.h> +#include <HAP_perf.h> +#include <math.h> +#include <string.h> + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) { + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements + return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); +} + +// Dot product of FP32 and FP16 vectors, accumulating to float +static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 + const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x_hf = Q6_V_vand_QV(bmask, x_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); + hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); +} + +// Dot product of FP32 and FP16 vectors, accumulating to float +static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r, + const void * restrict y, + const void * restrict x0, + const void * restrict x1, + unsigned int n, + float s) { + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); + // Load x (fp16) + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); + + // Load x (fp16) + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x0_hf = Q6_V_vand_QV(bmask, x0_hf); + x1_hf = Q6_V_vand_QV(bmask, x1_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); + hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); +} + +// Dot product of two F16 vectors, accumulating to float +static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { + const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = vy[i]; + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + } + + if (nloe) { + HVX_Vector y_hf = vy[i]; + + // Load x (fp16) and zero-out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); + hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); +} + +static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, + const void * restrict y, + const void * restrict x0, + const void * restrict x1, + unsigned int n, + float s) { + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = vy[i]; + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + if (nloe) { + HVX_Vector y_hf = vy[i]; + + // Load x (fp16) and zero-out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); + hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); +} + +// MAD: y (F32) += x (F16) * s (float) +static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { + const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; + HVX_Vector * restrict ptr_y = (HVX_Vector *) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector S = hvx_vec_splat_f16(s); + + uint32_t i = 0; + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + // Multiply x * s -> pair of F32 vectors + HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2])); + ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1])); + } + + if (nloe) { + HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + + HVX_Vector xs = Q6_V_lo_W(xs_p); + i = 2 * i; // index for ptr_y + + if (nloe >= 32) { + ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p); + } + + if (nloe) { + HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); + } + } +} + +#define FLASH_ATTN_BLOCK_SIZE 128 + +static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) { + const struct htp_tensor * q = &octx->src0; + const struct htp_tensor * k = &octx->src1; + const struct htp_tensor * v = &octx->src2; + const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; + const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; + struct htp_tensor * dst = &octx->dst; + + const uint32_t neq0 = q->ne[0]; + const uint32_t neq1 = q->ne[1]; + const uint32_t neq2 = q->ne[2]; + const uint32_t neq3 = q->ne[3]; + + const uint32_t nek0 = k->ne[0]; + const uint32_t nek1 = k->ne[1]; + const uint32_t nek2 = k->ne[2]; + const uint32_t nek3 = k->ne[3]; + + const uint32_t nev0 = v->ne[0]; + const uint32_t nev1 = v->ne[1]; + const uint32_t nev2 = v->ne[2]; + const uint32_t nev3 = v->ne[3]; + + const uint32_t nbq1 = q->nb[1]; + const uint32_t nbq2 = q->nb[2]; + const uint32_t nbq3 = q->nb[3]; + + const uint32_t nbk1 = k->nb[1]; + const uint32_t nbk2 = k->nb[2]; + const uint32_t nbk3 = k->nb[3]; + + const uint32_t nbv1 = v->nb[1]; + const uint32_t nbv2 = v->nb[2]; + const uint32_t nbv3 = v->nb[3]; + + const uint32_t ne1 = dst->ne[1]; + const uint32_t ne2 = dst->ne[2]; + const uint32_t ne3 = dst->ne[3]; + + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + // total rows in q + const uint32_t nr = neq1*neq2*neq3; + + const uint32_t dr = (nr + nth - 1) / nth; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, nr); + + if (ir0 >= ir1) return; + + dma_queue * dma = octx->ctx->dma[ith]; + + const uint32_t DK = nek0; + const uint32_t DV = nev0; + + const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2); + const size_t size_q_row_padded = hex_round_up(size_q_row, 128); + + const size_t size_k_row = DK * sizeof(__fp16); + const size_t size_v_row = DV * sizeof(__fp16); + const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask + + const size_t size_k_row_padded = hex_round_up(size_k_row, 128); + const size_t size_v_row_padded = hex_round_up(size_v_row, 128); + + const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator + uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith; + uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith; + uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith; + uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith; + uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith; + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + for (uint32_t ir = ir0; ir < ir1; ++ir) { + const uint32_t iq3 = fastdiv(ir, &octx->src0_div21); + const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1); + const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1); + + const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3); + const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2); + + const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3); + const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2); + + // Fetch Q row + const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); + dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1); + + const uint32_t h = iq2; // head index + const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f; + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + // Clear accumulator + hvx_splat_f32_a(spad_a, 0, DV); + float * VKQ32 = (float *) spad_a; + + const __fp16 * mp_base = NULL; + if (mask) { + const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2); + const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3); + mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]); + } + + const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; + + // Prefetch first two blocks + for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) { + const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); + + // K + const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); + uint8_t * k_dst = spad_k + (ib % 2) * size_k_block; + dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size); + + // V + const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); + uint8_t * v_dst = spad_v + (ib % 2) * size_v_block; + dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size); + + // Mask + if (mask) { + const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); + uint8_t * m_dst = spad_m + (ib % 2) * size_m_block; + // Mask is 1D contiguous for this row + dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); + } + } + + const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + + for (uint32_t ib = 0; ib < n_blocks; ++ib) { + const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); + + // Wait for DMA + uint8_t * k_base = dma_queue_pop(dma).dst; // K + uint8_t * v_base = dma_queue_pop(dma).dst; // V + __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M + + // Inner loop processing the block from VTCM + uint32_t ic = 0; + + const bool is_q_fp32 = (q->type == HTP_TYPE_F32); + + // Process in blocks of 32 (VLEN_FP32) + static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); + HVX_Vector_x4 scores_x4; + HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); + for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { + // 1. Compute scores + float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; + for (int j = 0; j < VLEN_FP32; j += 2) { + const uint32_t cur_ic = ic + j; + const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; + if (is_q_fp32) { + hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); + } else { + hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); + } + } + + HVX_Vector scores = *(HVX_Vector *) scores_arr; + + // 2. Softcap + if (logit_softcap != 0.0f) { + scores = hvx_vec_tanh_f32(scores); + scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap)); + scores = Q6_Vsf_equals_Vqf32(scores); + } + + // 3. Mask + if (mask) { + const __fp16 * mp = m_base + ic; + HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; + + HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00); + HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16); + + HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair)); + + HVX_Vector slope_vec = hvx_vec_splat_f32(slope); + HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec); + scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); + scores = Q6_Vsf_equals_Vqf32(scores); + } + + scores_x4.v[iv] = scores; + v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max); + } + + { + // 4. Online Softmax Update + v_max = hvx_vec_reduce_max_f32(v_max); + float m_block = hvx_vec_get_f32(v_max); + float M_old = M; + float M_new = (m_block > M) ? m_block : M; + M = M_new; + + const float ms = expf(M_old - M_new); + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + + HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new); + HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); + for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { + HVX_Vector scores = scores_x4.v[iv]; + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); + + p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); + + // 5. Accumulate V + float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; + *(HVX_Vector*)p_arr = P; + + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic2 + j; + const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + } + } + + p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); + S = S * ms + hvx_vec_get_f32(p_sum_vec); + } + + // Leftover + for (; ic < current_block_size; ++ic) { + float s_val; + const uint8_t * k_ptr = k_base + ic * size_k_row_padded; + + if (is_q_fp32) { + hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); + } else { + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); + } + + if (logit_softcap != 0.0f) { + s_val = logit_softcap * tanhf(s_val); + } + + if (mask) { + const float m_val = m_base[ic]; + s_val += slope * m_val; + } + + const float Mold = M; + float ms = 1.0f; + float vs = 1.0f; + + if (s_val > M) { + M = s_val; + ms = expf(Mold - M); + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + } else { + vs = expf(s_val - M); + } + + const uint8_t * v_ptr = v_base + ic * size_v_row_padded; + + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); + + S = S * ms + vs; + } + + // Issue DMA for next+1 block (if exists) + if (ib + 2 < n_blocks) { + const uint32_t next_ib = ib + 2; + const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start); + + // K + const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); + dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size); + + // V + const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); + dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size); + + // Mask + if (mask) { + const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); + dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); + } + } + } + + // sinks + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv); + + // Store result + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // dst is permuted + uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1; + + if (dst->type == HTP_TYPE_F32) { + hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + } else if (dst->type == HTP_TYPE_F16) { + hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + } + } +} + +static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + flash_attn_ext_f16_thread(octx, i, n); +} + +int op_flash_attn_ext(struct htp_ops_context * octx) { + const struct htp_tensor * q = &octx->src0; + const struct htp_tensor * k = &octx->src1; + const struct htp_tensor * v = &octx->src2; + const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL; + struct htp_tensor * dst = &octx->dst; + + // Check support + if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || + k->type != HTP_TYPE_F16 || + v->type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); + octx->src0_div1 = init_fastdiv_values(q->ne[1]); + + octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); + octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); + octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); + octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); + + if (mask) { + octx->src3_div2 = init_fastdiv_values(mask->ne[2]); + octx->src3_div3 = init_fastdiv_values(mask->ne[3]); + } + + size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128); + size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128); + size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128); + + size_t size_q_block = size_q_row_padded * 1; // single row for now + size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 + + octx->src0_spad.size_per_thread = size_q_block * 1; + octx->src1_spad.size_per_thread = size_k_block * 2; + octx->src2_spad.size_per_thread = size_v_block * 2; + octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0; + octx->dst_spad.size_per_thread = size_vkq_acc; + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads; + octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size; + + if (octx->ctx->vtcm_size < total_spad) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; + octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads); + } + + return HTP_STATUS_OK; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c new file mode 100644 index 0000000..a657cd2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -0,0 +1,106 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <math.h> +#include <string.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" + +#define get_rows_preamble \ + const uint32_t ne00 = octx->src0.ne[0]; \ + const uint32_t ne01 = octx->src0.ne[1]; \ + const uint32_t ne02 = octx->src0.ne[2]; \ + const uint32_t ne03 = octx->src0.ne[3]; \ + \ + const uint32_t ne10 = octx->src1.ne[0]; \ + const uint32_t ne11 = octx->src1.ne[1]; \ + const uint32_t ne12 = octx->src1.ne[2]; \ + \ + const uint32_t nb01 = octx->src0.nb[1]; \ + const uint32_t nb02 = octx->src0.nb[2]; \ + const uint32_t nb03 = octx->src0.nb[3]; \ + \ + const uint32_t nb10 = octx->src1.nb[0]; \ + const uint32_t nb11 = octx->src1.nb[1]; \ + const uint32_t nb12 = octx->src1.nb[2]; \ + \ + const uint32_t nb1 = octx->dst.nb[1]; \ + const uint32_t nb2 = octx->dst.nb[2]; \ + const uint32_t nb3 = octx->dst.nb[3]; \ + \ + const uint32_t nr = ne10 * ne11 * ne12; + +static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { + get_rows_preamble; + + // parallelize by src1 elements (which correspond to dst rows) + const uint32_t dr = octx->src1_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11); + const uint32_t rem = i - i12 * ne11 * ne10; + const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10); + const uint32_t i10 = rem - i11 * ne10; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + + if (i01 >= ne01) { + // invalid index, skip for now to avoid crash + continue; + } + + const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03; + const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; + hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + } + + return HTP_STATUS_OK; +} + +static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { + get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); +} + +int op_get_rows(struct htp_ops_context * octx) { + get_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->dst.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); + octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs); + return HTP_STATUS_OK; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.c b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.c new file mode 100644 index 0000000..44e1be4 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.c @@ -0,0 +1,63 @@ +#include "hex-dma.h" + +#include <stdbool.h> +#include <stdlib.h> +#include <string.h> + +#pragma clang diagnostic ignored "-Wunused-function" + +static inline uint32_t pow2_ceil(uint32_t x) { + if (x <= 1) { + return 1; + } + int p = 2; + x--; + while (x >>= 1) { + p <<= 1; + } + return p; +} + +dma_queue * dma_queue_create(size_t capacity) { + dma_queue * q = (dma_queue *) memalign(32, sizeof(dma_queue)); + if (q == NULL) { + FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__); + return NULL; + } + + capacity = pow2_ceil(capacity); + + memset(q, 0, sizeof(dma_queue)); + q->capacity = capacity; + q->idx_mask = capacity - 1; + + q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t)); + memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t)); + + q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr)); + memset(q->dptr, 0, capacity * sizeof(dma_ptr)); + + q->tail = &q->desc[capacity - 1]; + + if (!q->desc && !q->dptr) { + FARF(ERROR, "%s: failed to allocate DMA queue items\n", __FUNCTION__); + return NULL; + } + + FARF(HIGH, "dma-queue: capacity %u\n", capacity); + + return q; +} + +void dma_queue_delete(dma_queue * q) { + if (!q) { + return; + } + free(q->desc); + free(q->dptr); + free(q); +} + +void dma_queue_flush(dma_queue * q) { + while (dma_queue_pop(q).dst != NULL) ; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.h new file mode 100644 index 0000000..d1ddb0e --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -0,0 +1,156 @@ +#ifndef HTP_DMA_H +#define HTP_DMA_H + +#include <HAP_farf.h> +#include <hexagon_types.h> +#include <stdbool.h> +#include <stdint.h> + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct { + void *dst; + const void *src; +} dma_ptr; + +typedef struct { + hexagon_udma_descriptor_type1_t * desc; // descriptor pointers + hexagon_udma_descriptor_type1_t * tail; // tail pointer + dma_ptr * dptr; // dst/src pointers + uint32_t push_idx; + uint32_t pop_idx; + uint32_t capacity; + uint32_t idx_mask; +} dma_queue; + +dma_queue * dma_queue_create(size_t capacity); +void dma_queue_delete(dma_queue * q); +void dma_queue_flush(dma_queue * q); + +// TODO: technically we don't need these and could use Q6_dmstart/wait/etc instead +// but those do not seem to always compiler properly. +static inline void dmstart(void * next) { + asm volatile(" release(%0):at" : : "r"(next)); + asm volatile(" dmstart(%0)" : : "r"(next)); +} + +static inline void dmlink(void * cur, void * next) { + asm volatile(" release(%0):at" : : "r"(next)); + asm volatile(" dmlink(%0, %1)" : : "r"(cur), "r"(next)); +} + +static inline unsigned int dmpoll(void) { + unsigned int ret = 0; + asm volatile(" %0 = dmpoll" : "=r"(ret) : : "memory"); + return ret; +} + +static inline unsigned int dmwait(void) { + unsigned int ret = 0; + asm volatile(" %0 = dmwait" : "=r"(ret) : : "memory"); + return ret; +} + +static inline dma_ptr dma_make_ptr(void *dst, const void *src) +{ + dma_ptr p = { dst, src }; + return p; +} + +static inline bool dma_queue_push(dma_queue * q, + dma_ptr dptr, + size_t dst_row_size, + size_t src_row_size, + size_t width, // width in bytes. number of bytes to transfer per row + size_t nrows) { + if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { + FARF(ERROR, "dma-push: queue full\n"); + return false; + } + + hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx]; + + desc->next = NULL; + desc->length = 0; + desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1; + desc->dstbypass = 1; + desc->srcbypass = 1; +#if __HVX_ARCH__ >= 73 + desc->dstbypass = 1; + desc->srcbypass = 1; +#else + desc->dstbypass = 0; + desc->srcbypass = 1; +#endif + desc->order = 0; + desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE; + desc->src = (void *) dptr.src; + desc->dst = (void *) dptr.dst; + desc->allocation = 0; + desc->padding = 0; + desc->roiwidth = width; + desc->roiheight = nrows; + desc->srcstride = src_row_size; + desc->dststride = dst_row_size; + desc->srcwidthoffset = 0; + desc->dstwidthoffset = 0; + + q->dptr[q->push_idx] = dptr; + + dmlink(q->tail, desc); + q->tail = desc; + + // FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src); + q->push_idx = (q->push_idx + 1) & q->idx_mask; + return true; +} + +static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, + dma_ptr dptr, + size_t dst_row_size, + size_t src_row_size, + size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows); +} + + +static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, + dma_ptr dptr, + size_t dst_row_size, + size_t src_row_size, + size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); +} + +static inline dma_ptr dma_queue_pop(dma_queue * q) { + dma_ptr dptr = { NULL }; + + if (q->push_idx == q->pop_idx) { + return dptr; + } + + hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx]; + + // Wait for desc to complete + while (1) { + dmpoll(); + if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) { + break; + } + // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); + } + + dptr = q->dptr[q->pop_idx]; + + // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst); + q->pop_idx = (q->pop_idx + 1) & q->idx_mask; + return dptr; +} + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* HTP_DMA_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dump.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dump.h new file mode 100644 index 0000000..e3badb5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dump.h @@ -0,0 +1,77 @@ +#ifndef HEX_DUMP_H +#define HEX_DUMP_H + +#include <HAP_farf.h> + +static inline void hex_dump_int8_line(char * pref, const int8_t * x, int n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n && p < p_end; i++) { + p += snprintf(p, p_end - p, "%d, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n && p < p_end; i++) { + p += snprintf(p, p_end - p, "%d, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%d, ", (int) x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_f16_line(char * pref, const __fp16 * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_f32_line(char * pref, const float * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%.6f, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_f32(char * pref, const float * x, uint32_t n) { + uint32_t n0 = n / 16; + uint32_t n1 = n % 16; + + uint32_t i = 0; + for (; i < n0; i++) { + hex_dump_f32_line(pref, x + (16 * i), 16); + } + if (n1) { + hex_dump_f32_line(pref, x + (16 * i), n1); + } +} + +static inline void hex_dump_f16(char * pref, const __fp16 * x, uint32_t n) { + uint32_t n0 = n / 16; + uint32_t n1 = n % 16; + + uint32_t i = 0; + for (; i < n0; i++) { + hex_dump_f16_line(pref, x + (16 * i), 16); + } + if (n1) { + hex_dump_f16_line(pref, x + (16 * i), n1); + } +} + +#endif /* HEX_DUMP_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h new file mode 100644 index 0000000..b7b5867 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h @@ -0,0 +1,37 @@ +#ifndef HEX_FASTDIV_H +#define HEX_FASTDIV_H + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +struct fastdiv_values { + uint32_t mp; + uint32_t l; +}; + +static inline struct fastdiv_values init_fastdiv_values(uint32_t d) { + struct fastdiv_values result = { 0, 0 }; + // compute L = ceil(log2(d)); + while (result.l < 32 && ((uint32_t) 1 << result.l) < d) { + ++(result.l); + } + + result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1); + return result; +} + +static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) { + // Compute high 32 bits of n * mp + const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp) + // add n, apply bit shift + return (hi + n) >> vals->l; +} + +static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) { + return n - fastdiv(n, vals) * d; +} + +#endif /* HEX_FASTDIV_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hex-utils.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-utils.h new file mode 100644 index 0000000..fb8a25a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -0,0 +1,51 @@ +#ifndef HEX_UTILS_H +#define HEX_UTILS_H + +#include <stdbool.h> +#include <stdint.h> + +#include "hexagon_types.h" + +#include "hex-fastdiv.h" +#include "hex-dump.h" + +#ifndef MAX +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#endif + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +static inline uint64_t hex_get_cycles() { + uint64_t cycles = 0; + asm volatile(" %0 = c15:14\n" : "=r"(cycles)); + return cycles; +} + +static inline uint64_t hex_get_pktcnt() { + uint64_t pktcnt; + asm volatile(" %0 = c19:18\n" : "=r"(pktcnt)); + return pktcnt; +} + +static inline int32_t hex_is_aligned(void * addr, uint32_t align) { + return ((size_t) addr & (align - 1)) == 0; +} + +static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { + uint32_t left_off = (size_t) addr & (chunk_size - 1); + uint32_t right_off = left_off + n; + return right_off <= chunk_size; +} + +static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { + return m * ((n + m - 1) / m); +} + +static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { + const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); + Q6_l2fetch_AP((void *) p, control); +} + +#endif /* HEX_UTILS_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h new file mode 100644 index 0000000..a707d98 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -0,0 +1,35 @@ +#ifndef HTP_CTX_H +#define HTP_CTX_H + +#include "hex-dma.h" +#include "worker-pool.h" + +#include <assert.h> +#include <dspqueue.h> +#include <stdatomic.h> +#include <stdint.h> + +#define HTP_MAX_NTHREADS 10 + +// Main context for htp DSP backend +struct htp_context { + dspqueue_t queue; + dma_queue * dma[HTP_MAX_NTHREADS]; + worker_pool_context_t worker_pool; + uint32_t n_threads; + + int thread_id; + int thread_prio; + + uint8_t * vtcm_base; + size_t vtcm_size; + uint32_t vtcm_rctx; + + atomic_bool vtcm_valid; + atomic_bool vtcm_inuse; + atomic_bool vtcm_needs_release; + + uint32_t opmask; +}; + +#endif /* HTP_CTX_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/htp-msg.h b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-msg.h new file mode 100644 index 0000000..25403bb --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -0,0 +1,154 @@ +#ifndef HTP_MSG_H +#define HTP_MSG_H + +#include <assert.h> + +// ggml-common.h must be included prio to this header + +// Mask to enable various stages of the Ops. +// Used for debugging and profiling. +enum { + HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP) + HTP_OPMASK_QUANTIZE = (1 << 1), // Enable Quantize + HTP_OPMASK_COMPUTE = (1 << 2), // Enable Compute +}; + +// Op flags +enum { + HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0), // Skip dynamic quantization (reuse quantized tensors) + HTP_OPFLAGS_SKIP_COMPUTE = (1 << 1), // Skip actual computation (used for profiling) + HTP_OPFLAGS_EARLY_WAKEUP = (1 << 2) // Send early wakeup notification +}; + +enum htp_status { + HTP_STATUS_OK = 1, + HTP_STATUS_INTERNAL_ERR = 2, + HTP_STATUS_NO_SUPPORT = 3, + HTP_STATUS_INVAL_PARAMS = 4, + HTP_STATUS_VTCM_TOO_SMALL = 5, +}; + +// The values must match the ggml_type. +// Duplicated here because we can't include full ggml.h in the htp build. +// We have some static_asserts in the cpp code to ensure things are in sync. +enum htp_data_type { + HTP_TYPE_F32 = 0, + HTP_TYPE_F16 = 1, + HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q8_0 = 8, + HTP_TYPE_I32 = 26, + HTP_TYPE_I64 = 27, + HTP_TYPE_MXFP4 = 39, + HTP_TYPE_COUNT +}; + +// Do not reorder first 4 (used as an index) +enum htp_op { + HTP_OP_MUL = 0, + HTP_OP_ADD = 1, + HTP_OP_SUB = 2, + HTP_OP_DIV = 3, + HTP_OP_MUL_MAT, + HTP_OP_MUL_MAT_ID, + HTP_OP_RMS_NORM, + HTP_OP_UNARY_SILU, + HTP_OP_UNARY_GELU, + HTP_OP_GLU_SWIGLU, + HTP_OP_GLU_SWIGLU_OAI, + HTP_OP_GLU_GEGLU, + HTP_OP_SOFTMAX, + HTP_OP_ADD_ID, + HTP_OP_ROPE, + HTP_OP_FLASH_ATTN_EXT, + HTP_OP_SET_ROWS, + HTP_OP_GET_ROWS, + HTP_OP_SCALE, + HTP_OP_CPY, + HTP_OP_ARGSORT, + HTP_OP_SQR, + HTP_OP_SQRT, + HTP_OP_SUM_ROWS, + INVALID +}; + +static inline size_t htp_t_block_size(uint32_t t) { + switch (t) { + case HTP_TYPE_F32: + return 1; + case HTP_TYPE_F16: + return 1; + case HTP_TYPE_Q4_0: + return QK4_0; + case HTP_TYPE_Q8_0: + return QK8_0; + case HTP_TYPE_MXFP4: + return QK_MXFP4; + default: + assert(0 && "unsupported HTP data type"); + } + return 0; +} + +static inline size_t htp_type_nbytes(uint32_t t) { + switch (t) { + case HTP_TYPE_F32: + return 4; + case HTP_TYPE_F16: + return 2; + case HTP_TYPE_Q4_0: + return sizeof(block_q4_0); + case HTP_TYPE_Q8_0: + return sizeof(block_q8_0); + case HTP_TYPE_MXFP4: + return sizeof(block_mxfp4); + default: + assert(0 && "unsupported HTP data type"); + } + return 0; +} + +// Internal types +#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) +#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks +#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks + +#define HTP_MAX_DIMS 4 + +struct htp_tensor { + uint32_t data; // Buffer offset in the messages, and data pointer on the NSP + uint32_t type; // Data type + uint32_t ne[HTP_MAX_DIMS]; // Number of elements + uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) +}; + +#define HTP_MAX_OP_PARAMS 64 + +struct htp_general_req { + uint32_t op; // GGML/HTP Op + int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; + // Params for the op, e.g. epsilon of RMS norm + uint32_t flags; // Request flags + + struct htp_tensor src0; // Input0 tensor + struct htp_tensor src1; // Input1 tensor + struct htp_tensor src2; // Input2 tensor + struct htp_tensor src3; // Input3 tensor + struct htp_tensor src4; // Input4 tensor + struct htp_tensor dst; // Output tensor + + // should be multiple of 64 bytes (cacheline) +}; + +struct htp_general_rsp { + uint32_t op; // GGML/HTP Op + uint32_t status; // HTP_STATUS_... + uint32_t prof_usecs; // Number of usec per request + uint32_t prof_cycles; // Number of cycles per request + uint32_t prof_pkts; // Number of instruction packets per request + uint8_t unused[44]; // Pad to 64 bytes +}; + +#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req) +#define HTP_MAX_PACKET_BUFFERS 8 + +#endif /* HTP_MSG_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ops.h b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ops.h new file mode 100644 index 0000000..f1ad24d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -0,0 +1,91 @@ +#ifndef HTP_OPS_H +#define HTP_OPS_H + +#include "htp-ctx.h" +#include "htp-msg.h" +#include "worker-pool.h" + +#include <assert.h> +#include <stdint.h> + +#include <hex-fastdiv.h> + +// ggml-common.h must be included prior to this header + +struct htp_spad { + uint8_t * data; + size_t stride; + size_t size; + size_t size_per_thread; +}; + +struct htp_ops_context { + struct htp_context * ctx; + + enum htp_op op; + int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; + + struct htp_tensor src0; + struct htp_tensor src1; + struct htp_tensor src2; + struct htp_tensor src3; + struct htp_tensor src4; + struct htp_tensor dst; + + struct htp_spad src0_spad; + struct htp_spad src1_spad; + struct htp_spad src2_spad; + struct htp_spad src3_spad; + struct htp_spad dst_spad; + + worker_pool_context_t * wpool; // worker pool + uint32_t n_threads; // num threads + + uint32_t src0_nrows_per_thread; + uint32_t src1_nrows_per_thread; + + struct fastdiv_values src0_div1; // fastdiv values for ne1 + struct fastdiv_values src0_div2; // fastdiv values for ne2 + struct fastdiv_values src0_div3; // fastdiv values for ne3 + struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1 + + struct fastdiv_values src1_div1; // fastdiv values for ne1 + struct fastdiv_values src1_div2; // fastdiv values for ne2 + struct fastdiv_values src1_div3; // fastdiv values for ne3 + struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1 + + struct fastdiv_values src3_div1; // fastdiv values for ne1 + struct fastdiv_values src3_div2; // fastdiv values for ne2 + struct fastdiv_values src3_div3; // fastdiv values for ne3 + struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1 + + struct fastdiv_values broadcast_rk2; + struct fastdiv_values broadcast_rk3; + struct fastdiv_values broadcast_rv2; + struct fastdiv_values broadcast_rv3; + + struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 + struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 + + struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 + struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 + + uint32_t flags; +}; + +int op_matmul(struct htp_ops_context * octx); +int op_matmul_id(struct htp_ops_context * octx); +int op_binary(struct htp_ops_context * octx); +int op_unary(struct htp_ops_context * octx); +int op_sum_rows(struct htp_ops_context * octx); +int op_activations(struct htp_ops_context * octx); +int op_softmax(struct htp_ops_context * octx); +int op_add_id(struct htp_ops_context * octx); +int op_rope(struct htp_ops_context * octx); +int op_flash_attn_ext(struct htp_ops_context * octx); +int op_set_rows(struct htp_ops_context * octx); +int op_get_rows(struct htp_ops_context * octx); +int op_cpy(struct htp_ops_context * octx); +int op_argsort(struct htp_ops_context * octx); + +#endif /* HTP_OPS_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl b/llama.cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl new file mode 100644 index 0000000..9ebd937 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -0,0 +1,16 @@ +// FastRPC IDL interface for GGML HTP + +#ifndef HTP_IDL +#define HTP_IDL + +#include "AEEStdDef.idl" +#include "remote.idl" + +interface htp_iface : remote_handle64 { + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx); + AEEResult stop(); + AEEResult enable_etm(); + AEEResult disable_etm(); +}; + +#endif /* HTP_IDL */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h new file mode 100644 index 0000000..2577cdd --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h @@ -0,0 +1,470 @@ +#ifndef HVX_ARITH_H +#define HVX_ARITH_H + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> +#include <math.h> + +#include "hvx-base.h" +#include "hex-utils.h" + +// +// Binary operations (add, mul, sub) +// + +#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = vec_op(vsrc0[i], vsrc1[i]); \ + } \ + if (nloe) { \ + HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) +#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b) +#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \ +static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ +} \ + +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL) + +// Dispatcher logic +#define HVX_BINARY_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128)) { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \ + else OP_NAME##_aau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \ + else OP_NAME##_auu(dst, src0, src1, num_elems); \ + } \ + } else { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \ + else OP_NAME##_uau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \ + else OP_NAME##_uuu(dst, src0, src1, num_elems); \ + } \ + } \ +} + +HVX_BINARY_DISPATCHER(hvx_add_f32) +HVX_BINARY_DISPATCHER(hvx_sub_f32) +HVX_BINARY_DISPATCHER(hvx_mul_f32) + +// Mul-Mul Optimized +static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src0 % 128 == 0); + assert((unsigned long) src1 % 128 == 0); + assert((unsigned long) src2 % 128 == 0); + + HVX_Vector * restrict vdst = (HVX_Vector *) dst; + HVX_Vector * restrict vsrc0 = (HVX_Vector *) src0; + HVX_Vector * restrict vsrc1 = (HVX_Vector *) src1; + HVX_Vector * restrict vsrc2 = (HVX_Vector *) src2; + + const uint32_t elem_size = sizeof(float); + const uint32_t epv = 128 / elem_size; + const uint32_t nvec = num_elems / epv; + const uint32_t nloe = num_elems % epv; + + uint32_t i = 0; + + _Pragma("unroll(4)") + for (; i < nvec; i++) { + HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]); + vdst[i] = HVX_OP_MUL(v1, vsrc2[i]); + } + + if (nloe) { + HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]); + HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]); + hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2); + } +} + +// Scalar Operations + +#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector v = vsrc[i]; \ + vdst[i] = scalar_op_macro(v); \ + } \ + if (nloe) { \ + HVX_Vector v = vsrc[i]; \ + v = scalar_op_macro(v); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +#define HVX_OP_ADD_SCALAR(v) \ + ({ \ + const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \ + HVX_Vector out = HVX_OP_ADD(v, val_vec); \ + Q6_V_vmux_QVV(pred_inf, inf, out); \ + }) + +#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec) +#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec) + +// Add Scalar Variants + +static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR); +} + +static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR); +} + +static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR); +} + +static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + static const float kInf = INFINITY; + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR); +} + +// Sub Scalar Variants + +static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR); +} + +static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR); +} + +static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR); +} + +static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR); +} + +// Mul Scalar Variants + +static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR); +} + +static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR); +} + +static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR); +} + +static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR); +} + +static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_add_scalar_f32_aa(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_add_scalar_f32_au(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_add_scalar_f32_ua(dst, src, val, num_elems); + } else { + hvx_add_scalar_f32_uu(dst, src, val, num_elems); + } +} + +static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_mul_scalar_f32_aa(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_mul_scalar_f32_au(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_mul_scalar_f32_ua(dst, src, val, num_elems); + } else { + hvx_mul_scalar_f32_uu(dst, src, val, num_elems); + } +} + +static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_sub_scalar_f32_aa(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_sub_scalar_f32_au(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_sub_scalar_f32_ua(dst, src, val, num_elems); + } else { + hvx_sub_scalar_f32_uu(dst, src, val, num_elems); + } +} + +// MIN Scalar variants + +#define HVX_OP_MIN_SCALAR(v) Q6_Vsf_vmin_VsfVsf(val_vec, v) + +static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_min_scalar_f32_aa(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_min_scalar_f32_au(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_min_scalar_f32_ua(dst, src, val, num_elems); + } else { + hvx_min_scalar_f32_uu(dst, src, val, num_elems); + } +} + +// CLAMP Scalar variants + +#define HVX_OP_CLAMP_SCALAR(v) \ + ({ \ + HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(v, max_vec); \ + HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(min_vec, v); \ + HVX_Vector tmp = Q6_V_vmux_QVV(pred_cap_right, max_vec, v); \ + Q6_V_vmux_QVV(pred_cap_left, min_vec, tmp); \ + }) + +static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_clamp_scalar_f32_aa(dst, src, min, max, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_clamp_scalar_f32_au(dst, src, min, max, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_clamp_scalar_f32_ua(dst, src, min, max, num_elems); + } else { + hvx_clamp_scalar_f32_uu(dst, src, min, max, num_elems); + } +} + +// +// Square +// + +#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + } \ + if (nloe) { \ + HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128)) { + if (hex_is_aligned((void *) src, 128)) { + hvx_sqr_f32_aa(dst, src, num_elems); + } else { + hvx_sqr_f32_au(dst, src, num_elems); + } + } else { + if (hex_is_aligned((void *) src, 128)) { + hvx_sqr_f32_ua(dst, src, num_elems); + } else { + hvx_sqr_f32_uu(dst, src, num_elems); + } + } +} + +#undef HVX_OP_ADD +#undef HVX_OP_SUB +#undef HVX_OP_MUL +#undef hvx_arith_loop_body +#undef HVX_OP_ADD_SCALAR +#undef HVX_OP_SUB_SCALAR +#undef HVX_OP_MUL_SCALAR +#undef hvx_scalar_loop_body +#undef HVX_OP_MIN_SCALAR +#undef HVX_OP_CLAMP_SCALAR +#undef DEFINE_HVX_BINARY_OP_VARIANTS +#undef HVX_BINARY_DISPATCHER + +#endif // HVX_ARITH_H diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h new file mode 100644 index 0000000..12a1b7f --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -0,0 +1,173 @@ +#ifndef HVX_BASE_H +#define HVX_BASE_H + +#include <stdbool.h> +#include <stdint.h> + +#include "hex-utils.h" +#include "hvx-types.h" + +static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) { + // Rotate as needed. + v = Q6_V_vlalign_VVR(v, v, (size_t) dst); + + uint32_t left_off = (size_t) dst & 127; + uint32_t right_off = left_off + n; + + HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) dst); + HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off); + + if (right_off > 128) { + Q6_vmem_QRIV(qr, (HVX_Vector *) dst + 1, v); + // all 1's + qr = Q6_Q_vcmp_eq_VbVb(v, v); + } + + ql_not = Q6_Q_or_QQn(ql_not, qr); + Q6_vmem_QnRIV(ql_not, (HVX_Vector *) dst, v); +} + +static inline void hvx_vec_store_a(void * restrict dst, uint32_t n, HVX_Vector v) { + assert((unsigned long) dst % 128 == 0); + HVX_VectorPred m = Q6_Q_or_QQn(Q6_Q_vsetq_R((unsigned long) dst), Q6_Q_vsetq2_R(n)); + Q6_vmem_QnRIV(m, (HVX_Vector *) dst, v); +} + +static inline HVX_Vector hvx_vec_splat_f32(float v) { + union { float f; uint32_t i; } u = { .f = v }; + return Q6_V_vsplat_R(u.i); +} + +static inline HVX_Vector hvx_vec_splat_f16(float v) { + union { __fp16 f; uint16_t i; } u = { .f = v }; + return Q6_Vh_vsplat_R(u.i); +} + +static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all elements + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + + HVX_Vector ctrl = *(HVX_Vector *) repl; + return Q6_V_vdelta_VV(v, ctrl); +} + +static inline float hvx_vec_get_f32(HVX_Vector v) { + float __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 4, v); + return x; +} + +static inline int32_t hvx_vec_get_i32(HVX_Vector v) { + int32_t __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 4, v); + return x; +} + +static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { + // abs by clearing the fp16 sign bit + HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); + return Q6_V_vand_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_neg_f16(HVX_Vector v) { + // neg by setting the fp16 sign bit + HVX_Vector mask = Q6_Vh_vsplat_R(0x8000); + return Q6_V_vxor_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_abs_f32(HVX_Vector v) { + // abs by clearing the fp32 sign bit + HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff); + return Q6_V_vand_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_neg_f32(HVX_Vector v) { +#if __HVX_ARCH__ > 75 + return Q6_Vsf_vfneg_Vsf(v); +#else + // neg by setting the fp32 sign bit + HVX_Vector mask = Q6_V_vsplat_R(0x80000000); + return Q6_V_vxor_VV(v, mask); +#endif // __HVX_ARCH__ > 75 +} + +static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) { + const HVX_Vector vnan_exp = Q6_Vh_vsplat_R(0x7C00); + const HVX_Vector vnan_frac = Q6_Vh_vsplat_R(0x7FFF); + + // get pred of which are NaN, i.e., exponent bits all 1s and fraction bits non 0s + HVX_VectorPred p_exp = Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_exp), vnan_exp); + HVX_VectorPred p_frac = Q6_Q_not_Q(Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_frac), vnan_exp)); + return Q6_Q_and_QQ(p_exp, p_frac); +} + +static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero); + HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero); + HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0))); + +#if __HVX_ARCH__ < 79 + // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0) + const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY); + HVX_VectorPred nan = hvx_vec_is_nan_f16(v); + v = Q6_V_vmux_QVV(nan, neg_inf, v); +#endif + + return v; +} + +/* Q6_Vsf_equals_Vw is only available on v73+.*/ +#if __HVX_ARCH__ < 73 +static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) +{ + HVX_Vector const vzero = Q6_V_vzero(); + HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero); + HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in); + HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift); + HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift); + HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized); + HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp)); + return ret; +} + +static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) +{ + return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in)); +} +#endif + +static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { + // This looks complicated. + // Ideally should just be Q6_Vh_equals_Vhf(vin) + // but that instruction does not do proper rounding. + + // convert to qf32, multiplying by 1.0 in the process. + HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00)); + + // 'in-range' values are +/32752. + // add 192K to it, convert to sf + HVX_Vector v192K = Q6_V_vsplat_R(0x48400000); + HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K)); + HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K)); + + // for in-range cases, result is {163858... 229360} so the exponent is always 144. + // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer. + // Start by <<10 to get the final 'sign' bit in bit 15... + vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10); + vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10); + + // now round down to 16 + return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0); +} + +#endif /* HVX_BASE_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h new file mode 100644 index 0000000..ae0dbed --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -0,0 +1,245 @@ +#ifndef HVX_COPY_H +#define HVX_COPY_H + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> + +#include "hvx-base.h" + +#define hvx_splat_loop_body(dst_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + \ + uint32_t nvec = n / (128 / elem_size); \ + uint32_t nloe = n % (128 / elem_size); \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = src; \ + } \ + if (nloe) { \ + vec_store((void *) &vdst[i], nloe * elem_size, src); \ + } \ + } while(0) + +static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) dst % 128 == 0); + hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { + hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) { + hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float)); +} + +static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) { + hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float)); +} + +static inline void hvx_splat_f16_a(uint8_t * restrict dst, float v, uint32_t n) { + hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); +} + +static inline void hvx_splat_f16_u(uint8_t * restrict dst, float v, uint32_t n) { + hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); +} + +#define hvx_copy_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { vdst[i] = vsrc[i]; } \ + if (nloe) { \ + vec_store((void *) &vdst[i], nloe * elem_size, vsrc[i]); \ + } \ + } while(0) + +// Generic copy routines +static inline void hvx_copy_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_copy_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_copy_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) dst % 128 == 0); + hvx_copy_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_copy_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) src % 128 == 0); + hvx_copy_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_copy_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + hvx_copy_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +// copy n fp16 elements : source and destination are aligned to HVX Vector (128) +static inline void hvx_copy_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_aa(dst, src, n, sizeof(__fp16)); +} + +// copy n fp16 elements : source is aligned, destination is potentially unaligned +static inline void hvx_copy_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_au(dst, src, n, sizeof(__fp16)); +} + +// copy n fp16 elements : source is aligned, destination is potentially unaligned +static inline void hvx_copy_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_ua(dst, src, n, sizeof(__fp16)); +} + +// copy n fp16 elements : source is aligned, destination is potentially unaligned +static inline void hvx_copy_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_uu(dst, src, n, sizeof(__fp16)); +} + +// copy n fp32 elements : source and destination are aligned to HVX Vector (128) +static inline void hvx_copy_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_aa(dst, src, n, sizeof(float)); +} + +// copy n fp32 elements : source is aligned, destination is unaligned +static inline void hvx_copy_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_ua(dst, src, n, sizeof(float)); +} + +// copy n fp32 elements : source is unaligned, destination is aligned +static inline void hvx_copy_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_au(dst, src, n, sizeof(float)); +} + +// copy n fp32 elements : source is unaligned, destination unaligned +static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_uu(dst, src, n, sizeof(float)); +} + +//// fp32 -> fp16 + +#define hvx_copy_f16_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t elem_size = sizeof(__fp16); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \ + } \ + if (nloe) { \ + HVX_Vector v = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is aligned +static inline void hvx_copy_f16_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned +static inline void hvx_copy_f16_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned +static inline void hvx_copy_f16_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned +static inline void hvx_copy_f16_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +//// fp16 -> fp32 + +#define hvx_copy_f32_f16_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector one = hvx_vec_splat_f16(1.0); \ + \ + const uint32_t elem_size = sizeof(__fp16); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (i = 0; i < nvec; ++i) { \ + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \ + vdst[i*2] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)); \ + vdst[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)); \ + } \ + \ + if (nloe) { \ + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \ + \ + HVX_Vector vd = Q6_V_lo_W(p); \ + i = 2 * i; \ + \ + if (nloe >= 32) { \ + vdst[i] = Q6_Vsf_equals_Vqf32(vd); \ + nloe -= 32; ++i; vd = Q6_V_hi_W(p); \ + } \ + \ + if (nloe) { \ + vd = Q6_Vsf_equals_Vqf32(vd); \ + hvx_vec_store_u(&vdst[i], nloe * sizeof(float), vd); \ + } \ + } \ + } while(0) + +// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is aligned +static inline void hvx_copy_f32_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is aligned +static inline void hvx_copy_f32_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is unaligned +static inline void hvx_copy_f32_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is unaligned +static inline void hvx_copy_f32_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +#endif // HVX_COPY_H diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-div.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-div.h new file mode 100644 index 0000000..7dae012 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -0,0 +1,116 @@ +#ifndef HVX_DIV_H +#define HVX_DIV_H + +#include <HAP_farf.h> + +#include <math.h> +#include <string.h> +#include <assert.h> +#include <stddef.h> +#include <stdint.h> + +#include "hvx-base.h" +#include "hex-utils.h" +#include "hvx-inverse.h" +#include "hvx-arith.h" + +#if __HVX_ARCH__ < 79 +#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ + HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ + HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \ + } \ + } while(0) + +// 3-letter suffix variants +static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src0 % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src0 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src0 % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src0 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128)) { + if (hex_is_aligned((void *) src0, 128)) { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems); + else hvx_div_f32_aau(dst, src0, src1, num_elems); + } else { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems); + else hvx_div_f32_auu(dst, src0, src1, num_elems); + } + } else { + if (hex_is_aligned((void *) src0, 128)) { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems); + else hvx_div_f32_uau(dst, src0, src1, num_elems); + } else { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems); + else hvx_div_f32_uuu(dst, src0, src1, num_elems); + } + } +} + +#undef HVX_OP_MUL + +#endif // HVX_DIV_H diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h new file mode 100644 index 0000000..85201fc --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h @@ -0,0 +1,129 @@ +#ifndef HVX_DUMP_H +#define HVX_DUMP_H + +#include <HAP_farf.h> + +#include <stdbool.h> +#include <stdint.h> + +#include "hex-utils.h" +#include "hvx-types.h" + +static void hvx_vec_dump_f16_n(char * pref, HVX_Vector v, uint32_t n) { + HVX_VectorAlias u = { .v = v }; + + const uint32_t n0 = n / 16; + const uint32_t n1 = n % 16; + int i = 0; + for (; i < n0; i++) { + hex_dump_f16_line(pref, u.fp16 + (16 * i), 16); + } + if (n1) { + hex_dump_f16_line(pref, u.fp16 + (16 * i), n1); + } +} + +static void hvx_vec_dump_f16(char * pref, HVX_Vector v) { + hvx_vec_dump_f16_n(pref, v, 64); +} + +static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) { + HVX_VectorAlias u = { .v = v }; + + const uint32_t n0 = n / 16; + const uint32_t n1 = n % 16; + int i = 0; + for (; i < n0; i++) { + hex_dump_f32_line(pref, u.fp32 + (16 * i), 16); + } + if (n1) { + hex_dump_f32_line(pref, u.fp32 + (16 * i), n1); + } +} + +static void hvx_vec_dump_f32_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + float d[32]; + } u = { .v = v }; + + FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1], + u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); +} + +static void hvx_vec_dump_f32(char * pref, HVX_Vector v) { + hvx_vec_dump_f32_n(pref, v, 32); +} + +static void hvx_vec_dump_int32(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int32_t d[32]; + } u = { .v = v }; + + for (int i = 0; i < 32 / 16; i++) { + hex_dump_int32_line(pref, u.d + (16 * i), 16); + } +} + +static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int32_t d[32]; + } u = { .v = v }; + + FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12], + u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); +} + +static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int8_t d[128]; + } u = { .v = v }; + + FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60], + u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]); +} + +static void hvx_vec_dump_int8(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int8_t d[128]; + } u = { .v = v }; + + for (int i = 0; i < 128 / 16; i++) { + hex_dump_int8_line(pref, u.d + (16 * i), 16); + } +} + +static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + uint8_t d[128]; + } u = { .v = v }; + + for (int i = 0; i < 128 / 16; i++) { + hex_dump_uint8_line(pref, u.d + (16 * i), 16); + } +} + +static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) { + typedef union { + HVX_Vector v; + int8_t d[128]; + } U; + + U u0 = { .v = v0 }; + U u1 = { .v = v1 }; + + for (int i = 0; i < n; i++) { + if (u0.d[i] != u1.d[i]) { + return false; + } + } + + return true; +} + +#endif /* HVX_DUMP_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h new file mode 100644 index 0000000..44dfe23 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h @@ -0,0 +1,215 @@ +#ifndef HVX_EXP_H +#define HVX_EXP_H + +#include <stdbool.h> +#include <stdint.h> + +#include "hvx-base.h" +#include "hvx-floor.h" + +#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!) +#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!) +#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!) +#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!) +#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!) +#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!) +#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 +#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 +#define EXP_ONE (0x3f800000) // 1.0 +#define EXP_RANGE_R (0x41a00000) // 20.0 +#define EXP_RANGE_L (0xc1a00000) // -20.0 + +static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { + HVX_Vector z_qf32_v; + HVX_Vector x_v; + HVX_Vector x_qf32_v; + HVX_Vector y_v; + HVX_Vector k_v; + HVX_Vector f_v; + HVX_Vector epsilon_v; + HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E); + HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2); + HVX_Vector E_const; + HVX_Vector zero_v = Q6_V_vzero(); + + // exp(x) is approximated as follows: + // f = floor(x/ln(2)) = floor(x*log2(e)) + // epsilon = x - f*ln(2) + // exp(x) = exp(epsilon+f*ln(2)) + // = exp(epsilon)*exp(f*ln(2)) + // = exp(epsilon)*2^f + // + // Since epsilon is close to zero, it can be approximated with its Taylor series: + // exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+... + // Preserving the first eight elements, we get: + // exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7 + // = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2 + + HVX_Vector temp_v = in_vec; + + // Clamp inputs to (-20.0, 20.0) + HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R)); + HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec); + + in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v); + in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v); + + epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec); + epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); + + // f_v is the floating point result and k_v is the integer result + f_v = hvx_vec_floor_f32(epsilon_v); + k_v = hvx_vec_truncate_f32(f_v); + + x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v); + + // x = x - f_v * logn2; + epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2); + x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v); + // normalize before every QFloat's vmpy + x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); + + // z = x * x; + z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); + z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); + + x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); + + // y = E4 + E5 * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_5); + y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); + E_const = Q6_V_vsplat_R(EXP_COEFF_4); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E3 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_3); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E2 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_2); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E1 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_1); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E0 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_0); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = x + y * z; + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = y + 1.0; + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE)); + + // insert exponents + // y = ldexpf(y, k); + // y_v += k_v; // qf32 + // modify exponent + + y_v = Q6_Vsf_equals_Vqf32(y_v); + + // add k_v to the exponent of y_v + HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1); + + y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1); + y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent); + + // exponent cannot be negative; if overflow is detected, result is set to zero + HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent); + + y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN); + + y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v); + + return y_v; +} + +static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) { + const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp); + + HVX_Vector out = hvx_vec_exp_f32(in_vec); + + return Q6_V_vmux_QVV(pred0, inf, out); +} + +static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) { + int left_over = num_elems & (VLEN_FP32 - 1); + int num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == hex_is_aligned((void *) src, VLEN)) || (0 == hex_is_aligned((void *) dst, VLEN))) { + unaligned_addr = 1; + } + // assert((0 == unaligned_addr) || (0 == num_elems_whole)); + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + } + + HVX_Vector vec_out = Q6_V_vzero(); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.02f; // log(INF) + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + if (0 == unaligned_loop) { + HVX_Vector * p_vec_in1 = (HVX_Vector *) src; + HVX_Vector * p_vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_f32(*p_vec_in1++); + *p_vec_out++ = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf); + } else { + *p_vec_out++ = hvx_vec_exp_f32_guard(*p_vec_in1++, max_exp, inf); + } + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); + + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_f32(in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf); + } else { + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(in, max_exp, inf); + } + } + } + + if (left_over > 0) { + const float * srcf = (float *) src + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_f32(in); + + vec_out = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf); + } else { + vec_out = hvx_vec_exp_f32_guard(in, max_exp, inf); + } + + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out); + } +} + +#endif /* HVX_EXP_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h new file mode 100644 index 0000000..6a1bfde --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h @@ -0,0 +1,100 @@ +#ifndef HVX_FLOOR_H +#define HVX_FLOOR_H + +#include <stdbool.h> +#include <stdint.h> + +#include "hvx-base.h" + +#define IEEE_VSF_EXPLEN (8) +#define IEEE_VSF_EXPBIAS (127) +#define IEEE_VSF_EXPMASK (0xFF) +#define IEEE_VSF_MANTLEN (23) +#define IEEE_VSF_MANTMASK (0x7FFFFF) +#define IEEE_VSF_MIMPMASK (0x800000) + +static inline HVX_Vector hvx_vec_truncate_f32(HVX_Vector in_vec) { + HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); + HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); + HVX_Vector const_zero_v = Q6_V_vzero(); + + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); + + HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; + expval_v &= IEEE_VSF_EXPMASK; + expval_v -= IEEE_VSF_EXPBIAS; + + // negative exp == fractional value + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); + + HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift + + HVX_Vector mant_v = in_vec & mask_mant_v; // obtain mantissa + HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0 + + vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer + vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0 + + HVX_Vector neg_vout = -vout; + + vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives + + return (vout); +} + +static inline HVX_Vector hvx_vec_floor_f32(HVX_Vector in_vec) { + HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); + HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); + HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN); + HVX_Vector const_zero_v = Q6_V_vzero(); + HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf + + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); + + HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; + expval_v &= IEEE_VSF_EXPMASK; + expval_v -= IEEE_VSF_EXPBIAS; + + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); + HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v); + HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v); + HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec); + + // if expval < 0 (q_negexp) // <0, floor is 0 + // if vin > 0 + // floor = 0 + // if vin < 0 + // floor = -1 + // if expval < mant_len (q_expltmn) // >0, but fraction may exist + // get sign (q_negative) + // mask >> expval // fraction bits to mask off + // vout = ~(mask) // apply mask to remove fraction + // if (qneg) // negative floor is one less (more, sign bit for neg) + // vout += ((impl_mask) >> expval) + // if (mask && vin) + // vout = vin + // else // already an integer + // ; // no change + + // compute floor + mask_mant_v >>= expval_v; + HVX_Vector neg_addin_v = mask_impl_v >> expval_v; + HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v); + HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec); + + HVX_Vector mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v); // chk if bits set + HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v); + + HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear + HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits + + vout = in_vec; + vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval<mant + vout = Q6_V_vmux_QVV(q_integral, in_vec, vout); // integral values + vout = Q6_V_vmux_QVV(q_negexp_pos, const_zero_v, vout); // expval<0 x>0 -> 0 + vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1 + + return vout; +} + +#endif /* HVX_FLOOR_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h new file mode 100644 index 0000000..49f3efa --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h @@ -0,0 +1,176 @@ +#ifndef HVX_INVERSE_H +#define HVX_INVERSE_H + +#include <HAP_farf.h> + +#include <math.h> +#include <string.h> +#include <assert.h> +#include <stddef.h> +#include <stdint.h> + +#include "hvx-base.h" + +// ==================================================== +// FUNCTION: 1/(x+1) y(0) = 1, y(0.5) = 0.6667, y(1) = 0.5 +// Order:3; continuity: True; Ends forced: True +// Mode: unsigned; Result fractional bits: 14 +// Peak Error: 1.1295e-04 Rms Error: 2.8410e-05 Mean Error: 1.1370e-05 +// 32769 -32706 31252 -10589 +// 32590 -30635 22793 -4493 +// 32066 -27505 16481 -2348 +// 31205 -24054 11849 -1306 + +static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) { + // input is 0..0xffff representing 0.0 .. 1.0 + HVX_Vector p; + p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull); + p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull); + p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull); + p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull); + return p; // signed result, 14 fractional bits +} + +// Find reciprocal of fp16. +// (1) first, convert to fp32, multiplying by 1.0; this is done to +// handle denormals. Ignoring sign and zero, result should be at +// least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000) +// (exponent in range [103,143]) +// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly +// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32 +// (4) convert that to fp16 +// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace +// the result with the max value. +static inline HVX_Vector hvx_vec_inverse_f16(HVX_Vector vals) { + HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF); + HVX_Vector avals = Q6_V_vand_VV(vals, em_mask); + HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(avals, vals); + // is too small to 1/x ? for 'standard' fp16, this would be 0x101 + HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals); + + HVX_VectorPair to_qf32 = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00)); // *1.0 + HVX_Vector to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32)); + HVX_Vector to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32)); + + // bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector + HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9)); + // likewise extract the upper 16 from each, containing the exponents in range 103..142 + HVX_Vector exp_u16 = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0); + //Get exponent in IEEE 32-bit representation + exp_u16 = Q6_Vuh_vlsr_VuhR(exp_u16, 7); + + // so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane + // We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0) + // Use poly to transform to 1/x, with 14 fractional bits + // + HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16); + + HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros + + // Get mantissa for 16-bit represenation + HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF)); + + //Compute Reciprocal Exponent + HVX_Vector exp_recip = + Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1))); + //Convert it for 16-bit representation + exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15)); + exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10); + + //Merge exponent and mantissa for reciprocal + HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip); + // map 'small' inputs to standard largest value 0x7bff + recip = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip); + // add sign back + recip = Q6_V_vandor_VQR(recip, is_neg, 0x80008000); + return recip; +} + +static inline HVX_Vector hvx_vec_inverse_f32(HVX_Vector v_sf) { + HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3); + HVX_Vector two_sf = hvx_vec_splat_f32(2.0); + + // First approximation + HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf); + + HVX_Vector r_qf; + + // Refine + r_qf = Q6_Vqf32_vmpy_VsfVsf( + i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf))))); + r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( + r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); + r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( + r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); + + return Q6_Vsf_equals_Vqf32(r_qf); +} + +static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { + HVX_Vector out = hvx_vec_inverse_f32(v_sf); + + HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); + const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out); + + return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); +} + +#define hvx_inverse_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \ + } \ + if (nloe) { \ + HVX_Vector v = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, v); \ + } \ + } while(0) + +static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) { + if ((unsigned long) dst % 128 == 0) { + if ((unsigned long) src % 128 == 0) { + hvx_inverse_f32_aa(dst, src, num_elems); + } else { + hvx_inverse_f32_au(dst, src, num_elems); + } + } else { + if ((unsigned long) src % 128 == 0) { + hvx_inverse_f32_ua(dst, src, num_elems); + } else { + hvx_inverse_f32_uu(dst, src, num_elems); + } + } +} + +#endif // HVX_INVERSE_H diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h new file mode 100644 index 0000000..1ca7c05 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h @@ -0,0 +1,266 @@ +#ifndef HVX_REDUCE_H +#define HVX_REDUCE_H + +#include <math.h> +#include <stdbool.h> +#include <stdint.h> +#include <assert.h> + +#include "hex-utils.h" +#include "hvx-base.h" +#include "hvx-types.h" + +static inline HVX_Vector hvx_vec_reduce_sum_n_i32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // int32 + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum + width = width << 1; + } + return sum; +} + +static inline HVX_Vector hvx_vec_reduce_sum_i32(HVX_Vector in) { + return hvx_vec_reduce_sum_n_i32(in, 32); +} + +static inline HVX_Vector hvx_vec_reduce_sum_n_qf32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right + sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum + width = width << 1; + } + return sum; +} + +static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) { + return hvx_vec_reduce_sum_n_qf32(in, 32); +} + +#if __HVX_ARCH__ > 75 + +static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { + HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); + HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); + + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16)); + return sum_sf; +} + +static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum + width = width << 1; + } + return sum; +} + +#else + +static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { + HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); + HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); + + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16)); + return Q6_Vsf_equals_Vqf32(sum_qf); +} + +static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum + width = width << 1; + } + return sum; +} + +#endif + +static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) { + return hvx_vec_reduce_sum_n_f32(in, 32); +} + +static inline HVX_Vector hvx_vec_reduce_max_f16(HVX_Vector in) { + unsigned total = 128; // total vec nbytes + unsigned width = 2; // fp16 nbytes + + HVX_Vector _max = in, _max_t; + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max2_f16(HVX_Vector in, HVX_Vector _max) { + unsigned total = 128; // total vec nbytes + unsigned width = 2; // fp32 nbytes + + HVX_Vector _max_t; + + _max = Q6_Vhf_vmax_VhfVhf(in, _max); + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max_f32(HVX_Vector in) { + unsigned total = 128; // total vec nbytes + unsigned width = 4; // fp32 nbytes + + HVX_Vector _max = in, _max_t; + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max2_f32(HVX_Vector in, HVX_Vector _max) { + unsigned total = 128; // total vec nbytes + unsigned width = 4; // fp32 nbytes + + HVX_Vector _max_t; + + _max = Q6_Vsf_vmax_VsfVsf(in, _max); + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +#define hvx_reduce_loop_body(src_type, init_vec, pad_vec, vec_op, reduce_op, scalar_reduce) \ + do { \ + src_type * restrict vsrc = (src_type *) src; \ + HVX_Vector acc = init_vec; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = num_elems / epv; \ + const uint32_t nloe = num_elems % epv; \ + \ + uint32_t i = 0; \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + acc = vec_op(acc, vsrc[i]); \ + } \ + if (nloe) { \ + const float * srcf = (const float *) src + i * epv; \ + HVX_Vector in = *(HVX_UVector *) srcf; \ + HVX_Vector temp = Q6_V_valign_VVR(in, pad_vec, nloe * elem_size); \ + acc = vec_op(acc, temp); \ + } \ + HVX_Vector v = reduce_op(acc); \ + return scalar_reduce(v); \ + } while(0) + +#define HVX_REDUCE_MAX_OP(acc, val) Q6_Vsf_vmax_VsfVsf(acc, val) +#define HVX_REDUCE_SUM_OP(acc, val) Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(acc), val) +#define HVX_SUM_SQ_OP(acc, val) Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(val, val)) +#define HVX_REDUCE_MAX_SCALAR(v) hvx_vec_get_f32(v) +#define HVX_REDUCE_SUM_SCALAR(v) hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(v)) + +// Max variants + +static inline float hvx_reduce_max_f32_a(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]); + assert((unsigned long) src % 128 == 0); + hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR); +} + +static inline float hvx_reduce_max_f32_u(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]); + hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR); +} + +static inline float hvx_reduce_max_f32(const uint8_t * restrict src, const int num_elems) { + if (hex_is_aligned((void *) src, 128)) { + return hvx_reduce_max_f32_a(src, num_elems); + } else { + return hvx_reduce_max_f32_u(src, num_elems); + } +} + +// Sum variants + +static inline float hvx_reduce_sum_f32_a(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + assert((unsigned long) src % 128 == 0); + hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_reduce_sum_f32_u(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_reduce_sum_f32(const uint8_t * restrict src, const int num_elems) { + if (hex_is_aligned((void *) src, 128)) { + return hvx_reduce_sum_f32_a(src, num_elems); + } else { + return hvx_reduce_sum_f32_u(src, num_elems); + } +} + +// Sum of squares variants + +static inline float hvx_sum_of_squares_f32_a(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + assert((uintptr_t) src % 128 == 0); + hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_sum_of_squares_f32_u(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) { + if (hex_is_aligned((void *) src, 128)) { + return hvx_sum_of_squares_f32_a(src, num_elems); + } else { + return hvx_sum_of_squares_f32_u(src, num_elems); + } +} + +#undef hvx_reduce_loop_body +#undef HVX_REDUCE_MAX_OP +#undef HVX_REDUCE_SUM_OP +#undef HVX_REDUCE_MAX_SCALAR +#undef HVX_REDUCE_SUM_SCALAR +#undef HVX_SUM_SQ_OP + +#endif /* HVX_REDUCE_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h new file mode 100644 index 0000000..c65c986 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h @@ -0,0 +1,133 @@ +#ifndef HVX_SCALE_H +#define HVX_SCALE_H + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> + +#include "hvx-base.h" + +#define hvx_scale_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + HVX_Vector vs = hvx_vec_splat_f32(scale); \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; ++i) { \ + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \ + vdst[i] = Q6_Vsf_equals_Vqf32(v); \ + } \ + if (nloe) { \ + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \ + vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \ + } \ + } while(0) + +static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + assert((size_t) dst % 128 == 0); + assert((size_t) src % 128 == 0); + hvx_scale_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_scale_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + assert((size_t) dst % 128 == 0); + hvx_scale_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_scale_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + assert((size_t) src % 128 == 0); + hvx_scale_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + hvx_scale_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + if (((size_t) dst & 127) == 0) { + if (((size_t) src & 127) == 0) { + hvx_scale_f32_aa(dst, src, n, scale); + } else { + hvx_scale_f32_au(dst, src, n, scale); + } + } else { + if (((size_t) src & 127) == 0) { + hvx_scale_f32_ua(dst, src, n, scale); + } else { + hvx_scale_f32_uu(dst, src, n, scale); + } + } +} + +#define hvx_scale_offset_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + HVX_Vector vs = hvx_vec_splat_f32(scale); \ + HVX_Vector vo = hvx_vec_splat_f32(offset); \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; ++i) { \ + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \ + vdst[i] = Q6_Vsf_equals_Vqf32(v); \ + } \ + if (nloe) { \ + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \ + vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \ + } \ + } while(0) + +static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + assert((size_t) dst % 128 == 0); + assert((size_t) src % 128 == 0); + hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_scale_offset_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + assert((size_t) dst % 128 == 0); + hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_scale_offset_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + assert((size_t) src % 128 == 0); + hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + if (((size_t) dst & 127) == 0) { + if (((size_t) src & 127) == 0) { + hvx_scale_offset_f32_aa(dst, src, n, scale, offset); + } else { + hvx_scale_offset_f32_au(dst, src, n, scale, offset); + } + } else { + if (((size_t) src & 127) == 0) { + hvx_scale_offset_f32_ua(dst, src, n, scale, offset); + } else { + hvx_scale_offset_f32_uu(dst, src, n, scale, offset); + } + } +} + +#endif // HVX_SCALE_H diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h new file mode 100644 index 0000000..0951932 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h @@ -0,0 +1,141 @@ +#ifndef HVX_SIGMOID_H +#define HVX_SIGMOID_H + +#include "hvx-base.h" + +#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 +#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 +#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267 +#define FAST_SIGMOID_C3 (0x3f000000) // 0.5 + +static inline HVX_Vector hvx_vec_fast_sigmoid_f32(HVX_Vector v) { + v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); + v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3)); + + HVX_Vector in_int = hvx_vec_truncate_f32(Q6_Vsf_equals_Vqf32(v)); + HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int)); + HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x); + + HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2)); + v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); + + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1)); + v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx); + v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x); + + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1)); + HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1); + v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24); + v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent); + v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24); + + HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1)); + HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4)); + + HVX_Vector res = hvx_vec_inverse_f32(v5); + res = Q6_Vqf32_vmpy_VsfVsf(v3, res); + + return Q6_Vsf_equals_Vqf32(res); +} + +static inline HVX_Vector hvx_vec_fast_sigmoid_f32_guard(HVX_Vector v, + HVX_Vector one, + HVX_Vector max_exp, + HVX_Vector min_exp) { + const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v); + const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp); + + HVX_Vector out = hvx_vec_fast_sigmoid_f32(v); + out = Q6_V_vmux_QVV(pred_max, out, one); + return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); +} + +static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) { + // tanh(x) = 2 * sigmoid(2x) - 1 + HVX_Vector two = hvx_vec_splat_f32(2.0f); + HVX_Vector one = hvx_vec_splat_f32(1.0f); + HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two); + + HVX_Vector max_exp = hvx_vec_splat_f32(87.f); + HVX_Vector min_exp = hvx_vec_splat_f32(-87.f); + + HVX_Vector sig2x = hvx_vec_fast_sigmoid_f32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp); + + HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two); + res = Q6_Vqf32_vsub_Vqf32Vsf(res, one); + return Q6_Vsf_equals_Vqf32(res); +} + +#define hvx_sigmoid_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector one = hvx_vec_splat_f32(1.f); \ + const HVX_Vector max_exp = hvx_vec_splat_f32(87.f); \ + const HVX_Vector min_exp = hvx_vec_splat_f32(-87.f); \ + \ + const uint32_t epv = 128 / sizeof(float); \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \ + } \ + if (nloe) { \ + HVX_Vector tmp = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \ + vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \ + } \ + } while(0) + +#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t epv = 128 / sizeof(float); \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \ + } \ + if (nloe) { \ + HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \ + } \ + } while(0) + +static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sigmoid_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sigmoid_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sigmoid_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_sigmoid_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sigmoid_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +#endif /* HVX_SIGMOID_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h new file mode 100644 index 0000000..e31a100 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h @@ -0,0 +1,126 @@ +#ifndef HVX_SQRT_H +#define HVX_SQRT_H + +#include <stdbool.h> +#include <stdint.h> + +#include "hex-utils.h" + +#include "hvx-base.h" + +#define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation +#define RSQRT_ONE_HALF 0x3f000000 // 0.5 +#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5 + +#if __HVX_ARCH__ < 79 +#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) { + //Algorithm : + // x2 = input*0.5 + // y = * (long *) &input + // y = 0x5f3759df - (y>>1) + // y = y*(threehalfs - x2*y*y) + + HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST); + HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF); + HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES); + + HVX_Vector x2, y, ypower2, temp; + + x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf); + x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero()); + + y = Q6_Vw_vasr_VwR(in_vec, 1); + y = Q6_Vw_vsub_VwVw(rsqrtconst, y); + + // 1st iteration + ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp)); + + // 2nd iteration + y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); + ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); + + // 3rd iteration + y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); + ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); + + return Q6_Vsf_equals_Vqf32(temp); +} + +// Compute sqrt(x) as x*inv_sqrt(x) +#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ + HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ + vdst[i] = sqrt_res; \ + } \ + if (nloe) { \ + HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ + HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \ + } \ + } while(0) + +static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) { + if ((unsigned long) dst % 128 == 0) { + if ((unsigned long) src % 128 == 0) { + hvx_sqrt_f32_aa(dst, src, num_elems); + } else { + hvx_sqrt_f32_au(dst, src, num_elems); + } + } else { + if ((unsigned long) src % 128 == 0) { + hvx_sqrt_f32_ua(dst, src, num_elems); + } else { + hvx_sqrt_f32_uu(dst, src, num_elems); + } + } +} + +#endif /* HVX_SQRT_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-types.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-types.h new file mode 100644 index 0000000..d495a59 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-types.h @@ -0,0 +1,36 @@ +#ifndef HVX_TYPES_H +#define HVX_TYPES_H + +#include <stdbool.h> +#include <stdint.h> + +#include <hexagon_types.h> + +#define SIZEOF_FP32 (4) +#define SIZEOF_FP16 (2) +#define VLEN (128) +#define VLEN_FP32 (VLEN / SIZEOF_FP32) +#define VLEN_FP16 (VLEN / SIZEOF_FP16) + +typedef union { + HVX_Vector v; + uint8_t b[VLEN]; + uint16_t h[VLEN_FP16]; + uint32_t w[VLEN_FP32]; + __fp16 fp16[VLEN_FP16]; + float fp32[VLEN_FP32]; +} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias; + +typedef struct { + HVX_Vector v[2]; +} HVX_Vector_x2; + +typedef struct { + HVX_Vector v[4]; +} HVX_Vector_x4; + +typedef struct { + HVX_Vector v[8]; +} HVX_Vector_x8; + +#endif /* HVX_TYPES_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h new file mode 100644 index 0000000..a518ad3 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -0,0 +1,18 @@ +#ifndef HVX_UTILS_H +#define HVX_UTILS_H + +#include "hex-utils.h" + +#include "hvx-types.h" +#include "hvx-copy.h" +#include "hvx-scale.h" +#include "hvx-exp.h" +#include "hvx-inverse.h" +#include "hvx-reduce.h" +#include "hvx-sigmoid.h" +#include "hvx-sqrt.h" +#include "hvx-arith.h" +#include "hvx-div.h" +#include "hvx-base.h" + +#endif /* HVX_UTILS_H */ diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/main.c b/llama.cpp/ggml/src/ggml-hexagon/htp/main.c new file mode 100644 index 0000000..62708ee --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/main.c @@ -0,0 +1,1150 @@ +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wunused-function" + +#include <HAP_farf.h> +#include <HAP_perf.h> +#include <AEEStdErr.h> +#include <dspqueue.h> +#include <HAP_compute_res.h> +#include <HAP_etm_config.h> +#include <HAP_mem.h> +#include <HAP_power.h> +#include <HAP_ps.h> +#include <qurt.h> +#include <qurt_thread.h> +#include <remote.h> +#include <string.h> + +#include "hex-dma.h" +#include "hex-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "worker-pool.h" + +AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { + struct htp_context * ctx; + int err = 0; + + ctx = calloc(1, sizeof(*ctx)); + if (ctx == NULL) { + return AEE_ENOMEMORY; + } + + // Use the context structure as a handle + *handle = (remote_handle64) ctx; + + // Enable FARF logs + HAP_setFARFRuntimeLoggingParams(0xffff, NULL, 0); + + // Set client class + { + HAP_power_request_t request; + memset(&request, 0, sizeof(HAP_power_request_t)); + request.type = HAP_power_set_apptype; + request.apptype = HAP_POWER_COMPUTE_CLIENT_CLASS; + + if ((err = HAP_power_set((void *) ctx, &request)) != 0) { + return err; + } + } + + { + HAP_power_request_t request; + memset(&request, 0, sizeof(request)); + + request.type = HAP_power_set_DCVS_v3; + request.dcvs_v3.set_dcvs_enable = TRUE; + request.dcvs_v3.dcvs_enable = TRUE; + request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE; + request.dcvs_v3.set_bus_params = TRUE; + request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.bus_params.target_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.set_core_params = TRUE; + request.dcvs_v3.core_params.min_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.core_params.max_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.set_sleep_disable = TRUE; + request.dcvs_v3.sleep_disable = TRUE; + if ((err = HAP_power_set((void *) ctx, &request)) != 0) { + return err; + } + + memset(&request, 0, sizeof(request)); + request.type = HAP_power_set_HVX; + request.hvx.power_up = TRUE; + if ((err = HAP_power_set((void *) ctx, &request)) != 0) { + return err; + } + } + + { + // Power on HMX + HAP_power_request_t request; + memset(&request, 0, sizeof(HAP_power_request_t)); + request.type = HAP_power_set_HMX; + request.hmx.power_up = TRUE; + FARF(ALWAYS, "Powering HMX on\n"); + err = HAP_power_set((void *) &ctx, &request); + if (err != AEE_SUCCESS) { + FARF(ERROR, "Error powering on HMX."); + return err; + } + } + + return AEE_SUCCESS; +} + +AEEResult htp_iface_close(remote_handle64 handle) { + struct htp_context * ctx = (struct htp_context *) handle; + + if (!ctx) { + return AEE_EBADPARM; + } + + if (ctx->queue) { + FARF(ERROR, "Closing handle with queue still open"); + return AEE_EITEMBUSY; + } + + free(ctx); + return AEE_SUCCESS; +} + +AEEResult htp_iface_enable_etm(remote_handle64 handle) { + int err = HAP_user_etm_enable(); + if (err) { + if (err == AEE_EVERSIONNOTSUPPORT) { + FARF(ERROR, "API HAP_user_etm_enable is not supported\n"); + } else { + FARF(ERROR, "Error executing HAP_user_etm_enable with error code : 0x%x\n", err); + } + } + return err; +} + +AEEResult htp_iface_disable_etm(remote_handle64 handle) { + int err = HAP_user_etm_disable(); + if (err) { + if (err == AEE_EVERSIONNOTSUPPORT) { + FARF(ERROR, "API HAP_user_etm_disable is not supported\n"); + } else { + FARF(ERROR, "Error executing HAP_user_etm_disable with error code : 0x%x\n", err); + } + } + return err; +} + +static int vtcm_acquire(struct htp_context * ctx) { + int err; + if (!ctx->vtcm_valid) { + // Temporarily bump thread priority to make sure it's higher than other sessions. + // This way the resource manager will notify the other thread to release VTCM. + // Note that we need to reaquire VTCM at normal priority for this to work next time. + qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10); + err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + if (err != 0) { + FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); + abort(); + } + HAP_compute_res_release_cached(ctx->vtcm_rctx); + qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio); + + err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + if (err != 0) { + FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); + abort(); + } + ctx->vtcm_valid = true; + } + + ctx->vtcm_inuse = true; + return 0; +} + +static int vtcm_release(struct htp_context * ctx) { + ctx->vtcm_inuse = false; + + if (ctx->vtcm_valid && ctx->vtcm_needs_release) { + ctx->vtcm_valid = false; + ctx->vtcm_needs_release = false; + HAP_compute_res_release_cached(ctx->vtcm_rctx); + } + + return 0; +} + +static int vtcm_release_callback(unsigned int rctx, void * state) { + struct htp_context * ctx = (struct htp_context *) state; + + if (!ctx || ctx->vtcm_rctx != rctx) { + return AEE_EBADPARM; + } + + // If VTCM is not inuse (not processing Ops) release it right here + // otherwise we'll release it once we're done with the current Op. + + if (ctx->vtcm_inuse) { + ctx->vtcm_needs_release = false; + return 0; + } + + ctx->vtcm_valid = false; + HAP_compute_res_release_cached(ctx->vtcm_rctx); + + return 0; +} + +static int vtcm_alloc(struct htp_context * ctx) { + unsigned int vtcm_size = 8 * 1024 * 1024; // 8MB default + HAP_compute_res_query_VTCM(0, &vtcm_size, NULL, NULL, NULL); + + compute_res_attr_t attr; + HAP_compute_res_attr_init(&attr); + HAP_compute_res_attr_set_serialize(&attr, 0); + HAP_compute_res_attr_set_cache_mode(&attr, 1); + HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size); + HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx); + HAP_compute_res_attr_set_hmx_param(&attr, 1); + + // Allocate VTCM for scratch pads + uint32_t rctx = HAP_compute_res_acquire(&attr, 1000000 /* timeout */); + if (!rctx) { + FARF(ERROR, "failed to allocate %zu bytes VTCM\n", ctx->vtcm_size); + return AEE_ENOMEMORY; + } + + void * vtcm_ptr; + if (HAP_compute_res_attr_get_vtcm_ptr_v2(&attr, &vtcm_ptr, &vtcm_size) != 0) { + HAP_compute_res_release(rctx); + FARF(ERROR, "failed to allocate %zu bytes VTCM (new)\n", ctx->vtcm_size); + return AEE_ENOMEMORY; + } + + ctx->vtcm_base = (uint8_t *) vtcm_ptr; + ctx->vtcm_size = vtcm_size; + ctx->vtcm_rctx = rctx; + ctx->vtcm_valid = false; + ctx->vtcm_inuse = false; + ctx->vtcm_needs_release = false; + + return 0; +} + +static void vtcm_free(struct htp_context * ctx) { + if (ctx->vtcm_rctx) { + HAP_compute_res_release(ctx->vtcm_rctx); + ctx->vtcm_base = 0; + ctx->vtcm_rctx = 0; + } +} + +static void htp_packet_callback(dspqueue_t queue, int error, void * context); +static void htp_error_callback(dspqueue_t queue, int error, void * context); + +AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) { + struct htp_context * ctx = (struct htp_context *) handle; + + if (!ctx) { + return AEE_EBADPARM; + } + + if (ctx->queue) { + FARF(ERROR, "Queue already open"); + return AEE_EITEMBUSY; + } + + // Import queue created on the CPU + int err = dspqueue_import(dsp_queue_id, // Queue ID from dspqueue_export + htp_packet_callback, // Packet callback + htp_error_callback, // Error callback; no errors expected on the DSP + (void *) ctx, // Callback context + &ctx->queue); + + if (err) { + FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err); + return err; + } + + ctx->thread_id = qurt_thread_get_id(); + ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id); + + // allocate VTCM + err = vtcm_alloc(ctx); + if (err != AEE_SUCCESS) { + FARF(ERROR, "Unable to allocate VTCM"); + return AEE_ENOMEMORY; + } + + qurt_sysenv_max_hthreads_t hw_threads; + qurt_sysenv_get_max_hw_threads(&hw_threads); + uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF; + + if (n_hvx == 0) { + n_hvx = hw_nhvx; + } + if (n_hvx > hw_threads.max_hthreads) { + n_hvx = hw_threads.max_hthreads; + } + if (n_hvx > HTP_MAX_NTHREADS) { + n_hvx = HTP_MAX_NTHREADS; + } + + ctx->n_threads = n_hvx; + for (int i = 0; i < ctx->n_threads; i++) { + // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541 + ctx->dma[i] = dma_queue_create(64); + } + + // init worker pool + err = worker_pool_init(&ctx->worker_pool, n_hvx); + if (err != AEE_SUCCESS) { + FARF(ERROR, "Unable to create worker pool"); + return err; + } + + FARF(HIGH, "session %u started: n-hvx %u vtcm-size %zu vtcm-rctx %u n-threads %u thread-id %d thread-prio %d \n", + sess_id, hw_nhvx, ctx->vtcm_size, ctx->vtcm_rctx, ctx->n_threads, ctx->thread_id, ctx->thread_prio); + + return AEE_SUCCESS; +} + +AEEResult htp_iface_stop(remote_handle64 handle) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } + + if (!ctx->queue) { + FARF(ERROR, "Queue not open"); + return AEE_EBADSTATE; + } + + // Close queue. dspqueue_close() will also wait for callbacks to finish. + int err = dspqueue_close(ctx->queue); + ctx->queue = NULL; + if (err != 0) { + FARF(ERROR, "Queue close failed with 0x%08x", (unsigned) err); + return err; + } + + if (ctx->worker_pool) { + // Release worker pool + worker_pool_release(&ctx->worker_pool); + } + + for (int i = 0; i < ctx->n_threads; i++) { + dma_queue_delete(ctx->dma[i]); + } + + vtcm_free(ctx); + + return AEE_SUCCESS; +} + +static void htp_error_callback(dspqueue_t queue, int error, void * context) { + // No errors expected on the DSP. + FARF(ERROR, "Error callback: 0x%08x", (unsigned) error); +} + +struct profile_data { + uint64_t usecs; + uint64_t cycles; + uint64_t pkts; +}; + +static inline void profile_start(struct profile_data * d) { + d->usecs = HAP_perf_get_qtimer_count(); + d->cycles = hex_get_cycles(); + d->pkts = hex_get_pktcnt(); +} + +static inline void profile_stop(struct profile_data * d) { + d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); + d->cycles = hex_get_cycles() - d->cycles; + d->pkts = hex_get_pktcnt() - d->pkts; +} + +static int send_htp_rsp(struct htp_context * c, + uint32_t op, + uint32_t status, + struct dspqueue_buffer * bufs, + size_t n_bufs, + struct profile_data * prof) { + // Prep response struct + struct htp_general_rsp rsp; + rsp.op = op; + rsp.status = status; + rsp.prof_usecs = prof->usecs; + rsp.prof_cycles = prof->cycles; + rsp.prof_pkts = prof->pkts; + + int err = dspqueue_write(c->queue, + 0, // Flags + n_bufs, + bufs, // Buffer references + sizeof(rsp), + (const uint8_t *) &rsp, // Message + DSPQUEUE_TIMEOUT_NONE); + + if (err != 0) { + FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); + } + + return err; +} + +static void proc_matmul_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + size_t n_bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_matmul(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_argsort(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_cpy(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_get_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_matmul_id_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + size_t n_bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[3].fd; + rsp_bufs[0].ptr = bufs[3].ptr; + rsp_bufs[0].size = bufs[3].size; + rsp_bufs[0].offset = bufs[3].offset; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.src2 = req->src2; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.src2.data = (uint32_t) bufs[2].ptr; + octx.dst.data = (uint32_t) bufs[3].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_matmul_id(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_binary(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[3].fd; + rsp_bufs[0].ptr = bufs[3].ptr; + rsp_bufs[0].offset = bufs[3].offset; + rsp_bufs[0].size = bufs[3].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.src2 = req->src2; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.src2.data = (uint32_t) bufs[2].ptr; + octx.dst.data = (uint32_t) bufs[3].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_binary(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_unary(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_sum_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_activations_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + uint32_t n_bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + int write_idx = (n_bufs == 3) ? 2 : 1; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[write_idx].fd; + rsp_bufs[0].ptr = bufs[write_idx].ptr; + rsp_bufs[0].offset = bufs[write_idx].offset; + rsp_bufs[0].size = bufs[write_idx].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + if (3 == n_bufs) { + octx.src1 = req->src1; + } + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + if (3 == n_bufs) { + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + } else { + octx.dst.data = (uint32_t) bufs[1].ptr; + } + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + if (octx.op == HTP_OP_SOFTMAX) { + rsp_status = op_softmax(&octx); + } else { + rsp_status = op_activations(&octx); + } + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_rope_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + uint32_t n_bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + int write_idx = n_bufs - 1; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[write_idx].fd; + rsp_bufs[0].ptr = bufs[write_idx].ptr; + rsp_bufs[0].offset = bufs[write_idx].offset; + rsp_bufs[0].size = bufs[write_idx].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + if (4 == n_bufs) { + octx.src2 = req->src2; + } + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + if (4 == n_bufs) { + octx.src2.data = (uint32_t) bufs[2].ptr; + octx.dst.data = (uint32_t) bufs[3].ptr; + } else { + octx.dst.data = (uint32_t) bufs[2].ptr; + } + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_rope(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_set_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_flash_attn_ext_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + uint32_t n_bufs) { + // Setup Op context + struct htp_ops_context octx; + memset(&octx, 0, sizeof(octx)); + + octx.ctx = ctx; + octx.n_threads = ctx->n_threads; + + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.src2 = req->src2; + octx.src3 = req->src3; + octx.src4 = req->src4; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.src2.data = (uint32_t) bufs[2].ptr; + + int last_buf = 3; + + if (octx.src3.ne[0]) { + octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid + } + + if (octx.src4.ne[0]) { + octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid + } + + octx.dst.data = (uint32_t) bufs[last_buf].ptr; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_flash_attn_ext(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + + struct dspqueue_buffer rsp_buf = bufs[last_buf]; + rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); +} + +static void htp_packet_callback(dspqueue_t queue, int error, void * context) { + struct htp_context * ctx = (struct htp_context *) context; + + // Repeatedly read packets from the queue until it's empty. We don't + // necessarily get a separate callback for each packet, and new packets + // may arrive while we're processing the previous one. This ensures we + // keep the DSP busy as much as possible and avoid waiting for the CPU. + + while (1) { + struct htp_general_req req; + uint32_t req_size; + + struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; + uint32_t n_bufs; + uint32_t flags; + + // Read packet from queue + int err = dspqueue_read_noblock(queue, &flags, + HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references + &n_bufs, // Number of buffer references + bufs, // Buffer references + sizeof(req), // Max message length + &req_size, // Message length + (uint8_t *) &req); // Message + + if (err == AEE_EWOULDBLOCK) { + // Consumed all packets available for now + return; + } + + if (err != 0) { + FARF(ERROR, "dspqueue_read_noblock failed: 0x%08x", (unsigned) err); + return; + } + + if (req_size != sizeof(req)) { + FARF(ERROR, "Invalid request size"); + continue; + } + + if (req.flags & HTP_OPFLAGS_EARLY_WAKEUP) { + // Host wants early notification + dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); + } + + // Process packet based on its message type + switch (req.op) { + case HTP_OP_MUL_MAT: + if (n_bufs != 3) { + FARF(ERROR, "Bad matmul-req buffer list"); + continue; + } + proc_matmul_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_MUL_MAT_ID: + if (n_bufs != 4) { + FARF(ERROR, "Bad matmul-id-req buffer list"); + continue; + } + proc_matmul_id_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_MUL: + case HTP_OP_ADD: + case HTP_OP_SUB: + case HTP_OP_DIV: + if (n_bufs != 3) { + FARF(ERROR, "Bad binary-req buffer list"); + continue; + } + proc_binary_req(ctx, &req, bufs); + break; + + case HTP_OP_RMS_NORM: + case HTP_OP_SCALE: + if (n_bufs != 2) { + FARF(ERROR, "Bad unary-req buffer list"); + continue; + } + + proc_unary_req(ctx, &req, bufs); + break; + + case HTP_OP_SQR: + case HTP_OP_SQRT: + if (n_bufs != 2) { + FARF(ERROR, "Bad unary-req buffer list"); + continue; + } + + proc_unary_req(ctx, &req, bufs); + break; + + case HTP_OP_SUM_ROWS: + if (n_bufs != 2) { + FARF(ERROR, "Bad unary-req buffer list"); + continue; + } + + proc_sum_rows_req(ctx, &req, bufs); + break; + + case HTP_OP_UNARY_SILU: + case HTP_OP_UNARY_GELU: + if (n_bufs != 2) { + FARF(ERROR, "Bad act-req buffer list"); + continue; + } + proc_activations_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_GLU_SWIGLU: + case HTP_OP_GLU_SWIGLU_OAI: + case HTP_OP_SOFTMAX: + case HTP_OP_GLU_GEGLU: + if ((n_bufs != 2) && (n_bufs != 3)) { + FARF(ERROR, "Bad act-req buffer list"); + continue; + } + proc_activations_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_ADD_ID: + if (n_bufs != 4) { + FARF(ERROR, "Bad add-id-req buffer list"); + continue; + } + proc_add_id_req(ctx, &req, bufs); + break; + + case HTP_OP_ROPE: + if ((n_bufs != 3) && (n_bufs != 4)) { + FARF(ERROR, "Bad rope-req buffer list"); + continue; + } + proc_rope_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_FLASH_ATTN_EXT: + if (!(n_bufs >= 4 && n_bufs <= 6)) { + FARF(ERROR, "Bad flash-attn-ext-req buffer list"); + continue; + } + proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_SET_ROWS: + if (n_bufs != 3) { + FARF(ERROR, "Bad set-rows-req buffer list"); + continue; + } + proc_set_rows_req(ctx, &req, bufs); + break; + + case HTP_OP_GET_ROWS: + if (n_bufs != 3) { + FARF(ERROR, "Bad get-rows-req buffer list"); + continue; + } + proc_get_rows_req(ctx, &req, bufs); + break; + + case HTP_OP_CPY: + if (n_bufs != 2) { + FARF(ERROR, "Bad cpy-req buffer list"); + continue; + } + proc_cpy_req(ctx, &req, bufs); + break; + + case HTP_OP_ARGSORT: + if (n_bufs != 2) { + FARF(ERROR, "Bad argsort-req buffer list"); + continue; + } + proc_argsort_req(ctx, &req, bufs); + break; + + default: + FARF(ERROR, "Unknown Op %u", req.op); + break; + } + } +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c new file mode 100644 index 0000000..c360abe --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -0,0 +1,2665 @@ +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <math.h> +#include <string.h> + +#include "hex-dma.h" +#include "hvx-utils.h" +#include "hvx-dump.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#define MM_SPAD_SRC0_NROWS 16 +#define MM_SPAD_SRC1_NROWS 16 +#define MM_SPAD_DST_NROWS 2 + +struct htp_matmul_context { + const char * type; + struct htp_ops_context * octx; + + void (*vec_dot_1x1)(const int n, float * restrict s0, + const void * restrict vx0, + const void * restrict vy0); + + void (*vec_dot_2x1)(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0); + + void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1); + + // Precomputed values + uint32_t src0_nrows_per_thread; + uint32_t src1_nrows_per_thread; + + struct fastdiv_values mm_div_ne12_ne1; + struct fastdiv_values mm_div_ne1; + struct fastdiv_values mm_div_r2; + struct fastdiv_values mm_div_r3; +}; + +// vdelta control to replicate first 4x fp32 values across lanes +static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, + 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, + 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, + 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, + 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, + 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, +}; + +// vdelta control to replicate and interleave first 8x fp32 values across lanes +static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, + 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, + 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, + 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44, + 0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, + 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, +}; + +// vdelta control to replicate first fp32 value across all elements +static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, + 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, + 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, + 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, + 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, + 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, + 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, +}; + +// vdelta control to replicate first fp16 value across all elements +static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = { + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, + 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, + 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, + 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, + 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, + 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, +}; + +// vdelta control to replicate first fp16 value across all elements +static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = { + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, +}; + +// vdelta control to expand first 32 e8m0 values into 32 uint32 elements +static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { + 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00, + 0x00, 0x11, 0x10, 0x10, 0x10, 0x02, 0x00, 0x04, 0x00, 0x01, 0x02, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, + 0x00, 0x00, 0x22, 0x20, 0x20, 0x20, 0x21, 0x22, 0x20, 0x24, 0x04, 0x00, 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x02, + 0x00, 0x04, 0x00, 0x11, 0x12, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, + 0x01, 0x02, 0x00, 0x04, 0x44, 0x40, 0x40, 0x40, 0x41, 0x40, 0x40, 0x40, 0x42, 0x40, 0x44, 0x40, 0x41, 0x42, 0x48, + 0x48, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x12, 0x10, 0x10, 0x10, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, + 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20, +}; + +static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { + 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0, + 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + +// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales + +static inline size_t q8x4x2_row_size(uint32_t ne) { + // ensures perfect alignment of quants and full row + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (ne + qk - 1) / qk; + return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); +} + +static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + // Convert uint4 to int4 (i.e. x - 8) + v0 = Q6_Vb_vsub_VbVb(v0, i8); + v1 = Q6_Vb_vsub_VbVb(v1, i8); + v2 = Q6_Vb_vsub_VbVb(v2, i8); + v3 = Q6_Vb_vsub_VbVb(v3, i8); + v4 = Q6_Vb_vsub_VbVb(v4, i8); + v5 = Q6_Vb_vsub_VbVb(v5, i8); + v6 = Q6_Vb_vsub_VbVb(v6, i8); + v7 = Q6_Vb_vsub_VbVb(v7, i8); + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); + v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); + v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); + v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); + v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); + v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0 = vptr[0]; // first 128 vals + HVX_Vector v1 = vptr[1]; // ... + HVX_Vector v2 = vptr[2]; // ... + HVX_Vector v3 = vptr[3]; // ... + HVX_Vector v4 = vptr[4]; // ... + HVX_Vector v5 = vptr[5]; // ... + HVX_Vector v6 = vptr[6]; // ... + HVX_Vector v7 = vptr[7]; // ... + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). +// Accumulate each block into a single int32 value. +// Return a single HVX vector with 32x int32 accumulators. +// This version is parameterized to support less than 1024 elements. +// if() checks are optimized out at compile time -- make sure to pass N as a constexpr. + +static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { + HVX_Vector r0 = Q6_V_vsplat_R(0); + HVX_Vector r1 = Q6_V_vsplat_R(0); + HVX_Vector r2 = Q6_V_vsplat_R(0); + HVX_Vector r3 = Q6_V_vsplat_R(0); + HVX_Vector r4 = Q6_V_vsplat_R(0); + HVX_Vector r5 = Q6_V_vsplat_R(0); + HVX_Vector r6 = Q6_V_vsplat_R(0); + HVX_Vector r7 = Q6_V_vsplat_R(0); + + HVX_VectorPair p3; + HVX_VectorPair p2; + HVX_VectorPair p1; + HVX_VectorPair p0; + + if (n >= 128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); } + if (n >= 256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); } + if (n >= 384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); } + if (n >= 512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); } + if (n >= 640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); } + if (n >= 768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); } + if (n >= 896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); } + if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); } + + if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } + if (n >= 384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); } + if (n >= 640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); } + if (n >= 896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); } + + if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } + if (n >= 384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); } + if (n >= 640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); } + if (n >= 896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); } + + if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } + if (n >= 640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); } + + if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } + if (n >= 640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); } + + if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } + if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } + + return r0; +} + +static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) { + return hvx_vec_rmpy_x8_n(x, y, 1024); +} + +// Handle most common cases of tensors not multiple of 1024. +static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { + if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); }; + if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); }; + if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); }; + return hvx_vec_rmpy_x8_n(x, y, 1024); +} + +static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elemements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elemements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q8_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + + // Zero-out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (f32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + + // Zero-out unused values + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_Vector * restrict x = (const HVX_Vector *) vx; + const HVX_Vector * restrict y = (const HVX_Vector *) vy; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); +} + +static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; + + uint32_t nvec = n / VLEN_FP16; + uint32_t nloe = n % VLEN_FP16; + + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = y[i]; + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf); + + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); + rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); + rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1)); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP16; + uint32_t nloe = n % VLEN_FP16; + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_hf = x0[i]; + HVX_Vector r1_hf = x1[i]; + HVX_Vector c0_hf = y0[i]; + HVX_Vector c1_hf = y1[i]; + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); + HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); + HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); + HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); + + HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); + HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); + HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); + HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + + HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]); + + HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); + HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); + HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); + HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); + + HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); + HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); + HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); + HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_UVector * restrict x = (const HVX_UVector *) vx; + const HVX_UVector * restrict y = (const HVX_UVector *) vy; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); +} + +static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x_hf = Q6_V_vand_QV(bmask, x_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + // Convert into fp32 and reduce + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); +} + +#define htp_matmul_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict src1 = &octx->src1; \ + struct htp_tensor * restrict src2 = &octx->src2; \ + struct htp_tensor * restrict dst = &octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne20 = src2->ne[0]; \ + const uint32_t ne21 = src2->ne[1]; \ + const uint32_t ne22 = src2->ne[2]; \ + const uint32_t ne23 = src2->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +#define htp_matmul_preamble \ + struct htp_matmul_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + htp_matmul_tensors_preamble; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; + +// *** matmul with support for 4d tensors and full broadcasting + +static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { + htp_matmul_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const uint32_t nr0 = ne0; + + // This is the size of the rest of the dimensions of the result + const uint32_t nr1 = ne1 * ne2 * ne3; + + // distribute the thread work across the inner or outer loop based on which one is larger + uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + // The number of elements in each chunk + const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + + uint32_t current_chunk = ith; + + const uint32_t ith0 = current_chunk % nchunk0; + const uint32_t ith1 = current_chunk / nchunk0; + + const uint32_t ir0_start = dr0 * ith0; + const uint32_t ir0_end = MIN(ir0_start + dr0, nr0); + + const uint32_t ir1_start = dr1 * ith1; + const uint32_t ir1_end = MIN(ir1_start + dr1, nr1); + + // no work for this thread + if (ir0_start >= ir0_end || ir1_start >= ir1_end) { + return; + } + + // block-tiling attempt + const uint32_t blck_0 = 64; + const uint32_t blck_1 = 64; + + for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { + const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1); + const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1); + const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); + + // broadcast src0 into src1 + const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3); + const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2); + + const uint32_t i1 = i11; + const uint32_t i2 = i12; + const uint32_t i3 = i13; + + const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); + const uint8_t * restrict src1_col = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); + float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); + for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { + const uint8_t * restrict src0_row = src0_base + ir0 * nb01; + mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col); + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], + src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// src1 tensor is already in VTCM spad +static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { + htp_matmul_preamble; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const size_t dst_row_size = nb1; + const size_t src0_row_size = nb01; + const size_t src1_row_size = nb11; + + const size_t src0_stride = src0_spad->stride; + const size_t src1_stride = src1_spad->stride; + + // Per-thread VTCM scratchpads for all tensors + // Note that the entire src1 tensor is already in VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; + uint8_t * restrict src1_data = src1_spad->data; + + volatile uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; + + // Prefill spad with src0 rows + #pragma unroll(4) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + + // Process src1 columns in pairs (2×2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); + float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); + } + + // Handle remaining src1 rows (fallback to 2×1) + for (; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); + float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col); + } + + // Prefetch next (n + spad_nrows) row + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const int is0 = (ir0 - src0_start_row); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + + #pragma unroll(2) + for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); + float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], + src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// q8x4x2 src1 tensor is already in VTCM spad +static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { + htp_matmul_preamble; + + const uint32_t src0_nrows = ne01; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const size_t dst_row_size = nb1; + const size_t src0_row_size = nb01; + const size_t src1_row_size = nb11; + + const size_t src0_stride = src0_spad->stride; + const size_t src1_stride = src1_spad->stride; + + // Per-thread VTCM scratchpads for all tensors + // Note that the entire src1 tensor is already in VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + uint8_t * spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; + uint8_t * src1_data = src1_spad->data; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + float * tmp = (float *) spad_dst; + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; + const uint8_t * restrict src1_col = (const uint8_t *) src1_data; + float * restrict dst_col = (float *) dst->data; + + // Prefill spad with 2x src0 rows + #pragma unroll(2) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + const uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + } + + hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], + src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)] + +struct mmid_row_mapping { + uint32_t i1; + uint32_t i2; +}; + +// src1 tensor is already in VTCM spad +static void matmul_id(unsigned int nth, unsigned int ith, void * data) { + htp_matmul_preamble; + + struct htp_tensor * restrict ids = &octx->src2; + struct htp_spad * restrict src2_spad = &octx->src2_spad; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t src0_nrows = ne01; // src0 rows per expert + const uint32_t src1_nrows = ne11; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const uint32_t n_ids = ids->ne[0]; // n_expert_used + const uint32_t n_as = ne02; // n_expert + + const size_t matrix_row_counts_size = n_as * sizeof(uint32_t); + const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); + + const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0; + const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size; + + const size_t dst_row_size = nb1; + const size_t src0_row_size = nb01; + const size_t src1_row_size = q8x4x2_row_size(ne10); + + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + // Per-thread VTCM scratchpads for all tensors + // Note that the entire src1 tensor is already in VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; + uint8_t * restrict src1_data = src1_spad->data; + + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0); + + // Prefill spad with src0 rows + #pragma unroll(4) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), + src0_row_size_padded, src0_row_size, 2); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + + for (uint32_t cid = 0; cid < cne1; ++cid) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); + const int rm1 = row_mapping.i1; // expert idx + const int rm2 = row_mapping.i2; // token idx + + const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); + + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); + } + + // Prefetch next (n + spad_nrows) row + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), + src0_row_size_padded, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), + src0_row_size_padded, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + + for (uint32_t cid = 0; cid < cne1; ++cid) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); + const int rm1 = row_mapping.i1; // expert idx + const int rm2 = row_mapping.i2; // token idx + + const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); + + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], + src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// src1 tensor is already in VTCM spad +static void matvec_id(unsigned int nth, unsigned int ith, void * data) { + htp_matmul_preamble; + + struct htp_tensor * restrict ids = &octx->src2; + struct htp_spad * restrict src2_spad = &octx->src2_spad; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t src0_nrows = ne01; // src0 rows per expert + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + assert(ne13 % ne03 == 0); + + const size_t dst_row_size = nb1; + const size_t src0_row_size = nb01; + const size_t src1_row_size = q8x4x2_row_size(ne10); + + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + const uint32_t n_aids = src2->ne[0]; // num activated experts + const uint32_t n_ids = ne02; // num experts + + // Per-thread VTCM scratchpads for all tensors + // Note that the entire src1 tensor is already in VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; + uint8_t * restrict src1_data = src1_spad->data; + + for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) { // for each expert + const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]); + assert(eid < n_ids); + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02; + const uint8_t * restrict src1_col = (const uint8_t *) src1_data; + float * restrict dst_row = (float *) (dst->data + ie1 * nb1); + + // Prefill spad with src0 rows + #pragma unroll(4) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), + src0_row_size_padded, src0_row_size, 2); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); + + // Prefetch next (n + spad_nrows) row + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), + src0_row_size_padded, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), + src0_row_size_padded, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], + src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], + dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// *** dynamic quant + +static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vsplat_R(0); + + // Use reduce max fp32 to find max(abs(e)) first + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + // Load and convert into QF32 + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + + // Convert to QF32 + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + // Combine and convert to fp16 + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + // Convert into fp16 + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + // Replicate first fp16 scale across all lanes + HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16; + vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl); + vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + hvx_vec_store_u(y_d + 0, 2, vd01_hf); + HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64); + hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf); + + hvx_vec_store_u(y_d + 4, 2, vd23_hf); + rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64); + hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf); + + // Divide input by the scale + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; +} + +static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + + // Load and convert into QF32 + HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + + // Convert into fp16 + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + // Compute max and scale + HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); + HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); + + // Replicate first fp16 scale across all lanes + HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16; + vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl); + vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + hvx_vec_store_u(y_d + 0, 4, vd01_hf); + hvx_vec_store_u(y_d + 4, 4, vd23_hf); + + // Divide input by the scale + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; +} + +static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + + // Load and convert into QF32 + HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + + // Convert into fp16 + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + // Compute max and scale + HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); + vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); + + // Replicate first fp16 scale across all lanes + HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16; + vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl); + + HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); + + *(HVX_UVector *) y_d = vd_hf; + + // Divide input by the scale + HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; +} + +// Overrides input x +static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (k + qk - 1) / qk; + + const uint32_t qrow_size = k; // int8 + + const uint32_t dblk_size = 8 * 2; // 8x __fp16 + const uint32_t qblk_size = QK_Q8_0x4x2; // int8 + + uint8_t * restrict y_q = (y + 0); // quants first + uint8_t * restrict y_d = (y + qrow_size); // then scales + + // Temp scales override input since we're working off of the aligned temp buffer in VTCM + uint8_t * restrict t_d = (uint8_t *) x; + + for (uint32_t i = 0; i < nb; i++) { +#if FP32_QUANTIZE_GROUP_SIZE == 32 + quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); +#elif FP32_QUANTIZE_GROUP_SIZE == 64 + quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); +#elif FP32_QUANTIZE_GROUP_SIZE == 128 + quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); +#else +#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" +#endif + } + + // now copy the scales into final location + hvx_copy_f16_ua(y_d, t_d, nb * 8); +} + +static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = src->nb[1]; + const size_t dst_row_size = q8x4x2_row_size(ne0); + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); + uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); + + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); + memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + // FARF(HIGH, "quantize-q8x4-row: %u\n", i); + quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// TODO just a plain copy that should be done via the DMA during the Op setup +static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + + +static inline bool htp_is_permuted(const struct htp_tensor * t) { + return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; +} + +static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) { + switch (type) { + case HTP_TYPE_Q4_0: + mmctx->type = "q4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; + return 0; + case HTP_TYPE_Q8_0: + mmctx->type = "q8x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; + return 0; + case HTP_TYPE_MXFP4: + mmctx->type = "mxfp4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; + return 0; + default: + return -1; + } +} + +static void htp_mminit_spad(struct htp_ops_context * octx, + size_t dst_row_size, + size_t src0_row_size_padded, + size_t src1_row_size, + uint32_t src1_nrows, + size_t src2_spad_size_per_thread) { + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + if (src2_spad_size_per_thread > 0) { + octx->src2_spad.size_per_thread = src2_spad_size_per_thread; + octx->src2_spad.size = octx->src2_spad.size_per_thread; + } + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; +} + +int op_matmul(struct htp_ops_context * octx) { + htp_matmul_tensors_preamble; + + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src1_nrows = ne11 * ne12 * ne13; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + size_t src1_row_size = nb11; + + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + size_t src1_row_size_padded; + + worker_callback_t quant_job_func; + worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; + + bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE); + + if (src0->type == HTP_TYPE_F16) { + // Try optimized f16-f16 path first (src1 in VTCM) + const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); + const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); + const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; + + const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; + + // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). + // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); + + if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; + + src1_row_size = f16_src1_row_size; // row size post quantization + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + } else { + // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required + quant_job_func = NULL; + if (src1->type == HTP_TYPE_F32) { + mmctx->type = "f16-f32"; + mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; + matmul_job_func = matmul_4d; + } else { + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; + matmul_job_func = matmul_4d; + } + + src1_row_size = nb11; // original row size in DDR + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + + need_quant = false; + } + } else { + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; + } + + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); + } + + // VTCM scratchpads for all tensors + size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; + + FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size); + + FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], + src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], + dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, + octx->ctx->vtcm_size, spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + octx->src0_spad.stride = src0_row_size_padded; + octx->src1_spad.stride = src1_row_size; + + if (need_quant) { + const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + } + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); + } + + return HTP_STATUS_OK; +} + +int op_matmul_id(struct htp_ops_context * octx) { + htp_matmul_tensors_preamble; + + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + + struct htp_tensor * restrict ids = &octx->src2; + + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + const uint32_t src0_nrows = ne01; // per expert + const uint32_t src1_nrows = ne11 * ne12 * ne13; + + worker_callback_t quant_job_func; + worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + + size_t src1_row_size; + size_t src1_row_size_padded; + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + size_t matrix_row_counts_size = n_as * sizeof(uint32_t); + size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); + + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; + } + + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + + const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); + + size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; + + FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, + octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size); + + FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, + src1->data, dst->data); + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; + + octx->src0_spad.stride = src0_row_size_padded; + octx->src1_spad.stride = src1_row_size; + + if (src1_nrows > 1) { + // initialize matrix_row_counts and map + uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0; + struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size; + + memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); + + // group rows by src0 matrix + for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx + for (uint32_t id = 0; id < n_ids; ++id) { // expert idx + const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + assert(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 }; + matrix_row_counts[i02] += 1; + } + } + } + + // Setup worker pool callbacks + if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) { + const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + } + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + } + + return HTP_STATUS_OK; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/rope-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/rope-ops.c new file mode 100644 index 0000000..943ca5c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -0,0 +1,480 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <math.h> +#include <string.h> + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h +#define HTP_ROPE_TYPE_NORMAL 0 +#define HTP_ROPE_TYPE_NEOX 2 + +#define htp_rope_preamble \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct rope_th_ctx { + int32_t n_dims; + int32_t mode; + int32_t n_ctx_orig; + int32_t sections[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + float theta_scale; + float corr_dims[2]; + + struct htp_ops_context * octx; +}; + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + + return (1 - MIN(1, MAX(0, y))); +} + +static void rope_cache_init(const float theta_base, + const float freq_scale, + const float * freq_factors, + float * corr_dims, + const uint32_t ne0, + const float ext_factor, + const float mscale, + float * cache, + const float theta_scale) { + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py + float theta = theta_base; + + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + + float theta_extrap = theta / ff; + + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta_final = theta_interp; + float mscale_final = mscale; + + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + + cache[i0 + 0] = cosf(theta_final) * mscale_final; + cache[i0 + 1] = sinf(theta_final) * mscale_final; + + theta *= theta_scale; + } +} + +#define M_PI 3.1415926535897932384626433 + +static void rope_corr_dims(int n_dims, + int n_ctx_orig, + float freq_base, + float beta_fast, + float beta_slow, + float * dims) { + float start = floorf(n_dims * logf(n_ctx_orig / (beta_fast * 2 * (float) M_PI)) / (2 * logf(freq_base))); + float end = ceilf(n_dims * logf(n_ctx_orig / (beta_slow * 2 * (float) M_PI)) / (2 * logf(freq_base))); + dims[0] = MAX(0, start); + dims[1] = MIN(n_dims - 1, end); +} + +static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) { + memset(rope_ctx, 0, sizeof(struct rope_th_ctx)); + + const int32_t * op_params = &octx->op_params[0]; + + rope_ctx->n_dims = ((const int32_t *) op_params)[1]; + rope_ctx->mode = ((const int32_t *) op_params)[2]; + rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4]; + + memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float)); + memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float)); + memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float)); + memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float)); + memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float)); + memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float)); + memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4); + + rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims); + + rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast, + rope_ctx->beta_slow, rope_ctx->corr_dims); + + rope_ctx->octx = octx; + FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims, + rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor); +} + +static void hvx_calc_rope_neox_f32(const float * restrict src0, + float * restrict dst, + const int num_elems, + const float * restrict theta_cache) { + // for (int i = 0; i < num_elems; i += 2) { + //const float cos_theta = theta_cache[i + 0]; + //const float sin_theta = theta_cache[i + 1]; + + //const float x0 = src[0]; + //const float x1 = src[num_elems/2]; + + //dst[0] = x0*cos_theta - x1*sin_theta; + //dst[num_elems/2] = x0*sin_theta + x1*cos_theta; + + //src += 1; + //dst += 1; + // } + + const uint8_t * restrict src0_curr = (const uint8_t *) src0; + const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; + uint8_t * restrict dst_curr = (uint8_t *) dst; + + int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once + int half_size = (sizeof(float) * (num_elems / 2)); + + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v0 = *(HVX_Vector *) src0_curr; + HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size); + + HVX_Vector v2 = *(HVX_Vector *) theta_curr; + HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); + *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5); + + src0_curr += VLEN; + theta_curr += 2 * VLEN; + dst_curr += VLEN; + } +} + +static void hvx_calc_rope_f32(const float * restrict src0, + float * restrict dst, + const int num_elems, + const float * restrict theta_cache) { + // for (int i = 0; i < num_elems; i += 2) { + //const float cos_theta = theta_cache[i + 0]; + //const float sin_theta = theta_cache[i + 1]; + + //const float x0 = src[0]; + //const float x1 = src[1]; + + //dst[0] = x0*cos_theta - x1*sin_theta; + //dst[1] = x0*sin_theta + x1*cos_theta; + + //src += 2; + //dst += 2; + // } + + const uint8_t * restrict src0_curr = (const uint8_t *) src0; + const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; + uint8_t * restrict dst_curr = (uint8_t *) dst; + + int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once + + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v0 = *(HVX_Vector *) src0_curr; + HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN); + + HVX_Vector v2 = *(HVX_Vector *) theta_curr; + HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1 + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); + + *(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore); + *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore); + + src0_curr += 2 * VLEN; + theta_curr += 2 * VLEN; + dst_curr += 2 * VLEN; + } +} + +static void rope_hex_f32(struct rope_th_ctx * rope_ctx, + const uint32_t ir0, + const uint32_t ir1, + int nth, + int ith, + const int opt_path) { + struct htp_ops_context * octx = rope_ctx->octx; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src2 = &octx->src2; + struct htp_tensor * dst = &octx->dst; + + const int32_t mode = rope_ctx->mode; + const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; + + htp_rope_preamble; + + const int32_t * pos = (const int32_t *) src1->data; + + float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01)); + + const float * freq_factors = NULL; + if (src2 != NULL) { + freq_factors = (const float *) src2->data; + } + + const uint32_t i1_end = MIN(ir1, ne1); + const int32_t half_dims = rope_ctx->n_dims / 2; + const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float); + for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch + for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len + const int32_t p = pos[i2]; + + rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor, + rope_ctx->attn_factor, wp0, rope_ctx->theta_scale); + + for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads + const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01); + float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1); + + const float * src_loc = src; + float * dst_data_loc = dst_data; + + if (1 == opt_path) { + if (is_neox) { + hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); + } else { + hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); + } + + src_loc += rope_ctx->n_dims; + dst_data_loc += rope_ctx->n_dims; + } else { + for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { + const float cos_theta = wp0[i0 + 0]; + const float sin_theta = wp0[i0 + 1]; + + if (is_neox) { + const float x0 = src_loc[0]; + const float x1 = src_loc[half_dims]; + + dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; + dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta; + + src_loc += 1; + dst_data_loc += 1; + } else { + const float x0 = src_loc[0]; + const float x1 = src_loc[1]; + + dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; + dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; + + src_loc += 2; + dst_data_loc += 2; + } + } + + src_loc += (is_neox ? half_dims : 0); + dst_data_loc += (is_neox ? half_dims : 0); + } + + // TODO: use simd to speed up the remaining elements copy + memcpy(dst_data_loc, src_loc, remain_bytes); + } + } + } +} + +static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) { + struct htp_ops_context * octx = rope_ctx->octx; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + htp_rope_preamble; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + int is_aligned = 1; + int opt_path = 0; + if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) || + (0 == hex_is_aligned((void *) dst->data, VLEN))) { + FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n"); + is_aligned = 0; + } + if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { + struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data; + + rope_job_f32_per_thread(rope_ctx, n, i); +} + +static int execute_op_rope_f32(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src2 = &octx->src2; + struct htp_tensor * dst = &octx->dst; + + worker_callback_t op_func; + const char * op_type = NULL; + + struct rope_th_ctx rope_ctx; + + switch (octx->op) { + case HTP_OP_ROPE: + op_func = rope_job_dispatcher_f32; + op_type = "rope-f32"; + + init_rope_ctx(&rope_ctx, octx); + break; + + default: + FARF(ERROR, "Unsupported Op %u\n", octx->op); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t n_threads = octx->n_threads; + + const size_t src0_row_size = src0->nb[1]; + const size_t src1_row_size = src0_row_size; + const size_t dst_row_size = dst->nb[1]; + + // VTCM scratchpads for all tensors + // N rows per thread, padded to HVX vector size + octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; + octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; + octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + + size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + + if (src2->ne[0]) { + FARF(HIGH, + "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u " + "dst-spad-size %u\n", + op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], + src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], + dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + } else { + FARF(HIGH, + "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], + src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, + octx->dst_spad.size); + } + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, + spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t n_jobs = MIN(n_threads, src0_nrows); + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs); + } + + return err; +} + +int op_rope(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src0.type) { + case HTP_TYPE_F32: + err = execute_op_rope_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c new file mode 100644 index 0000000..904484d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -0,0 +1,164 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <math.h> +#include <string.h> + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#define set_rows_preamble \ + const uint32_t ne00 = octx->src0.ne[0]; \ + const uint32_t ne01 = octx->src0.ne[1]; \ + const uint32_t ne02 = octx->src0.ne[2]; \ + const uint32_t ne03 = octx->src0.ne[3]; \ + \ + const uint32_t ne10 = octx->src1.ne[0]; \ + const uint32_t ne11 = octx->src1.ne[1]; \ + const uint32_t ne12 = octx->src1.ne[2]; \ + \ + const uint32_t nb01 = octx->src0.nb[1]; \ + const uint32_t nb02 = octx->src0.nb[2]; \ + const uint32_t nb03 = octx->src0.nb[3]; \ + \ + const uint32_t nb10 = octx->src1.nb[0]; \ + const uint32_t nb11 = octx->src1.nb[1]; \ + const uint32_t nb12 = octx->src1.nb[2]; \ + \ + const uint32_t nb1 = octx->dst.nb[1]; \ + const uint32_t nb2 = octx->dst.nb[2]; \ + const uint32_t nb3 = octx->dst.nb[3]; \ + \ + const uint32_t ne1 = octx->dst.ne[1]; \ + \ + const uint32_t nr = ne01; + +static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { + set_rows_preamble; + + // parallelize by rows of src0 + const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i03 = 0; i03 < ne03; ++i03) { + for (uint32_t i02 = 0; i02 < ne02; ++i02) { + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i10 = i; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + if (i1 >= ne1) { + // ignore invalid indices + continue; + } + + const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; + const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + + // copy row + hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + } + } + } + + return HTP_STATUS_OK; +} + +static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) { + set_rows_preamble; + + // parallelize by rows of src0 + const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i03 = 0; i03 < ne03; ++i03) { + for (uint32_t i02 = 0; i02 < ne02; ++i02) { + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i10 = i; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + if (i1 >= ne1) { + // ignore invalid indices + continue; + } + + const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; + uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + + hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00); + } + } + } + + return HTP_STATUS_OK; +} + +static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) { + set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i); +} + +static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { + set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); +} + +int op_set_rows(struct htp_ops_context * octx) { + set_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + octx->set_rows_div_ne12 = init_fastdiv_values(ne12); + octx->set_rows_div_ne11 = init_fastdiv_values(ne11); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + switch(octx->dst.type) { + case HTP_TYPE_F32: + worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs); + break; + case HTP_TYPE_F16: + worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs); + break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + return HTP_STATUS_OK; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c new file mode 100644 index 0000000..e91a16d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -0,0 +1,395 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <math.h> +#include <string.h> + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#define htp_softmax_preamble3 \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \ + const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \ + const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \ + const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \ + \ + const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \ + const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \ + const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \ + const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct softmax_th_ctx { + bool use_f16; + bool use_src1; + uint32_t n_head; + uint32_t n_head_log2; + + float scale; + float max_bias; + float m0; + float m1; + + struct htp_ops_context * octx; +}; + +static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + + memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx)); + + memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float)); + memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); + + softmax_ctx->n_head = src0->ne[2]; + softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head)); + + softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2); + softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2); + + softmax_ctx->use_src1 = (src1->ne[0] != 0); + softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); + + softmax_ctx->octx = octx; +} + +static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + const int num_elems, + float scale, + const uint8_t * restrict mask, + float slope) { + const uint8_t * restrict src_curr = src; + uint8_t * restrict dst_curr = dst; + const uint8_t * restrict mask_curr = mask; + + HVX_Vector scale_vec = hvx_vec_splat_f32(scale); + HVX_Vector slope_vec = hvx_vec_splat_f32(slope); + + int step_of_1 = num_elems >> 5; + + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = *(HVX_Vector *) src_curr; + + HVX_Vector v3 = *(HVX_Vector *) mask_curr; + + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec); + + HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v3, slope_vec); + + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, v4); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v5); + + src_curr += VLEN; + dst_curr += VLEN; + mask_curr += VLEN; + } +} + +static void hvx_fast_softmax_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems) { + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_pad = (HVX_Vector *) pad; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000); + HVX_Vector max_vec = hvx_vec_splat_f32(((const float *) src)[0]); + HVX_Vector zero_v = Q6_V_vzero(); + HVX_Vector one_v = hvx_vec_splat_f32(1.0); + + int step_of_1 = num_elems >> 5; + + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = v_src[i]; + max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1); + } + + HVX_Vector v = hvx_vec_reduce_max_f32(max_vec); + max_vec = hvx_vec_repl4(v); + + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec); + + HVX_Vector v3 = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(v2)); + + sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3); + + v_pad[i] = v3; + } + + v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); + sum_vec = hvx_vec_repl4(v); + + HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v); + HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec); + HVX_Vector scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v); + + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = v_pad[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec); + v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + } +} + +static float hvx_softmax_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict spad, + const int num_elems, + const float max) { + hvx_sub_scalar_f32(spad, src, max, num_elems); + + hvx_exp_f32(spad, dst, num_elems, false); + + float sum = hvx_reduce_sum_f32(dst, num_elems); + + return sum; +} + +static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) { + struct htp_ops_context * octx = softmax_ctx->octx; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * dst = &octx->dst; + + htp_softmax_preamble3; + + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1); + + float * wp0 = (float *) src0_spad_data; + float * wp1 = (float *) src1_spad_data; + float * wp2 = (float *) dst_spad_data; + + for (uint32_t i03 = 0; i03 < ne03; i03++) { + for (uint32_t i02 = 0; i02 < ne02; i02++) { + for (uint32_t i01 = ith; i01 < ne01; i01 += nth) { + const uint32_t i11 = i01; + const uint32_t i12 = i02 % ne12; + const uint32_t i13 = i03 % ne13; + + // ALiBi + const uint32_t h = i02; // head + + const float slope = (softmax_ctx->max_bias > 0.0f) ? + h < softmax_ctx->n_head_log2 ? + powf(softmax_ctx->m0, h + 1) : + powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) : + 1.0f; + + float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03); + float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3); + + // broadcast the mask across rows + __fp16 * mp_f16 = (softmax_ctx->use_src1) ? + (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : + NULL; + float * mp_f32 = (softmax_ctx->use_src1) ? + (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : + NULL; + + if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) { + hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale, + (const uint8_t *) mp_f32, slope); + } else { + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale); + if (mp_f32) { + if (softmax_ctx->use_f16) { + for (int i = 0; i < ne00; ++i) { + wp0[i] += slope * (float) mp_f16[i]; + } + } else { + for (int i = 0; i < ne00; ++i) { + wp0[i] += slope * mp_f32[i]; + } + } + } + } + + if (1 == opt_path) { + hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); + } else { + float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); + float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); + sum = sum > 0.0 ? (1.0 / sum) : 1; + hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); + } + } + } + } +} + +static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) { + struct htp_ops_context * octx = softmax_ctx->octx; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + htp_softmax_preamble3; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + int is_aligned = 1; + int opt_path = 0; + if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) { + is_aligned = 0; + FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n"); + } + if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + softmax_htp_f32(nth, ith, softmax_ctx, opt_path); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, + ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) { + struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data; + softmax_job_f32_per_thread(p_softmax_ctx, n, i); +} + +static int execute_op_softmax_f32(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + worker_callback_t op_func; + const char * op_type = NULL; + + struct softmax_th_ctx softmax_ctx; + + switch (octx->op) { + case HTP_OP_SOFTMAX: + op_func = softmax_job_dispatcher_f32; + op_type = "softmax-f32"; + + init_softmax_ctx(&softmax_ctx, octx); + break; + + default: + FARF(ERROR, "Unsupported Op %u\n", octx->op); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t n_threads = octx->n_threads; + + const size_t src0_row_size = src0->nb[1]; + const size_t src1_row_size = src0_row_size; + const size_t dst_row_size = dst->nb[1]; + + // VTCM scratchpads for all tensors + // N rows per thread, padded to HVX vector size + octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; + octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; + octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + + size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + + if (src1->ne[0]) { + FARF(HIGH, + "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], + src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, + octx->dst_spad.size); + } else { + FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + } + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, + spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t n_jobs = MIN(n_threads, src0_nrows); + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs); + } + + return err; +} + +int op_softmax(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src0.type) { + case HTP_TYPE_F32: + err = execute_op_softmax_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c new file mode 100644 index 0000000..62e45da --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -0,0 +1,115 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <string.h> +#include <math.h> + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + + +#define sum_rows_preamble \ + struct htp_tensor *src0 = &octx->src0;\ + struct htp_tensor *dst = &octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + +static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) { + sum_rows_preamble; + + const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return HTP_STATUS_OK; + } + + int opt_path = 0; + if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + const uint8_t * restrict data_src = (const uint8_t *) src0->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); + float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); + + for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) { + const float * restrict src_local = src_th + (ir * ne00); + + if (ir + 1 < src0_nrows_per_thread) { + hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1); + } + + if (1 == opt_path) { + dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00); + } else { + dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00); + } + } + + return HTP_STATUS_OK; +} + +static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) { + sum_rows_thread_f32((struct htp_ops_context *) data, n, i); +} + +int op_sum_rows(struct htp_ops_context * octx) { + sum_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const int n_threads = octx->n_threads; + const uint32_t src0_nrows = ne01 * ne02 * ne03; + + uint32_t n_jobs = MIN(n_threads, src0_nrows); + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs); + + return HTP_STATUS_OK; +} + diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c new file mode 100644 index 0000000..ce879bf --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -0,0 +1,342 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <math.h> +#include <string.h> + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#define htp_unary_preamble \ + const uint32_t ne00 = src->ne[0]; \ + const uint32_t ne01 = src->ne[1]; \ + const uint32_t ne02 = src->ne[2]; \ + const uint32_t ne03 = src->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src->nb[0]; \ + const uint32_t nb01 = src->nb[1]; \ + const uint32_t nb02 = src->nb[2]; \ + const uint32_t nb03 = src->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems, + float epsilon) { + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + int step_of_1 = num_elems >> 5; + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + sum_v = hvx_vec_repl4(reduced_sum); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); + HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); + HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); + + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); + + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + } +} + +static void scale_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + float scale = 0.f; + float bias = 0.f; + memcpy(&scale, &op_params[0], sizeof(float)); + memcpy(&bias, &op_params[1], sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); + } + + hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); + } +} + +static void rms_norm_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); + } + + if (1 == opt_path) { + hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); + } else { + float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems); + + const float mean = sum / row_elems; + const float scale = 1.0f / sqrtf(mean + epsilon); + + hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale); + } + } +} + +static void sqr_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); + } + + if (1 == opt_path) { + hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } else { + hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } + } +} + +static void sqrt_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); + } + + if (1 == opt_path) { + hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } else { + hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } + } +} + +static void unary_job_f32_per_thread(const struct htp_tensor * src, + struct htp_tensor * dst, + uint8_t * spad, + int htp_op, + int32_t * op_params, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread) { + htp_unary_preamble; + + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + int is_aligned = 1; + int opt_path = 0; + if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) { + is_aligned = 0; + } + if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + const uint8_t * restrict data_src = (const uint8_t *) src->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); + float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); + uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01); + + switch (htp_op) { + case HTP_OP_RMS_NORM: + rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; + case HTP_OP_SCALE: + scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; + case HTP_OP_SQR: + sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; + case HTP_OP_SQRT: + sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; + + default: + break; + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0], + src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2], + dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + + unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i, + octx->src0_nrows_per_thread); +} + +static int execute_op_unary_f32(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + const struct htp_tensor * src0 = &octx->src0; + struct htp_tensor * dst = &octx->dst; + + worker_callback_t unary_op_func; + const char * op_type = NULL; + + switch (octx->op) { + case HTP_OP_RMS_NORM: + unary_op_func = unary_job_dispatcher_f32; + op_type = "rmsnorm-f32"; + break; + case HTP_OP_SCALE: + unary_op_func = unary_job_dispatcher_f32; + op_type = "scale-f32"; + break; + case HTP_OP_SQR: + unary_op_func = unary_job_dispatcher_f32; + op_type = "sqr-f32"; + break; + case HTP_OP_SQRT: + unary_op_func = unary_job_dispatcher_f32; + op_type = "sqrt-f32"; + break; + + default: + FARF(ERROR, "Unsupported unary Op %u\n", octx->op); + return HTP_STATUS_NO_SUPPORT; + } + + const int n_threads = octx->n_threads; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + + const size_t src0_row_size = src0->nb[1]; + const size_t dst_row_size = dst->nb[1]; + + // VTCM scratchpads for all tensors + octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; + octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; + + size_t spad_size = octx->src0_spad.size + octx->dst_spad.size; + + FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, + spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t n_jobs = MIN(n_threads, src0_nrows); + + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs); + } + + return err; +} + +int op_unary(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src0.type) { + case HTP_TYPE_F32: + err = execute_op_unary_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.c b/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.c new file mode 100644 index 0000000..894815f --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.c @@ -0,0 +1,293 @@ +#include "worker-pool.h" + +#include <qurt.h> +#include <stdatomic.h> +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include "HAP_farf.h" + +#define WORKER_THREAD_STACK_SZ (2 * 16384) +#define LOWEST_USABLE_QURT_PRIO (254) + +struct worker_pool_s; + +// internal structure kept in thread-local storage per instance of worker pool +typedef struct { + struct worker_pool_s * pool; + unsigned int id; +} worker_context_t; + +// internal structure kept in thread-local storage per instance of worker pool +typedef struct worker_pool_s { + worker_pool_job_t job[MAX_NUM_WORKERS]; // list of job descriptors + qurt_thread_t thread[MAX_NUM_WORKERS]; // thread ID's of the workers + worker_context_t context[MAX_NUM_WORKERS]; // worker contexts + void * stack[MAX_NUM_WORKERS]; // thread stack pointers + unsigned int n_threads; // number of workers in this pool + + atomic_uint seqn; // seqno used to detect new jobs + atomic_uint next_job; // next job index + atomic_uint n_pending; // number of pending jobs + atomic_uint n_jobs; // number of current jobs + atomic_bool killed; // threads need to exit +} worker_pool_t; + +static void worker_pool_main(void * context) { + worker_context_t * me = (worker_context_t *) context; + worker_pool_t * pool = me->pool; + + FARF(HIGH, "worker-pool: thread %u started", me->id); + + unsigned int prev_seqn = 0; + while (!atomic_load(&pool->killed)) { + unsigned int seqn = atomic_load(&pool->seqn); + if (seqn == prev_seqn) { + // Nothing to do + qurt_futex_wait(&pool->seqn, prev_seqn); + continue; + } + + // New job + prev_seqn = seqn; + + unsigned int n = atomic_load(&pool->n_jobs); + unsigned int i = atomic_fetch_add(&pool->next_job, 1); + if (i >= n) { + // Spurios wakeup + continue; + } + + pool->job[i].func(n, i, pool->job[i].data); + + atomic_fetch_sub(&pool->n_pending, 1); + } + + FARF(HIGH, "worker-pool: thread %u stopped", me->id); +} + +AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context, uint32_t n_threads, uint32_t stack_size) { + int err = 0; + + if (NULL == context) { + FARF(ERROR, "NULL context passed to worker_pool_init()."); + return AEE_EBADPARM; + } + + // Allocations + int size = (stack_size * n_threads) + (sizeof(worker_pool_t)); + + unsigned char * mem_blob = (unsigned char *) malloc(size); + if (!mem_blob) { + FARF(ERROR, "Could not allocate memory for worker pool!!"); + return AEE_ENOMEMORY; + } + + worker_pool_t * me = (worker_pool_t *) (mem_blob + stack_size * n_threads); + + // name for the first worker, useful in debugging threads + char name[19]; + snprintf(name, 12, "0x%8x:", (int) me); + strcat(name, "worker0"); + me->n_threads = n_threads; + + // initializations + for (unsigned int i = 0; i < me->n_threads; i++) { + me->stack[i] = NULL; + me->thread[i] = 0; + + me->context[i].id = i; + me->context[i].pool = me; + } + + // initialize job queue + me->n_pending = 0; + me->n_jobs = 0; + me->next_job = 0; + me->seqn = 0; + me->killed = 0; + + // launch the workers + qurt_thread_attr_t attr; + qurt_thread_attr_init(&attr); + + for (unsigned int i = 0; i < me->n_threads; i++) { + // set up stack + me->stack[i] = mem_blob; + mem_blob += stack_size; + qurt_thread_attr_set_stack_addr(&attr, me->stack[i]); + qurt_thread_attr_set_stack_size(&attr, stack_size); + + // set up name + qurt_thread_attr_set_name(&attr, name); + name[17] = (name[17] + 1); + // name threads context:worker0, context:worker1, .. (recycle at 9, but num threads should be less than that anyway) + if (name[17] > '9') { + name[17] = '0'; + } + + // set up priority - by default, match the creating thread's prio + int prio = qurt_thread_get_priority(qurt_thread_get_id()); + + if (prio < 1) { + prio = 1; + } + if (prio > LOWEST_USABLE_QURT_PRIO) { + prio = LOWEST_USABLE_QURT_PRIO; + } + + qurt_thread_attr_set_priority(&attr, prio); + + // launch + err = qurt_thread_create(&me->thread[i], &attr, worker_pool_main, (void *) &me->context[i]); + if (err) { + FARF(ERROR, "Could not launch worker threads!"); + worker_pool_release((worker_pool_context_t *) &me); + return AEE_EQURTTHREADCREATE; + } + } + *context = (worker_pool_context_t *) me; + return AEE_SUCCESS; +} + +AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads) { + return worker_pool_init_with_stack_size(context, n_threads, WORKER_THREAD_STACK_SZ); +} + +// clean up worker pool +void worker_pool_release(worker_pool_context_t * context) { + worker_pool_t * me = (worker_pool_t *) *context; + + // if no worker pool exists, return error. + if (NULL == me) { + return; + } + + atomic_store(&me->killed, 1); + atomic_fetch_add(&me->seqn, 1); + qurt_futex_wake(&me->seqn, me->n_threads); + + // de-initializations + for (unsigned int i = 0; i < me->n_threads; i++) { + if (me->thread[i]) { + int status; + (void) qurt_thread_join(me->thread[i], &status); + } + } + + // free allocated memory (were allocated as a single buffer starting at stack[0]) + if (me->stack[0]) { + free(me->stack[0]); + } + + *context = NULL; +} + +// run jobs +AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n) { + worker_pool_t * me = (worker_pool_t *) context; + if (NULL == me) { + FARF(ERROR, "worker-pool: invalid context"); + return AEE_EBADPARM; + } + + if (n > me->n_threads) { + FARF(ERROR, "worker-pool: invalid number of jobs %u for n-threads %u", n, me->n_threads); + return AEE_EBADPARM; + } + + memcpy(me->job, job, sizeof(worker_pool_job_t) * n); + + if (n > 1) { + atomic_store(&me->next_job, 1); + atomic_store(&me->n_jobs, n); + atomic_store(&me->n_pending, n - 1); + + // wake up workers + atomic_fetch_add(&me->seqn, 1); + qurt_futex_wake(&me->seqn, n - 1); + } + + // main thread runs job #0 + me->job[0].func(n, 0, me->job[0].data); + + if (n > 1) { + while (atomic_load(&me->n_pending)) + ; + } + + return 0; +} + +// run func +AEEResult worker_pool_run_func(worker_pool_context_t context, worker_callback_t func, void * data, unsigned int n) { + worker_pool_job_t job[n]; + + for (unsigned int i = 0; i < n; i++) { + job[i].func = func; + job[i].data = data; + } + + return worker_pool_run_jobs(context, job, n); +} + +AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio) { + worker_pool_t * me = (worker_pool_t *) context; + + // if no worker pool exists, return error. + if (!me) { + return AEE_ENOMORE; + } + + int result = AEE_SUCCESS; + if (prio < 1) { + prio = 1; + } + if (prio > LOWEST_USABLE_QURT_PRIO) { + prio = LOWEST_USABLE_QURT_PRIO; + } + + for (unsigned int i = 0; i < me->n_threads; i++) { + int res = qurt_thread_set_priority(me->thread[i], (unsigned short) prio); + if (0 != res) { + result = AEE_EBADPARM; + FARF(ERROR, "QURT failed to set priority of thread %d, ERROR = %d", me->thread[i], res); + } + } + + return result; +} + +AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids) { + worker_pool_t * me = (worker_pool_t *) context; + if (!me) { + FARF(ERROR, "worker-pool: invalid context"); + return AEE_EBADPARM; + ; + } + + for (int i = 0; i < me->n_threads; i++) { + tids[i] = me->thread[i]; + } + + return AEE_SUCCESS; +} + +AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio) { + worker_pool_t * me = (worker_pool_t *) context; + if (!me) { + FARF(ERROR, "worker-pool: invalid context"); + return AEE_EBADPARM; + } + + int priority = qurt_thread_get_priority(me->thread[0]); + if (priority > 0) { + *prio = priority; + return 0; + } else { + *prio = 0; + return AEE_EBADSTATE; + } +} diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.h b/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.h new file mode 100644 index 0000000..6f8c905 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.h @@ -0,0 +1,57 @@ +#ifndef HTP_WORKER_POOL_H +#define HTP_WORKER_POOL_H + +// MACRO enables function to be visible in shared-library case. +#define WORKERPOOL_API __attribute__((visibility("default"))) + +#include <AEEStdDef.h> +#include <AEEStdErr.h> +#include <stdint.h> + +#ifdef __cplusplus +extern "C" { +#endif + +/// signature of callbacks to be invoked by worker threads +typedef void (*worker_callback_t)(unsigned int n, unsigned int i, void *); + +/// Typedef of worker_pool context +typedef void * worker_pool_context_t; + +/// descriptor for requested callback +typedef struct { + worker_callback_t func; + void * data; +} worker_pool_job_t; + +/// Maximum supported number of worker threads. +#define MAX_NUM_WORKERS 10 + +// Initialize worker pool. +WORKERPOOL_API AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads); + +// Initialize worker pool with custom stack size +WORKERPOOL_API AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context, + uint32_t n_threads, + uint32_t stack_size); + +// Kill worker threads and release worker pool resources +WORKERPOOL_API void worker_pool_release(worker_pool_context_t * context); + +// Run jobs with the worker pool. +WORKERPOOL_API AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n); + +WORKERPOOL_API AEEResult worker_pool_run_func(worker_pool_context_t context, + worker_callback_t func, + void * data, + unsigned int n); + +WORKERPOOL_API AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio); +WORKERPOOL_API AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio); +WORKERPOOL_API AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids); + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef HTP_WORKER_POOL_H |
