summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-hexagon
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-hexagon')
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/CMakeLists.txt117
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp3286
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp-drv.cpp418
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp-drv.h121
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt45
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c823
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c281
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/binary-ops.c827
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake157
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c251
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c684
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c106
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.c63
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.h156
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hex-dump.h77
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h37
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hex-utils.h51
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h35
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/htp-msg.h154
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/htp-ops.h91
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl16
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h470
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h173
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h245
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-div.h116
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h129
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h215
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h100
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h176
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h266
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h133
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h141
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h126
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-types.h36
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h18
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/main.c1150
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c2665
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/rope-ops.c480
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c164
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c395
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c115
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c342
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.c293
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.h57
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/libdl.h79
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/libggml-htp.inf38
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/op-desc.h153
47 files changed, 16071 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-hexagon/CMakeLists.txt b/llama.cpp/ggml/src/ggml-hexagon/CMakeLists.txt
new file mode 100644
index 0000000..f3a5835
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/CMakeLists.txt
@@ -0,0 +1,117 @@
+file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
+file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
+
+if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}")
+ message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.")
+endif()
+
+if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
+ message("Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json")
+ file(READ "${HEXAGON_SDK_ROOT}/hexagon_sdk.json" HEXAGON_SDK_CONFIG_PATH)
+ string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} "root" "tools" "info" 0 "path")
+ message("Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}")
+ set(HEXAGON_TOOLS_ROOT "${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}")
+ file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
+ if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
+ message(FATAL_ERROR "Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
+ endif()
+endif()
+
+message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels")
+
+include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
+include(ExternalProject)
+
+option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
+set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
+set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
+
+add_library(htp_iface OBJECT
+ ${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
+
+set_target_properties(htp_iface PROPERTIES POSITION_INDEPENDENT_CODE ON)
+target_include_directories(htp_iface PUBLIC
+ ${HEXAGON_SDK_ROOT}/incs
+ ${HEXAGON_SDK_ROOT}/incs/stddef
+ ${HEXAGON_SDK_ROOT}/utils/examples
+ ${CMAKE_CURRENT_SOURCE_DIR}/htp
+ ${CMAKE_CURRENT_BINARY_DIR})
+
+build_idl(htp/htp_iface.idl htp_iface)
+
+if (CMAKE_SYSTEM_NAME MATCHES Android)
+ target_link_options(htp_iface PUBLIC -llog -ldl)
+elseif (CMAKE_SYSTEM_NAME MATCHES Windows)
+ target_precompile_headers(htp_iface PUBLIC <sal.h>)
+else()
+ target_link_options(htp_iface PUBLIC -ldl)
+endif()
+
+set(TARGET_NAME ggml-hexagon)
+ggml_add_backend_library(${TARGET_NAME}
+ ggml-hexagon.cpp
+ htp-drv.cpp
+ htp-drv.h
+ libdl.h
+ ../../include/ggml-hexagon.h)
+
+target_link_libraries(${TARGET_NAME} PRIVATE htp_iface)
+target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR})
+
+# Build HTP skels
+set(HTP_SKELS)
+function(build_htp_skel V)
+ ExternalProject_Add(htp-${V}
+ SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
+ BUILD_BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so
+ CMAKE_ARGS
+ -DCMAKE_BUILD_TYPE=Release
+ -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
+ -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
+ -DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
+ -DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
+ -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
+ -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
+ -DDSP_VERSION=${V}
+ -DPREBUILT_LIB_DIR="toolv19_${V}")
+ list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
+ set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
+endfunction()
+
+build_htp_skel(v68)
+build_htp_skel(v69)
+build_htp_skel(v73)
+build_htp_skel(v75)
+build_htp_skel(v79)
+build_htp_skel(v81)
+
+# Install Hexagon skels required at runtime
+install(FILES ${HTP_SKELS} TYPE LIB)
+
+if (CMAKE_SYSTEM_NAME MATCHES Windows AND GGML_HEXAGON_HTP_CERT)
+ file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/arm64" WINSDK_BIN0_ARM64)
+ file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/x86" WINSDK_BIN0_X86)
+ file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/arm64" WINSDK_BIN1_ARM64)
+ file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/x86" WINSDK_BIN1_X86)
+
+ set(WINSDK_PATHS ${WINSDK_BIN0_ARM64} ${WINSDK_BIN0_X86} ${WINSDK_BIN1_ARM64} ${WINSDK_BIN1_X86})
+
+ find_program(INF2CAT NAMES inf2cat.exe PATHS ${WINSDK_PATHS} REQUIRED)
+ find_program(SIGNTOOL NAMES signtool.exe PATHS ${WINSDK_PATHS} REQUIRED)
+
+ message(STATUS "hexagon: using ${GGML_HEXAGON_HTP_CERT} to sign libggml-htp skels")
+
+ set(LIBGGML_HTP_CAT ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp.cat)
+ add_custom_target(libggml-htp-cat
+ BYPRODUCTS ${LIBGGML_HTP_CAT}
+ DEPENDS libggml-htp.inf ${HTP_SKELS}
+ COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/libggml-htp.inf ${CMAKE_CURRENT_BINARY_DIR}
+ COMMAND ${INF2CAT} /driver:${CMAKE_CURRENT_BINARY_DIR} /os:10_25H2_ARM64
+ COMMAND ${SIGNTOOL} sign /fd sha256 /f ${GGML_HEXAGON_HTP_CERT} ${LIBGGML_HTP_CAT}
+ COMMENT "generating and signing libggml-htp.cat file"
+ VERBATIM
+ )
+
+ add_dependencies(${TARGET_NAME} libggml-htp-cat)
+ install(FILES ${LIBGGML_HTP_CAT} TYPE LIB)
+endif()
diff --git a/llama.cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/llama.cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp
new file mode 100644
index 0000000..54f9986
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp
@@ -0,0 +1,3286 @@
+#include <assert.h>
+#include <inttypes.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <time.h>
+
+#include <atomic>
+#include <chrono>
+#include <cstddef>
+#include <mutex>
+#include <stdexcept>
+#include <string>
+
+#ifdef _WIN32
+# include <sal.h>
+#else
+# include <semaphore.h>
+# include <unistd.h>
+#endif
+
+#pragma clang diagnostic ignored "-Wnested-anon-types"
+#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
+
+#include <AEEStdErr.h>
+#include <dspqueue.h>
+#include <rpcmem.h>
+
+#define GGML_COMMON_IMPL_CPP
+#include "ggml-backend-impl.h"
+#include "ggml-common.h"
+#include "ggml-hexagon.h"
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+#include "op-desc.h"
+#include "htp-msg.h"
+#include "htp_iface.h"
+#include "htp-drv.h"
+
+static size_t opt_ndev = 1;
+static size_t opt_nhvx = 0; // use all
+static int opt_arch = 0; // autodetect
+static int opt_etm = 0;
+static int opt_verbose = 0;
+static int opt_profile = 0;
+static int opt_hostbuf = 1; // hostbuf ON by default
+static int opt_experimental = 0;
+
+// Enable all stages by default
+static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;
+static int opt_opsync = 0; // synchronous ops
+
+#define HEX_VERBOSE(...) \
+ if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__)
+
+static inline uint64_t hex_is_aligned(void * addr, uint32_t align) {
+ return ((size_t) addr & (align - 1)) == 0;
+}
+
+static inline size_t hex_round_up(size_t n, size_t m) {
+ return m * ((n + m - 1) / m);
+}
+
+static const char * status_to_str(uint32_t status) {
+ switch (status) {
+ case HTP_STATUS_OK:
+ return "OK";
+ case HTP_STATUS_NO_SUPPORT:
+ return "NO-SUPPORT";
+ case HTP_STATUS_INVAL_PARAMS:
+ return "INVAL-PARAMS";
+ case HTP_STATUS_VTCM_TOO_SMALL:
+ return "VTCM-TOO-SMALL";
+ case HTP_STATUS_INTERNAL_ERR:
+ return "INTERNAL-ERROR";
+ default:
+ return "UNKNOWN";
+ }
+}
+
+// ** debug helpers
+
+static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_tensor * op, const uint32_t req_flags) {
+ if (!opt_verbose) return;
+
+ op_desc desc(op);
+ GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(),
+ ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags);
+}
+
+static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) {
+ if (!opt_verbose) return;
+
+ op_desc desc(op);
+ GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(),
+ ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no");
+}
+
+static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op,
+ uint32_t op_usec, uint32_t op_cycles, uint32_t op_pkts, uint64_t call_usec) {
+ if (!opt_profile) return;
+
+ op_desc desc(op);
+ GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(),
+ ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs,
+ op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec);
+}
+
+// ** backend sessions
+
+struct ggml_hexagon_session {
+ ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false);
+ ~ggml_hexagon_session() noexcept(true);
+
+ void allocate(int dev_id) noexcept(false);
+ void release() noexcept(true);
+
+ void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false);
+ void flush();
+
+ ggml_backend_buffer_type buffer_type = {};
+ ggml_backend_buffer_type repack_buffer_type = {};
+
+ std::string name;
+ remote_handle64 handle;
+ dspqueue_t queue;
+ uint32_t session_id;
+ uint32_t domain_id;
+ uint64_t queue_id;
+ int dev_id;
+ bool valid_session;
+ bool valid_handle;
+ bool valid_queue;
+ bool valid_iface;
+ std::atomic<int> op_pending;
+ uint32_t prof_usecs;
+ uint32_t prof_cycles;
+ uint32_t prof_pkts;
+};
+
+void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) {
+ // Bump pending flag (cleared in the session::flush once we get the responce)
+ this->op_pending++; // atomic inc
+
+ int err = dspqueue_write(this->queue,
+ 0, // flags - the framework will autoset this
+ n_bufs, // number of buffers
+ bufs, // buffer references
+ sizeof(req), // Message length
+ (const uint8_t *) &req, // Message
+ DSPQUEUE_TIMEOUT // Timeout
+ );
+
+ if (err != 0) {
+ GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
+ }
+
+ if (sync) {
+ flush();
+ }
+}
+
+// Flush HTP response queue i.e wait for all outstanding requests to complete
+void ggml_hexagon_session::flush() {
+ dspqueue_t q = this->queue;
+
+ // Repeatedly read packets from the queue until it's empty. We don't
+ // necessarily get a separate callback for each packet, and new packets
+ // may arrive while we're processing the previous one.
+
+ while (this->op_pending) {
+ struct htp_general_rsp rsp;
+ uint32_t rsp_size;
+ uint32_t flags;
+
+ struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
+ uint32_t n_bufs;
+
+ // Read response packet from queue
+ int err = dspqueue_read(q, &flags,
+ HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
+ &n_bufs, // Number of buffer references
+ bufs, // Buffer references
+ sizeof(rsp), // Max message length
+ &rsp_size, // Message length
+ (uint8_t *) &rsp, // Message
+ DSPQUEUE_TIMEOUT); // Timeout
+
+ if (err == AEE_EEXPIRED) {
+ // TODO: might need to bail out if the HTP is stuck on something
+ continue;
+ }
+
+ if (err != 0) {
+ GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err);
+ }
+
+ // Basic sanity checks
+ if (rsp_size != sizeof(rsp)) {
+ GGML_ABORT("ggml-hex: dspcall : bad response (size)\n");
+ }
+
+ if (rsp.status != HTP_STATUS_OK) {
+ GGML_LOG_ERROR("ggml-hex: dspcall : dsp-rsp: %s\n", status_to_str(rsp.status));
+ // TODO: handle errors
+ }
+
+ // TODO: update profiling implementation, currently only works for opt_opsync mode
+ this->prof_usecs = rsp.prof_usecs;
+ this->prof_cycles = rsp.prof_cycles;
+ this->prof_pkts = rsp.prof_pkts;
+
+ this->op_pending--; // atomic dec
+ }
+}
+
+// ** backend buffers
+
+struct ggml_backend_hexagon_buffer_type_context {
+ ggml_backend_hexagon_buffer_type_context(const std::string & name, ggml_hexagon_session * sess) {
+ this->sess = sess;
+ this->name = name;
+ }
+
+ ggml_hexagon_session * sess;
+ std::string name;
+};
+
+struct ggml_backend_hexagon_buffer_context {
+ bool mmap_to(ggml_hexagon_session * s) {
+ HEX_VERBOSE("ggml-hex: %s mmaping buffer: base %p domain-id %d session-id %d size %zu fd %d repack %d\n",
+ s->name.c_str(), (void *) this->base, s->domain_id, s->session_id, this->size, this->fd,
+ (int) this->repack);
+
+ int err = fastrpc_mmap(s->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD);
+ if (err != 0) {
+ GGML_LOG_ERROR("ggml-hex: buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n",
+ s->domain_id, this->size, this->fd, (unsigned) err);
+ return false;
+ }
+
+ return true;
+ }
+
+ bool mmap() {
+ if (this->mapped) {
+ return true;
+ }
+ if (!mmap_to(this->sess)) {
+ return false;
+ }
+ this->mapped = true;
+ return true;
+ }
+
+ void munmap() {
+ if (!this->mapped) {
+ return;
+ }
+
+ fastrpc_munmap(this->sess->domain_id, this->fd, this->base, this->size);
+ this->mapped = false;
+ }
+
+ ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {
+ size += 4 * 1024; // extra page for padding
+
+ this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
+ if (!this->base) {
+ GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size);
+ throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)");
+ }
+
+ this->fd = rpcmem_to_fd(this->base);
+ if (this->fd < 0) {
+ GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->name.c_str(), (void *) this->base);
+ rpcmem_free(this->base);
+ this->base = NULL;
+ throw std::runtime_error("ggml-hex: rpcmem_to_fd failed (see log for details)");
+ }
+
+ HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d repack %d\n", sess->name.c_str(),
+ (void *) this->base, size, this->fd, (int) repack);
+
+ this->sess = sess;
+ this->size = size;
+ this->mapped = false;
+ this->repack = repack;
+ }
+
+ ~ggml_backend_hexagon_buffer_context() {
+ munmap();
+ if (this->base) {
+ rpcmem_free(this->base);
+ this->base = NULL;
+ }
+ }
+
+ ggml_hexagon_session * sess; // primary session
+ uint8_t * base;
+ size_t size;
+ int fd;
+ bool mapped; // mmap is done
+ bool repack; // repacked buffer
+};
+
+static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_buffer_t buffer) {
+ return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer->buft->context)->sess;
+}
+
+static void ggml_backend_hexagon_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
+ delete ctx;
+}
+
+static void * ggml_backend_hexagon_buffer_get_base(ggml_backend_buffer_t buffer) {
+ auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
+ return ctx->base;
+}
+
+static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+ auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
+ auto sess = ctx->sess;
+
+ HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d repack %d\n", sess->name.c_str(),
+ tensor->name, (void *) ctx->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage,
+ (int) ctx->repack);
+
+ if (tensor->view_src != NULL && tensor->view_offs == 0) {
+ ; // nothing to do for the view
+ } else {
+ if (!ctx->mapped) {
+ ctx->mmap();
+ }
+ }
+ return GGML_STATUS_SUCCESS;
+}
+
+// ======== Q4x4x2 ====================
+struct x2_q4 {
+ int v[2];
+};
+
+static x2_q4 unpack_q4(uint8_t v) {
+ x2_q4 x = { (int) (v & 0x0f) - 8, (int) (v >> 4) - 8 };
+ return x;
+}
+
+static void dump_block_q4_0(const block_q4_0 * b, int i) {
+ HEX_VERBOSE("ggml-hex: repack q4_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_q4(b->qs[0]).v[0],
+ unpack_q4(b->qs[1]).v[0], unpack_q4(b->qs[2]).v[0], unpack_q4(b->qs[3]).v[0], unpack_q4(b->qs[12]).v[1],
+ unpack_q4(b->qs[13]).v[1], unpack_q4(b->qs[14]).v[1], unpack_q4(b->qs[15]).v[1],
+ GGML_FP16_TO_FP32(b->d));
+}
+
+static void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k) {
+ static const int qk = QK_Q4_0x4x2;
+ const int dblk_size = 8 * 2; // 8x __fp16
+ const int qblk_size = qk / 2; // int4
+ const int qrow_size = k / 2; // int4 (not padded)
+
+ const uint8_t * v_q = v + 0; // quants first
+ const uint8_t * v_d = v + qrow_size; // then scales
+
+ const uint8_t * q = v_q + i * qblk_size;
+ const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);
+
+ HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
+ unpack_q4(q[0]).v[0], unpack_q4(q[1]).v[0], unpack_q4(q[2]).v[0], unpack_q4(q[3]).v[0],
+ unpack_q4(q[60]).v[0], unpack_q4(q[61]).v[0], unpack_q4(q[62]).v[0], unpack_q4(q[63]).v[0],
+ unpack_q4(q[124]).v[0], unpack_q4(q[125]).v[0], unpack_q4(q[126]).v[0], unpack_q4(q[127]).v[0],
+ GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));
+
+ HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
+ i + 1, unpack_q4(q[0]).v[1], unpack_q4(q[1]).v[1], unpack_q4(q[2]).v[1], unpack_q4(q[3]).v[1],
+ unpack_q4(q[60]).v[1], unpack_q4(q[61]).v[1], unpack_q4(q[62]).v[1], unpack_q4(q[63]).v[1],
+ unpack_q4(q[124]).v[1], unpack_q4(q[125]).v[1], unpack_q4(q[126]).v[1], unpack_q4(q[127]).v[1],
+ GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));
+}
+
+static void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) {
+ static const int qk = QK4_0;
+
+ for (unsigned int i = 0; i < qk / 2; ++i) {
+ const int x0 = (x->qs[i] & 0x0F);
+ const int x1 = (x->qs[i] >> 4);
+ qs[bi * qk + i + 0] = x0;
+ qs[bi * qk + i + qk / 2] = x1;
+ }
+}
+
+static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi) {
+ static const int qk = QK4_0;
+
+ for (unsigned int i = 0; i < qk / 2; ++i) {
+ const uint8_t x0 = qs[bi * qk + i + 0];
+ const uint8_t x1 = qs[bi * qk + i + qk / 2];
+ x->qs[i] = x0 | (x1 << 4);
+ }
+}
+
+static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
+ static const int qk = QK_Q4_0x4x2;
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
+
+ const int dblk_size = 8 * 2; // 8x __fp16
+ const int qblk_size = qk / 2; // int4
+ const int qrow_size = k / 2; // int4 (not padded to blocks)
+
+ uint8_t * y_q = y + 0; // quants first
+ uint8_t * y_d = y + qrow_size; // then scales
+
+ if (opt_verbose > 2) {
+ for (int i = 0; i < nb; i++) {
+ dump_block_q4_0(&x[i * 8 + 0], 0);
+ dump_block_q4_0(&x[i * 8 + 1], 1);
+ dump_block_q4_0(&x[i * 8 + 2], 2);
+ dump_block_q4_0(&x[i * 8 + 3], 3);
+ dump_block_q4_0(&x[i * 8 + 4], 4);
+ dump_block_q4_0(&x[i * 8 + 5], 5);
+ dump_block_q4_0(&x[i * 8 + 6], 6);
+ dump_block_q4_0(&x[i * 8 + 7], 7);
+ }
+ }
+
+ // Repack the quants
+ for (int i = 0; i < nb; i++) {
+ uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
+ unpack_q4_0_quants(qs, &x[i * 8 + 0], 0);
+ unpack_q4_0_quants(qs, &x[i * 8 + 1], 1);
+ unpack_q4_0_quants(qs, &x[i * 8 + 2], 2);
+ unpack_q4_0_quants(qs, &x[i * 8 + 3], 3);
+ unpack_q4_0_quants(qs, &x[i * 8 + 4], 4);
+ unpack_q4_0_quants(qs, &x[i * 8 + 5], 5);
+ unpack_q4_0_quants(qs, &x[i * 8 + 6], 6);
+ unpack_q4_0_quants(qs, &x[i * 8 + 7], 7);
+
+ uint8_t * q = y_q + (i * qblk_size);
+ for (int j = 0; j < qk / 2; j++) {
+ q[j] = (qs[j + 128] << 4) | qs[j];
+ }
+ }
+
+ // Repack the scales
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
+ // the last block is truncated and overriden by the scales.
+ for (int i = 0; i < nb; i++) {
+ // Repack the scales
+ ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
+ d[0] = x[i * 8 + 0].d;
+ d[1] = x[i * 8 + 1].d;
+ d[2] = x[i * 8 + 2].d;
+ d[3] = x[i * 8 + 3].d;
+ d[4] = x[i * 8 + 4].d;
+ d[5] = x[i * 8 + 5].d;
+ d[6] = x[i * 8 + 6].d;
+ d[7] = x[i * 8 + 7].d;
+ }
+
+ if (opt_verbose > 1) {
+ for (int i = 0; i < nb; i++) {
+ dump_packed_block_q4x4x2(y, i, k);
+ }
+ }
+}
+
+static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
+ static const int qk = QK_Q4_0x4x2;
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
+
+ const int dblk_size = 8 * 2; // 8x __fp16
+ const int qblk_size = qk / 2; // int4
+ const int qrow_size = k / 2; // int4 (not padded to blocks)
+
+ const uint8_t * y_q = y + 0; // quants first
+ const uint8_t * y_d = y + qrow_size; // then scales
+
+ if (opt_verbose > 1) {
+ for (int i = 0; i < nb; i++) {
+ dump_packed_block_q4x4x2(y, i, k);
+ }
+ }
+
+ // Unpack the quants
+ for (int i = 0; i < nb; i++) {
+ uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
+
+ const uint8_t * q = y_q + (i * qblk_size);
+ for (int j = 0; j < qk / 2; j++) {
+ qs[j] = q[j] & 0xf;
+ qs[j + 128] = q[j] >> 4;
+ }
+
+ pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
+ pack_q4_0_quants(&x[i * 8 + 1], qs, 1);
+ pack_q4_0_quants(&x[i * 8 + 2], qs, 2);
+ pack_q4_0_quants(&x[i * 8 + 3], qs, 3);
+ pack_q4_0_quants(&x[i * 8 + 4], qs, 4);
+ pack_q4_0_quants(&x[i * 8 + 5], qs, 5);
+ pack_q4_0_quants(&x[i * 8 + 6], qs, 6);
+ pack_q4_0_quants(&x[i * 8 + 7], qs, 7);
+ }
+
+ // Repack the scales
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
+ // the last block is truncated and overriden by the scales.
+ for (int i = 0; i < nb; i++) {
+ // Unpack the scales
+ const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
+ x[i * 8 + 0].d = d[0];
+ x[i * 8 + 1].d = d[1];
+ x[i * 8 + 2].d = d[2];
+ x[i * 8 + 3].d = d[3];
+ x[i * 8 + 4].d = d[4];
+ x[i * 8 + 5].d = d[5];
+ x[i * 8 + 6].d = d[6];
+ x[i * 8 + 7].d = d[7];
+ }
+
+ if (opt_verbose > 2) {
+ for (int i = 0; i < nb; i++) {
+ dump_block_q4_0(&x[i * 8 + 0], 0);
+ dump_block_q4_0(&x[i * 8 + 1], 1);
+ dump_block_q4_0(&x[i * 8 + 2], 2);
+ dump_block_q4_0(&x[i * 8 + 3], 3);
+ dump_block_q4_0(&x[i * 8 + 4], 4);
+ dump_block_q4_0(&x[i * 8 + 5], 5);
+ dump_block_q4_0(&x[i * 8 + 6], 6);
+ dump_block_q4_0(&x[i * 8 + 7], 7);
+ }
+ }
+}
+
+static void init_row_q4x4x2(block_q4_0 * x, int64_t k) {
+ static const int qk = QK_Q4_0x4x2;
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
+
+ // Init the quants such that they unpack into zeros
+ uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
+ memset(qs, 8, sizeof(qs));
+
+ for (int i = 0; i < nb; i++) {
+ pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
+ pack_q4_0_quants(&x[i * 8 + 1], qs, 1);
+ pack_q4_0_quants(&x[i * 8 + 2], qs, 2);
+ pack_q4_0_quants(&x[i * 8 + 3], qs, 3);
+ pack_q4_0_quants(&x[i * 8 + 4], qs, 4);
+ pack_q4_0_quants(&x[i * 8 + 5], qs, 5);
+ pack_q4_0_quants(&x[i * 8 + 6], qs, 6);
+ pack_q4_0_quants(&x[i * 8 + 7], qs, 7);
+ }
+
+ // Init the scales
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
+ // the last block is truncated and overriden by the scales.
+ for (int i = 0; i < nb; i++) {
+ // Unpack the scales
+ x[i * 8 + 0].d = 0;
+ x[i * 8 + 1].d = 0;
+ x[i * 8 + 2].d = 0;
+ x[i * 8 + 3].d = 0;
+ x[i * 8 + 4].d = 0;
+ x[i * 8 + 5].d = 0;
+ x[i * 8 + 6].d = 0;
+ x[i * 8 + 7].d = 0;
+ }
+}
+
+// repack q4_0 data into q4x4x2 tensor
+static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) {
+ int64_t nrows = ggml_nrows(t);
+
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
+
+ // Ensure we don't try to read more data than is available in the source buffer 'data'
+ // or write more than the tensor can hold.
+ const size_t total_tensor_size = (size_t)nrows * row_size;
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+ // Calculate how many full rows and how many remaining bytes we need to process.
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
+
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
+ GGML_ASSERT(buf_pd != NULL);
+
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
+ GGML_ASSERT(buf_rp != NULL);
+
+ HEX_VERBOSE("ggml-hex: repack-q4_0-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
+ t->ne[0], nrows, row_size);
+
+ init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
+
+ // 1. Process all the full rows
+ for (int64_t i = 0; i < n_full_rows; i++) {
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
+
+ memcpy(buf_pd, src, row_size);
+ repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
+ memcpy(dst, buf_rp, row_size);
+ }
+
+ // 2. Process the final, potentially partial, row
+ if (n_rem_bytes > 0) {
+ const int64_t i = n_full_rows;
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
+
+ // re-init the row because we are potentially copying a partial row
+ init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);
+
+ // Copy only the remaining bytes from the source.
+ memcpy(buf_pd, src, n_rem_bytes);
+
+ // Repack the entire buffer
+ repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
+
+ // Write only the corresponding remaining bytes to the destination tensor.
+ memcpy(dst, buf_rp, n_rem_bytes);
+ }
+
+ ggml_aligned_free(buf_pd, row_size_pd);
+ ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+// repack q4x4x2 tensor into q4_0 data
+static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) {
+ int64_t nrows = ggml_nrows(t);
+
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
+
+ // Ensure we don't try to copy more data than the tensor actually contains.
+ const size_t total_tensor_size = (size_t)nrows * row_size;
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+ // Calculate how many full rows and how many remaining bytes we need to process.
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
+
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
+ GGML_ASSERT(buf_pd != NULL);
+
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
+ GGML_ASSERT(buf_rp != NULL);
+
+ HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
+ t->ne[0], nrows, row_size);
+
+ memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
+
+ // 1. Process all the full rows
+ for (int64_t i = 0; i < n_full_rows; i++) {
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
+
+ memcpy(buf_pd, src, row_size);
+ unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+ memcpy(dst, buf_rp, row_size);
+ }
+
+ // 2. Process the final, potentially partial, row
+ if (n_rem_bytes > 0) {
+ const int64_t i = n_full_rows;
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
+
+ // We still need to read and unpack the entire source row because quantization is block-based.
+ memcpy(buf_pd, src, row_size);
+ unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+
+ // But we only copy the remaining number of bytes to the destination.
+ memcpy(dst, buf_rp, n_rem_bytes);
+ }
+
+ ggml_aligned_free(buf_pd, row_size_pd);
+ ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+// ======== Q8x4x2 ====================
+static void dump_block_q8_0(const block_q8_0 * b, int i) {
+ HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2],
+ b->qs[3], b->qs[28], b->qs[29], b->qs[30], b->qs[31], GGML_FP16_TO_FP32(b->d));
+}
+
+static void dump_packed_block_q8x4x2(const uint8_t * v, unsigned int i, size_t k) {
+ static const int qk = QK_Q8_0x4x2;
+ const int dblk_size = 8 * 2; // 8x __fp16
+ const int qblk_size = qk; // int8
+ const int qrow_size = k; // int8 (not padded)
+
+ const uint8_t * v_q = v + 0; // quants first
+ const uint8_t * v_d = v + qrow_size; // then scales
+
+ const uint8_t * q = v_q + i * qblk_size;
+ const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);
+
+ HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
+ q[0], q[1], q[2], q[3], q[60], q[61], q[62], q[63], q[124], q[125], q[126], q[127],
+ GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));
+
+ HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
+ i + 1, q[128], q[129], q[130], q[131], q[192], q[193], q[194], q[195], q[252], q[253], q[254], q[255],
+ GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));
+}
+
+static void unpack_q8_0_quants(uint8_t * qs, const block_q8_0 * x, unsigned int bi) {
+ static const int qk = QK8_0;
+
+ for (unsigned int i = 0; i < qk; ++i) {
+ qs[bi * qk + i] = x->qs[i];
+ }
+}
+
+static void pack_q8_0_quants(block_q8_0 * x, const uint8_t * qs, unsigned int bi) {
+ static const int qk = QK8_0;
+
+ for (unsigned int i = 0; i < qk; ++i) {
+ x->qs[i] = qs[bi * qk + i];
+ }
+}
+
+static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) {
+ static const int qk = QK_Q8_0x4x2;
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
+
+ const int dblk_size = 8 * 2; // 8x __fp16
+ const int qblk_size = qk; // int8
+ const int qrow_size = k; // int8 (not padded to blocks)
+
+ uint8_t * y_q = y + 0; // quants first
+ uint8_t * y_d = y + qrow_size; // then scales
+
+ if (opt_verbose > 2) {
+ for (int i = 0; i < nb; i++) {
+ dump_block_q8_0(&x[i * 8 + 0], 0);
+ dump_block_q8_0(&x[i * 8 + 1], 1);
+ dump_block_q8_0(&x[i * 8 + 2], 2);
+ dump_block_q8_0(&x[i * 8 + 3], 3);
+ dump_block_q8_0(&x[i * 8 + 4], 4);
+ dump_block_q8_0(&x[i * 8 + 5], 5);
+ dump_block_q8_0(&x[i * 8 + 6], 6);
+ dump_block_q8_0(&x[i * 8 + 7], 7);
+ }
+ }
+
+ // Repack the quants
+ for (int i = 0; i < nb; i++) {
+ uint8_t qs[QK_Q8_0x4x2]; // unpacked quants
+
+ unpack_q8_0_quants(qs, &x[i * 8 + 0], 0);
+ unpack_q8_0_quants(qs, &x[i * 8 + 1], 1);
+ unpack_q8_0_quants(qs, &x[i * 8 + 2], 2);
+ unpack_q8_0_quants(qs, &x[i * 8 + 3], 3);
+ unpack_q8_0_quants(qs, &x[i * 8 + 4], 4);
+ unpack_q8_0_quants(qs, &x[i * 8 + 5], 5);
+ unpack_q8_0_quants(qs, &x[i * 8 + 6], 6);
+ unpack_q8_0_quants(qs, &x[i * 8 + 7], 7);
+
+ uint8_t * q = y_q + (i * qblk_size);
+ for (int j = 0; j < qk; j++) {
+ q[j] = qs[j];
+ }
+ }
+
+ // Repack the scales
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
+ // the last block is truncated and overriden by the scales.
+ for (int i = 0; i < nb; i++) {
+ // Repack the scales
+ ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
+ d[0] = x[i * 8 + 0].d;
+ d[1] = x[i * 8 + 1].d;
+ d[2] = x[i * 8 + 2].d;
+ d[3] = x[i * 8 + 3].d;
+ d[4] = x[i * 8 + 4].d;
+ d[5] = x[i * 8 + 5].d;
+ d[6] = x[i * 8 + 6].d;
+ d[7] = x[i * 8 + 7].d;
+ }
+
+ if (opt_verbose > 1) {
+ for (int i = 0; i < nb; i++) {
+ dump_packed_block_q8x4x2(y, i, k);
+ }
+ }
+}
+
+static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) {
+ static const int qk = QK_Q8_0x4x2;
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
+
+ const int dblk_size = 8 * 2; // 8x __fp16
+ const int qblk_size = qk; // int8
+ const int qrow_size = k; // int8 (not padded to blocks)
+
+ const uint8_t * y_q = y + 0; // quants first
+ const uint8_t * y_d = y + qrow_size; // then scales
+
+ if (opt_verbose > 1) {
+ for (int i = 0; i < nb; i++) {
+ dump_packed_block_q8x4x2(y, i, k);
+ }
+ }
+
+ // Unpack the quants
+ for (int i = 0; i < nb; i++) {
+ uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
+
+ const uint8_t * q = y_q + (i * qblk_size);
+ for (int j = 0; j < qk; j++) {
+ qs[j] = q[j];
+ }
+
+ pack_q8_0_quants(&x[i * 8 + 0], qs, 0);
+ pack_q8_0_quants(&x[i * 8 + 1], qs, 1);
+ pack_q8_0_quants(&x[i * 8 + 2], qs, 2);
+ pack_q8_0_quants(&x[i * 8 + 3], qs, 3);
+ pack_q8_0_quants(&x[i * 8 + 4], qs, 4);
+ pack_q8_0_quants(&x[i * 8 + 5], qs, 5);
+ pack_q8_0_quants(&x[i * 8 + 6], qs, 6);
+ pack_q8_0_quants(&x[i * 8 + 7], qs, 7);
+ }
+
+ // Repack the scales
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
+ // the last block is truncated and overriden by the scales.
+ for (int i = 0; i < nb; i++) {
+ // Unpack the scales
+ const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
+ x[i * 8 + 0].d = d[0];
+ x[i * 8 + 1].d = d[1];
+ x[i * 8 + 2].d = d[2];
+ x[i * 8 + 3].d = d[3];
+ x[i * 8 + 4].d = d[4];
+ x[i * 8 + 5].d = d[5];
+ x[i * 8 + 6].d = d[6];
+ x[i * 8 + 7].d = d[7];
+ }
+
+ if (opt_verbose > 2) {
+ for (int i = 0; i < nb; i++) {
+ dump_block_q8_0(&x[i * 8 + 0], 0);
+ dump_block_q8_0(&x[i * 8 + 1], 1);
+ dump_block_q8_0(&x[i * 8 + 2], 2);
+ dump_block_q8_0(&x[i * 8 + 3], 3);
+ dump_block_q8_0(&x[i * 8 + 4], 4);
+ dump_block_q8_0(&x[i * 8 + 5], 5);
+ dump_block_q8_0(&x[i * 8 + 6], 6);
+ dump_block_q8_0(&x[i * 8 + 7], 7);
+ }
+ }
+}
+
+static void init_row_q8x4x2(block_q8_0 * x, int64_t k) {
+ static const int qk = QK_Q8_0x4x2;
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
+
+ // Init the quants such that they unpack into zeros
+ uint8_t qs[QK_Q8_0x4x2]; // unpacked quants
+ memset(qs, 0, sizeof(qs));
+
+ for (int i = 0; i < nb; i++) {
+ pack_q8_0_quants(&x[i * 8 + 0], qs, 0);
+ pack_q8_0_quants(&x[i * 8 + 1], qs, 1);
+ pack_q8_0_quants(&x[i * 8 + 2], qs, 2);
+ pack_q8_0_quants(&x[i * 8 + 3], qs, 3);
+ pack_q8_0_quants(&x[i * 8 + 4], qs, 4);
+ pack_q8_0_quants(&x[i * 8 + 5], qs, 5);
+ pack_q8_0_quants(&x[i * 8 + 6], qs, 6);
+ pack_q8_0_quants(&x[i * 8 + 7], qs, 7);
+ }
+
+ // Init the scales
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2)
+ // the last block is truncated and overriden by the scales.
+ for (int i = 0; i < nb; i++) {
+ // Unpack the scales
+ x[i * 8 + 0].d = 0;
+ x[i * 8 + 1].d = 0;
+ x[i * 8 + 2].d = 0;
+ x[i * 8 + 3].d = 0;
+ x[i * 8 + 4].d = 0;
+ x[i * 8 + 5].d = 0;
+ x[i * 8 + 6].d = 0;
+ x[i * 8 + 7].d = 0;
+ }
+}
+
+// repack q8_0 data into q8x4x2 tensor
+static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) {
+ int64_t nrows = ggml_nrows(t);
+
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
+
+ // Ensure we don't try to read more data than is available in the source buffer 'data'
+ // or write more than the tensor can hold.
+ const size_t total_tensor_size = (size_t)nrows * row_size;
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+ // Calculate how many full rows and how many remaining bytes we need to process.
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
+
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
+ GGML_ASSERT(buf_pd != NULL);
+
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
+ GGML_ASSERT(buf_rp != NULL);
+
+ HEX_VERBOSE("ggml-hex: repack-q8_0-q8x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
+ t->ne[0], nrows, row_size);
+
+ init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
+
+ // 1. Process all the full rows
+ for (int64_t i = 0; i < n_full_rows; i++) {
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
+
+ memcpy(buf_pd, src, row_size);
+ repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
+ memcpy(dst, buf_rp, row_size);
+ }
+
+ // 2. Process the final, potentially partial, row
+ if (n_rem_bytes > 0) {
+ const int64_t i = n_full_rows;
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
+
+ // re-init the row because we are potentially copying a partial row
+ init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);
+
+ // Copy only the remaining bytes from the source.
+ memcpy(buf_pd, src, n_rem_bytes);
+
+ // Repack the entire buffer
+ repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
+
+ // Write only the corresponding remaining bytes to the destination tensor.
+ memcpy(dst, buf_rp, n_rem_bytes);
+ }
+
+ ggml_aligned_free(buf_pd, row_size_pd);
+ ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+// repack q8x4x2 tensor into q8_0 data
+static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) {
+ int64_t nrows = ggml_nrows(t);
+
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
+
+ // Ensure we don't try to copy more data than the tensor actually contains.
+ const size_t total_tensor_size = (size_t)nrows * row_size;
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+ // Calculate how many full rows and how many remaining bytes we need to process.
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
+
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
+ GGML_ASSERT(buf_pd != NULL);
+
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
+ GGML_ASSERT(buf_rp != NULL);
+
+ HEX_VERBOSE("ggml-hex: repack-q8x4x2-q8_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
+ t->ne[0], nrows, row_size);
+
+ memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
+
+ // 1. Process all the full rows
+ for (int64_t i = 0; i < n_full_rows; i++) {
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
+
+ memcpy(buf_pd, src, row_size);
+ unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+ memcpy(dst, buf_rp, row_size);
+ }
+
+ // 2. Process the final, potentially partial, row
+ if (n_rem_bytes > 0) {
+ const int64_t i = n_full_rows;
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
+
+ // We still need to read and unpack the entire source row because quantization is block-based.
+ memcpy(buf_pd, src, row_size);
+ unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+
+ // But we only copy the remaining number of bytes to the destination.
+ memcpy(dst, buf_rp, n_rem_bytes);
+ }
+
+ ggml_aligned_free(buf_pd, row_size_pd);
+ ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+// ======== MXFP4x4x2 ====================
+struct x2_mxfp4 {
+ int v[2];
+};
+
+static x2_mxfp4 unpack_mxfp4(uint8_t v) {
+ x2_mxfp4 x;
+ x.v[0] = kvalues_mxfp4[(v & 0x0f)];
+ x.v[1] = kvalues_mxfp4[(v >> 4)];
+ return x;
+}
+
+static void dump_block_mxfp4(const block_mxfp4 * b, int i) {
+ HEX_VERBOSE("ggml-hex: repack mxfp4 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_mxfp4(b->qs[0]).v[0],
+ unpack_mxfp4(b->qs[1]).v[0], unpack_mxfp4(b->qs[2]).v[0], unpack_mxfp4(b->qs[3]).v[0],
+ unpack_mxfp4(b->qs[12]).v[1], unpack_mxfp4(b->qs[13]).v[1], unpack_mxfp4(b->qs[14]).v[1],
+ unpack_mxfp4(b->qs[15]).v[1], GGML_E8M0_TO_FP32_HALF(b->e));
+}
+
+static void dump_packed_block_mxfp4x4x2(const uint8_t * v, unsigned int i, size_t k) {
+ static const int qk = QK_MXFP4x4x2;
+ const int eblk_size = 8 * 1; // 8x E8M0
+ const int qblk_size = qk / 2; // int4
+ const int qrow_size = k / 2; // int4 (not padded)
+
+ const uint8_t * v_q = v + 0; // quants first
+ const uint8_t * v_e = v + qrow_size; // then scales
+
+ const uint8_t * q = v_q + i * qblk_size;
+ const uint8_t * e = (const uint8_t *) (v_e + i * eblk_size);
+
+ HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
+ unpack_mxfp4(q[0]).v[0], unpack_mxfp4(q[1]).v[0], unpack_mxfp4(q[2]).v[0], unpack_mxfp4(q[3]).v[0],
+ unpack_mxfp4(q[60]).v[0], unpack_mxfp4(q[61]).v[0], unpack_mxfp4(q[62]).v[0], unpack_mxfp4(q[63]).v[0],
+ unpack_mxfp4(q[124]).v[0], unpack_mxfp4(q[125]).v[0], unpack_mxfp4(q[126]).v[0],
+ unpack_mxfp4(q[127]).v[0], GGML_E8M0_TO_FP32_HALF(e[0]), GGML_E8M0_TO_FP32_HALF(e[1]),
+ GGML_E8M0_TO_FP32_HALF(e[2]), GGML_E8M0_TO_FP32_HALF(e[3]));
+
+ HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
+ i + 1, unpack_mxfp4(q[0]).v[1], unpack_mxfp4(q[1]).v[1], unpack_mxfp4(q[2]).v[1],
+ unpack_mxfp4(q[3]).v[1], unpack_mxfp4(q[60]).v[1], unpack_mxfp4(q[61]).v[1], unpack_mxfp4(q[62]).v[1],
+ unpack_mxfp4(q[63]).v[1], unpack_mxfp4(q[124]).v[1], unpack_mxfp4(q[125]).v[1],
+ unpack_mxfp4(q[126]).v[1], unpack_mxfp4(q[127]).v[1], GGML_E8M0_TO_FP32_HALF(e[4]),
+ GGML_E8M0_TO_FP32_HALF(e[5]), GGML_E8M0_TO_FP32_HALF(e[6]), GGML_E8M0_TO_FP32_HALF(e[7]));
+}
+
+static void unpack_mxfp4_quants(uint8_t * qs, const block_mxfp4 * x, unsigned int bi) {
+ static const int qk = QK_MXFP4;
+
+ for (unsigned int i = 0; i < qk / 2; ++i) {
+ const uint8_t x0 = (x->qs[i] & 0x0F);
+ const uint8_t x1 = (x->qs[i] >> 4);
+ qs[bi * qk + i + 0] = x0;
+ qs[bi * qk + i + qk / 2] = x1;
+ }
+}
+
+static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int bi) {
+ static const int qk = QK4_0;
+
+ for (unsigned int i = 0; i < qk / 2; ++i) {
+ const uint8_t x0 = qs[bi * qk + i + 0];
+ const uint8_t x1 = qs[bi * qk + i + qk / 2];
+ x->qs[i] = x0 | (x1 << 4);
+ }
+}
+
+static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) {
+ static const int qk = QK_MXFP4x4x2;
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
+
+ const int eblk_size = 8 * 1; // 8x E8M0
+ const int qblk_size = qk / 2; // int4
+ const int qrow_size = k / 2; // int4 (not padded to blocks)
+
+ uint8_t * y_q = y + 0; // quants first
+ uint8_t * y_e = y + qrow_size; // then scales
+
+ if (opt_verbose > 2) {
+ for (int i = 0; i < nb; i++) {
+ dump_block_mxfp4(&x[i * 8 + 0], 0);
+ dump_block_mxfp4(&x[i * 8 + 1], 1);
+ dump_block_mxfp4(&x[i * 8 + 2], 2);
+ dump_block_mxfp4(&x[i * 8 + 3], 3);
+ dump_block_mxfp4(&x[i * 8 + 4], 4);
+ dump_block_mxfp4(&x[i * 8 + 5], 5);
+ dump_block_mxfp4(&x[i * 8 + 6], 6);
+ dump_block_mxfp4(&x[i * 8 + 7], 7);
+ }
+ }
+
+ // Repack the quants
+ for (int i = 0; i < nb; i++) {
+ uint8_t qs[QK_MXFP4x4x2]; // unpacked quants
+
+ unpack_mxfp4_quants(qs, &x[i * 8 + 0], 0);
+ unpack_mxfp4_quants(qs, &x[i * 8 + 1], 1);
+ unpack_mxfp4_quants(qs, &x[i * 8 + 2], 2);
+ unpack_mxfp4_quants(qs, &x[i * 8 + 3], 3);
+ unpack_mxfp4_quants(qs, &x[i * 8 + 4], 4);
+ unpack_mxfp4_quants(qs, &x[i * 8 + 5], 5);
+ unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6);
+ unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7);
+
+ uint8_t * q = y_q + (i * qblk_size);
+ for (int j = 0; j < qk / 2; j++) {
+ q[j] = (qs[j + 128] << 4) | qs[j];
+ }
+ }
+
+ // Repack the scales
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
+ // the last block is truncated and overriden by the scales.
+ for (int i = 0; i < nb; i++) {
+ // Repack the scales
+ uint8_t * e = (uint8_t *) (y_e + i * eblk_size);
+ e[0] = x[i * 8 + 0].e;
+ e[1] = x[i * 8 + 1].e;
+ e[2] = x[i * 8 + 2].e;
+ e[3] = x[i * 8 + 3].e;
+ e[4] = x[i * 8 + 4].e;
+ e[5] = x[i * 8 + 5].e;
+ e[6] = x[i * 8 + 6].e;
+ e[7] = x[i * 8 + 7].e;
+ }
+
+ if (opt_verbose > 1) {
+ for (int i = 0; i < nb; i++) {
+ dump_packed_block_mxfp4x4x2(y, i, k);
+ }
+ }
+}
+
+static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) {
+ static const int qk = QK_MXFP4x4x2;
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
+
+ const int eblk_size = 8 * 1; // 8x E8M0
+ const int qblk_size = qk / 2; // int4
+ const int qrow_size = k / 2; // int4 (not padded to blocks)
+
+ const uint8_t * y_q = y + 0; // quants first
+ const uint8_t * y_e = y + qrow_size; // then scales
+
+ if (opt_verbose > 1) {
+ for (int i = 0; i < nb; i++) {
+ dump_packed_block_mxfp4x4x2(y, i, k);
+ }
+ }
+
+ // Unpack the quants
+ for (int i = 0; i < nb; i++) {
+ uint8_t qs[QK_MXFP4x4x2]; // unpacked quants
+
+ const uint8_t * q = y_q + (i * qblk_size);
+ for (int j = 0; j < qk / 2; j++) {
+ qs[j] = q[j] & 0xf;
+ qs[j + 128] = q[j] >> 4;
+ }
+
+ pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
+ pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);
+ pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);
+ pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);
+ pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);
+ pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);
+ pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);
+ pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);
+ }
+
+ // Repack the scales
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2)
+ // the last block is truncated and overriden by the scales.
+ for (int i = 0; i < nb; i++) {
+ // Unpack the scales
+ const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size);
+ x[i * 8 + 0].e = e[0];
+ x[i * 8 + 1].e = e[1];
+ x[i * 8 + 2].e = e[2];
+ x[i * 8 + 3].e = e[3];
+ x[i * 8 + 4].e = e[4];
+ x[i * 8 + 5].e = e[5];
+ x[i * 8 + 6].e = e[6];
+ x[i * 8 + 7].e = e[7];
+ }
+
+ if (opt_verbose > 2) {
+ for (int i = 0; i < nb; i++) {
+ dump_block_mxfp4(&x[i * 8 + 0], 0);
+ dump_block_mxfp4(&x[i * 8 + 1], 1);
+ dump_block_mxfp4(&x[i * 8 + 2], 2);
+ dump_block_mxfp4(&x[i * 8 + 3], 3);
+ dump_block_mxfp4(&x[i * 8 + 4], 4);
+ dump_block_mxfp4(&x[i * 8 + 5], 5);
+ dump_block_mxfp4(&x[i * 8 + 6], 6);
+ dump_block_mxfp4(&x[i * 8 + 7], 7);
+ }
+ }
+}
+
+static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) {
+ static const int qk = QK_MXFP4x4x2;
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
+
+ // Init the quants such that they unpack into zeros
+ uint8_t qs[QK_MXFP4x4x2]; // unpacked quants
+ memset(qs, 0, sizeof(qs));
+
+ for (int i = 0; i < nb; i++) {
+ pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
+ pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);
+ pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);
+ pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);
+ pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);
+ pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);
+ pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);
+ pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);
+ }
+
+ // Init the scales
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
+ // the last block is truncated and overriden by the scales.
+ for (int i = 0; i < nb; i++) {
+ // Unpack the scales
+ x[i * 8 + 0].e = 0;
+ x[i * 8 + 1].e = 0;
+ x[i * 8 + 2].e = 0;
+ x[i * 8 + 3].e = 0;
+ x[i * 8 + 4].e = 0;
+ x[i * 8 + 5].e = 0;
+ x[i * 8 + 6].e = 0;
+ x[i * 8 + 7].e = 0;
+ }
+}
+
+// repack mxfp4 data into mxfp4x4x2 tensor
+static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t size) {
+ int64_t nrows = ggml_nrows(t);
+
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
+
+ // Ensure we don't try to read more data than is available in the source buffer 'data'
+ // or write more than the tensor can hold.
+ const size_t total_tensor_size = (size_t)nrows * row_size;
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+ // Calculate how many full rows and how many remaining bytes we need to process.
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
+
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
+ GGML_ASSERT(buf_pd != NULL);
+
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
+ GGML_ASSERT(buf_rp != NULL);
+
+ HEX_VERBOSE("ggml-hex: repack-mxfp4-mxfp4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data,
+ size, t->ne[0], nrows, row_size);
+
+ init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
+
+ // 1. Process all the full rows
+ for (int64_t i = 0; i < n_full_rows; i++) {
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
+
+ memcpy(buf_pd, src, row_size);
+ repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
+ memcpy(dst, buf_rp, row_size);
+ }
+
+ // 2. Process the final, potentially partial, row
+ if (n_rem_bytes > 0) {
+ const int64_t i = n_full_rows;
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
+
+ // re-init the row because we are potentially copying a partial row
+ init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);
+
+ // Copy only the remaining bytes from the source.
+ memcpy(buf_pd, src, n_rem_bytes);
+
+ // Repack the entire buffer (partial data + zero padding).
+ repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
+
+ // Write only the corresponding remaining bytes to the destination tensor.
+ memcpy(dst, buf_rp, n_rem_bytes);
+ }
+
+ ggml_aligned_free(buf_pd, row_size_pd);
+ ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+// repack mxfp4x4x2 tensor into mxfp4 data
+static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t size) {
+ int64_t nrows = ggml_nrows(t);
+
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
+
+ // Ensure we don't try to copy more data than the tensor actually contains.
+ const size_t total_tensor_size = (size_t)nrows * row_size;
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+ // Calculate how many full rows and how many remaining bytes we need to process.
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
+
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
+ GGML_ASSERT(buf_pd != NULL);
+
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
+ GGML_ASSERT(buf_rp != NULL);
+
+ HEX_VERBOSE("ggml-hex: repack-mxfp4x4x2-mxfp4 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data,
+ size, t->ne[0], nrows, row_size);
+
+ memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
+
+ // 1. Process all the full rows
+ for (int64_t i = 0; i < n_full_rows; i++) {
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
+
+ memcpy(buf_pd, src, row_size);
+ unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+ memcpy(dst, buf_rp, row_size);
+ }
+
+ // 2. Process the final, potentially partial, row
+ if (n_rem_bytes > 0) {
+ const int64_t i = n_full_rows;
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
+
+ // We still need to read and unpack the entire source row because the format is block-based.
+ memcpy(buf_pd, src, row_size);
+ unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+
+ // But we only copy the remaining number of bytes to the destination to respect the size limit.
+ memcpy(dst, buf_rp, n_rem_bytes);
+ }
+
+ ggml_aligned_free(buf_pd, row_size_pd);
+ ggml_aligned_free(buf_rp, row_size_rp);
+}
+
+static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
+ ggml_tensor * tensor,
+ const void * data,
+ size_t offset,
+ size_t size) {
+ auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context;
+ auto sess = ctx->sess;
+
+ HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data,
+ offset, size);
+
+ switch (tensor->type) {
+ case GGML_TYPE_Q4_0:
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
+ repack_q4_0_q4x4x2(tensor, data, size);
+ break;
+
+ case GGML_TYPE_Q8_0:
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
+ repack_q8_0_q8x4x2(tensor, data, size);
+ break;
+
+ case GGML_TYPE_MXFP4:
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
+ repack_mxfp4_mxfp4x4x2(tensor, data, size);
+ break;
+
+ default:
+ memcpy((char *) tensor->data + offset, data, size);
+ break;
+ }
+}
+
+static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
+ const ggml_tensor * tensor,
+ void * data,
+ size_t offset,
+ size_t size) {
+ auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context;
+ auto sess = ctx->sess;
+
+ HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data,
+ offset, size);
+
+ switch (tensor->type) {
+ case GGML_TYPE_Q4_0:
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
+ repack_q4x4x2_q4_0(data, tensor, size);
+ break;
+
+ case GGML_TYPE_Q8_0:
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
+ repack_q8x4x2_q8_0(data, tensor, size);
+ break;
+
+ case GGML_TYPE_MXFP4:
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
+ repack_mxfp4x4x2_mxfp4(data, tensor, size);
+ break;
+
+ default:
+ memcpy(data, (const char *) tensor->data + offset, size);
+ break;
+ }
+}
+
+static bool ggml_backend_hexagon_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
+ const struct ggml_tensor * src,
+ struct ggml_tensor * dst) {
+ GGML_UNUSED(buffer);
+ GGML_UNUSED(src);
+ GGML_UNUSED(dst);
+ // we might optimize this later, for now take the slow path (ie get/set_tensor)
+ return false;
+}
+
+static void ggml_backend_hexagon_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context;
+ auto sess = ctx->sess;
+ HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->name.c_str(), (void *) ctx->base, ctx->size);
+ memset(ctx->base, value, ctx->size);
+}
+
+static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = {
+ /* .free_buffer = */ ggml_backend_hexagon_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_hexagon_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_hexagon_buffer_init_tensor,
+ /* .memset_tensor = */ NULL,
+ /* .set_tensor = */ ggml_backend_hexagon_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_hexagon_buffer_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_hexagon_buffer_cpy_tensor,
+ /* .clear = */ ggml_backend_hexagon_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+// ** backend buffer type
+
+static const char * ggml_backend_hexagon_buffer_type_name(ggml_backend_buffer_type_t buffer_type) {
+ return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->name.c_str();
+}
+
+static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer(
+ ggml_backend_buffer_type_t buffer_type, size_t size) {
+ auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
+ try {
+ ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/);
+ return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);
+ } catch (const std::exception & exc) {
+ GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what());
+ return nullptr;
+ }
+}
+
+static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffer(
+ ggml_backend_buffer_type_t buffer_type, size_t size) {
+ auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
+ try {
+ ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/);
+ return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);
+ } catch (const std::exception & exc) {
+ GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what());
+ return nullptr;
+ }
+}
+
+static size_t ggml_backend_hexagon_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
+ return 128; // HVX alignment
+ GGML_UNUSED(buffer_type);
+}
+
+static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * t) {
+ return ggml_nbytes(t);
+}
+
+static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
+ return 1 * 1024 * 1024 * 1024; // 1GB per buffer
+ GGML_UNUSED(buffer_type);
+}
+
+static bool ggml_backend_hexagon_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+ return opt_hostbuf;
+ GGML_UNUSED(buft);
+}
+
+static bool ggml_backend_hexagon_repack_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+ return false;
+ GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_hexagon_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_hexagon_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_hexagon_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_hexagon_buffer_type_get_alignment,
+ /* .get_max_size = */ ggml_backend_hexagon_buffer_type_get_max_size,
+ /* .get_alloc_size = */ ggml_backend_hexagon_buffer_type_get_alloc_size,
+ /* .is_host = */ ggml_backend_hexagon_buffer_type_is_host,
+};
+
+static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_hexagon_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_hexagon_repack_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_hexagon_buffer_type_get_alignment,
+ /* .get_max_size = */ ggml_backend_hexagon_buffer_type_get_max_size,
+ /* .get_alloc_size = */ ggml_backend_hexagon_buffer_type_get_alloc_size,
+ /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host,
+};
+
+void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
+ this->valid_session = false;
+ this->valid_handle = false;
+ this->valid_queue = false;
+ this->valid_iface = false;
+
+ this->domain_id = 3; // Default for CDSP, updated after the session is created
+ this->session_id = 0; // Default for CDSP, updated after the session is created
+ this->dev_id = dev_id;
+ this->name = std::string("HTP") + std::to_string(dev_id);
+
+ this->op_pending = 0;
+ this->prof_usecs = 0;
+ this->prof_cycles = 0;
+ this->prof_pkts = 0;
+
+ GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str());
+
+ domain * my_domain = get_domain(this->domain_id);
+ if (my_domain == NULL) {
+ GGML_LOG_ERROR("ggml-hex: unable to get domain struct for CDSP\n");
+ throw std::runtime_error("ggml-hex: failed to get CDSP domain (see log for details)");
+ }
+
+ // Create new session
+ if (dev_id != 0) {
+ struct remote_rpc_reserve_new_session n;
+ n.domain_name_len = strlen(CDSP_DOMAIN_NAME);
+ n.domain_name = const_cast<char *>(CDSP_DOMAIN_NAME);
+ n.session_name = const_cast<char *>(this->name.c_str());
+ n.session_name_len = this->name.size();
+
+ int err = remote_session_control(FASTRPC_RESERVE_NEW_SESSION, (void *) &n, sizeof(n));
+ if (err != AEE_SUCCESS) {
+ GGML_LOG_ERROR("ggml-hex: failed to reserve new session %d : error 0x%x\n", dev_id, err);
+ throw std::runtime_error("ggml-hex: remote_session_control(new-sess) failed (see log for details)");
+ }
+
+ // Save the IDs
+ this->session_id = n.session_id;
+ this->domain_id = n.effective_domain_id;
+ this->valid_session = true;
+ }
+
+ // Get session URI
+
+ char session_uri[256];
+ {
+ char htp_uri[256];
+ snprintf(htp_uri, sizeof(htp_uri), "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch);
+
+ struct remote_rpc_get_uri u = {};
+ u.session_id = this->session_id;
+ u.domain_name = const_cast<char *>(CDSP_DOMAIN_NAME);
+ u.domain_name_len = strlen(CDSP_DOMAIN_NAME);
+ u.module_uri = const_cast<char *>(htp_uri);
+ u.module_uri_len = strlen(htp_uri);
+ u.uri = session_uri;
+ u.uri_len = sizeof(session_uri);
+
+ int err = remote_session_control(FASTRPC_GET_URI, (void *) &u, sizeof(u));
+ if (err != AEE_SUCCESS) {
+ // fallback to single session uris
+ int htp_URI_domain_len = strlen(htp_uri) + MAX_DOMAIN_NAMELEN;
+
+ snprintf(session_uri, htp_URI_domain_len, "%s%s", htp_uri, my_domain->uri);
+
+ GGML_LOG_WARN("ggml-hex: failed to get URI for session %d : error 0x%x. Falling back to single session URI: %s\n", dev_id, err, session_uri);
+ }
+ }
+
+ // Enable Unsigned PD
+ {
+ struct remote_rpc_control_unsigned_module u;
+ u.domain = this->domain_id;
+ u.enable = 1;
+ int err = remote_session_control(DSPRPC_CONTROL_UNSIGNED_MODULE, (void *) &u, sizeof(u));
+ if (err != AEE_SUCCESS) {
+ GGML_LOG_ERROR("ggml-hex: failed to enable unsigned PD for session %d : error 0x%x\n", dev_id, err);
+ throw std::runtime_error("ggml-hex: remote_session_control(unsign) failed (see log for details)");
+ }
+ }
+
+ // Open session
+ int err = htp_iface_open(session_uri, &this->handle);
+ if (err != AEE_SUCCESS) {
+ GGML_LOG_ERROR("ggml-hex: failed to open session %d : error 0x%x\n", dev_id, err);
+ throw std::runtime_error("ggml-hex: failed to open session (see log for details)");
+ }
+
+ this->valid_handle = true;
+
+ GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(),
+ this->session_id, this->domain_id, session_uri, (unsigned long) this->handle);
+
+ // Enable FastRPC QoS mode
+ {
+ struct remote_rpc_control_latency l;
+ l.enable = 1;
+
+ int err = remote_handle64_control(this->handle, DSPRPC_CONTROL_LATENCY, (void *) &l, sizeof(l));
+ if (err != 0) {
+ GGML_LOG_WARN("ggml-hex: failed to enable fastrpc QOS mode: 0x%08x\n", (unsigned) err);
+ }
+ }
+
+ // Now let's setup the DSP queue
+ err = dspqueue_create(this->domain_id,
+ 0, // Flags
+ 128 * 1024, // Request queue size (in bytes)
+ 64 * 1024, // Response queue size (in bytes)
+ nullptr, // Read packet callback (we handle reads explicitly)
+ nullptr, // Error callback (we handle errors during reads)
+ (void *) this, // Callback context
+ &queue);
+ if (err != 0) {
+ GGML_LOG_ERROR("ggml-hex: %s dspqueue_create failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
+ throw std::runtime_error("ggml-hex: failed to create dspqueue (see log for details)");
+ }
+
+ this->valid_queue = true;
+
+ // Export queue for use on the DSP
+ err = dspqueue_export(queue, &this->queue_id);
+ if (err != 0) {
+ GGML_LOG_ERROR("ggml-hex: dspqueue_export failed: 0x%08x\n", (unsigned) err);
+ throw std::runtime_error("ggml-hex: dspqueue export failed (see log for details)");
+ }
+
+ if (opt_etm) {
+ err = htp_iface_enable_etm(this->handle);
+ if (err != 0) {
+ GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err);
+ }
+ }
+
+ // Start the DSP-side service. We need to pass the queue ID to the
+ // DSP in a FastRPC call; the DSP side will import the queue and start
+ // listening for packets in a callback.
+ err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx);
+ if (err != 0) {
+ GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
+ throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
+ }
+ this->valid_iface = true;
+}
+
+void ggml_hexagon_session::release() noexcept(true) {
+ GGML_LOG_INFO("ggml-hex: releasing session: %s\n", this->name.c_str());
+
+ int err;
+
+ // Stop the DSP-side service and close the queue
+ if (this->valid_iface) {
+ err = htp_iface_stop(this->handle);
+ if (err != 0) {
+ GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err);
+ }
+ }
+
+ if (opt_etm) {
+ err = htp_iface_disable_etm(this->handle);
+ if (err != 0) {
+ GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err);
+ }
+ }
+
+ if (this->valid_queue) {
+ err = dspqueue_close(queue);
+ if (err != 0) {
+ GGML_ABORT("ggml-hex: dspqueue_close failed: 0x%08x\n", (unsigned) err);
+ }
+ }
+
+ if (this->valid_handle) {
+ htp_iface_close(this->handle);
+ }
+}
+
+ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) {
+ buffer_type.device = dev;
+ repack_buffer_type.device = dev;
+
+ try {
+ allocate(dev_id);
+
+ buffer_type.iface = ggml_backend_hexagon_buffer_type_interface;
+ buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name, this);
+
+ repack_buffer_type.iface = ggml_backend_hexagon_repack_buffer_type_interface;
+ repack_buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name + "-REPACK", this);
+ } catch (const std::exception & exc) {
+ release();
+ throw;
+ }
+}
+
+ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {
+ release();
+
+ delete static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type.context);
+ delete static_cast<ggml_backend_hexagon_buffer_type_context *>(repack_buffer_type.context);
+}
+
+// ** backend interface
+
+static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) {
+ return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment;
+}
+
+static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) {
+ if (!opt_hostbuf) {
+ return ggml_backend_buffer_is_hexagon(b);
+ }
+ return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
+}
+
+static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
+ if (x->ne[0] != y->ne[0]) {
+ return false;
+ }
+ if (x->ne[1] != y->ne[1]) {
+ return false;
+ }
+ if (x->ne[2] != y->ne[2]) {
+ return false;
+ }
+ if (x->ne[3] != y->ne[3]) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * src1 = op->src[1];
+ const struct ggml_tensor * src2 = op->src[2];
+ const struct ggml_tensor * src3 = op->src[3];
+ const struct ggml_tensor * src4 = op->src[4];
+ const struct ggml_tensor * dst = op;
+
+ // Check for F16 support only as requested
+ if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {
+ return false;
+ }
+
+ if (src3 && src3->type != GGML_TYPE_F16) { // mask
+ return false;
+ }
+
+ if (src4 && src4->type != GGML_TYPE_F32) { // sinks
+ return false;
+ }
+
+ // For now we support F32 or F16 output as htp backend often converts output on the fly if needed,
+ // but the op implementation writes to F16 or F32.
+ // Let's assume dst can be F32 or F16.
+ if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
+ return false;
+ }
+
+ return opt_experimental;
+}
+
+static bool hex_supported_src0_type(ggml_type t) {
+ return t == GGML_TYPE_F32;
+}
+
+static bool hex_supported_src1_type(ggml_type t) {
+ return t == GGML_TYPE_F32;
+}
+
+static bool hex_supported_src2_type(ggml_type t) {
+ return t == GGML_TYPE_F32;
+}
+
+static bool hex_supported_src1_type2(ggml_type t) {
+ return t == GGML_TYPE_F16;
+}
+
+static bool hex_supported_src1_type3(ggml_type t) {
+ return t == GGML_TYPE_I32;
+}
+
+static bool hex_supported_dst_type(ggml_type t) {
+ return t == GGML_TYPE_F32;
+}
+
+static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
+ // TODO: support broadcast for ne[2 and 3]
+ if (x->ne[0] != y->ne[0]) {
+ return false;
+ }
+ if (x->ne[2] != y->ne[2]) {
+ return false;
+ }
+ if (x->ne[3] != y->ne[3]) {
+ return false;
+ }
+ return true;
+}
+
+static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ if (dst->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
+ return false;
+ }
+
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
+ if (src0->ne[0] % 32) {
+ return false;
+ }
+
+ if (src0->ne[1] > 16 * 1024) {
+ return false; // typically the lm-head which would be too large for VTCM
+ }
+
+ if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
+ return false;
+ }
+
+ // src0 (weights) must be repacked
+ if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {
+ return false;
+ }
+ break;
+
+ case GGML_TYPE_F16:
+ if (src0->nb[1] < src0->nb[0]) {
+ GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n");
+ return false;
+ }
+ break;
+
+ default:
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * src1 = op->src[1];
+ const struct ggml_tensor * src2 = op->src[2];
+ const struct ggml_tensor * dst = op;
+
+ if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32 || src2->type != GGML_TYPE_I32) {
+ return false;
+ }
+
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
+ if ((src0->ne[0] % 32)) {
+ return false;
+ }
+
+ // src0 (weights) must be repacked
+ if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {
+ return false;
+ }
+ break;
+
+ default:
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * src1 = op->src[1];
+ const struct ggml_tensor * dst = op;
+
+ if (!hex_supported_src0_type(src0->type)) {
+ return false;
+ }
+ if (!hex_supported_src1_type(src1->type)) {
+ return false;
+ }
+ if (!hex_supported_dst_type(dst->type)) {
+ return false;
+ }
+ if (!hex_supported_dims2(src0, dst)) {
+ return false;
+ }
+ if (!ggml_can_repeat(src1, src0)) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * src1 = op->src[1];
+ const struct ggml_tensor * dst = op;
+
+ if (!hex_supported_src0_type(src0->type)) {
+ return false;
+ }
+ if (!hex_supported_src1_type(src1->type)) {
+ return false;
+ }
+ if (!hex_supported_dst_type(dst->type)) {
+ return false;
+ }
+ if (!hex_supported_dims2(src0, dst)) {
+ return false;
+ }
+
+ // REVISIT: add support for non-contigiuos tensors
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * dst = op;
+
+ if (!hex_supported_src0_type(src0->type)) {
+ return false;
+ }
+ if (!hex_supported_dst_type(dst->type)) {
+ return false;
+ }
+ if (!hex_supported_dims2(src0, dst)) {
+ return false;
+ }
+
+ // TODO: add support for non-contigiuos tensors
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * dst = op;
+
+ if (!hex_supported_src0_type(src0->type)) {
+ return false;
+ }
+ if (!hex_supported_dst_type(dst->type)) {
+ return false;
+ }
+
+ // TODO: add support for non-contigiuos tensors
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
+ const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * src1 = op->src[1];
+ const struct ggml_tensor * dst = op;
+
+ if (!hex_supported_src0_type(src0->type)) {
+ return false;
+ }
+ if (!hex_supported_dst_type(dst->type)) {
+ return false;
+ }
+
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
+ return false;
+ }
+
+ if (src1) {
+ if (!hex_supported_src1_type(src1->type)) {
+ return false;
+ }
+ if (!hex_supported_dims2(src0, src1)) {
+ return false;
+ }
+ if (!ggml_is_contiguous(src1)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * src1 = op->src[1];
+ const struct ggml_tensor * src2 = op->src[2];
+ const struct ggml_tensor * dst = op;
+
+ if (src2) {
+ return false; // FIXME: add support for sinks
+ }
+
+ if (!hex_supported_src0_type(src0->type)) {
+ return false;
+ }
+ if (!hex_supported_dst_type(dst->type)) {
+ return false;
+ }
+
+ if (src1) {
+ if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
+ return false;
+ }
+ if (src0->ne[0] != src1->ne[0]) {
+ return false;
+ }
+ if (src1->ne[1] < src0->ne[1]) {
+ return false;
+ }
+ if (src0->ne[2] % src1->ne[2] != 0) {
+ return false;
+ }
+ if (src0->ne[3] % src1->ne[3] != 0) {
+ return false;
+ }
+ }
+
+ if (src1) {
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+ return false;
+ }
+ } else {
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0]; // values
+ const struct ggml_tensor * src1 = op->src[1]; // indices
+ const struct ggml_tensor * dst = op;
+
+ if (src0->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
+ return false;
+ }
+
+ if (dst->type != GGML_TYPE_F16) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0]; // values
+ const struct ggml_tensor * src1 = op->src[1]; // indices
+ const struct ggml_tensor * dst = op;
+
+ if (src0->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
+ return false;
+ }
+
+ if (dst->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0]; // values
+ const struct ggml_tensor * dst = op; // indices
+
+ if (src0->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (dst->type != GGML_TYPE_I32) {
+ return false;
+ }
+
+ if (src0->ne[0] > (16*1024)) {
+ // reject tensors with huge rows for now
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const int32_t * op_params = &op->op_params[0];
+
+ int mode = op_params[2];
+
+ if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
+ return false;
+ }
+ if (mode & 1) {
+ return false;
+ }
+
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * src1 = op->src[1];
+ const struct ggml_tensor * src2 = op->src[2];
+ const struct ggml_tensor * dst = op;
+
+ if (!hex_supported_src0_type(src0->type)) {
+ return false; // FIXME: add support for GGML_TYPE_F16 for src0
+ }
+ if (!hex_supported_dst_type(dst->type)) {
+ return false;
+ }
+ if (!hex_supported_src1_type3(src1->type)) {
+ return false;
+ }
+ if (src2) {
+ if (!hex_supported_src2_type(src2->type)) {
+ return false;
+ }
+ int n_dims = op_params[1];
+ if (src2->ne[0] < (n_dims / 2)) {
+ return false;
+ }
+ }
+
+ if (src2) {
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(src2) ||
+ !ggml_is_contiguous(dst)) {
+ return false;
+ }
+ } else {
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+enum dspqbuf_type {
+ DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
+ DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
+ DSPQBUF_TYPE_CONSTANT,
+};
+
+static void dspqbuf_dump(dspqueue_buffer * d, const struct ggml_tensor * t, dspqbuf_type type) {
+ if (opt_verbose < 2) return;
+
+ auto buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);
+ auto sess = buf->sess;
+
+ GGML_LOG_DEBUG("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(),
+ t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset,
+ (unsigned int) d->size);
+}
+
+// Init hexagon tensor from GGML tensor and Hexagon buffer
+static void htp_req_tensor_init(htp_tensor * h, const ggml_tensor * t) {
+ h->data = 0; // updated by the receiver
+ h->type = t->type;
+ h->ne[0] = t->ne[0];
+ h->ne[1] = t->ne[1];
+ h->ne[2] = t->ne[2];
+ h->ne[3] = t->ne[3];
+ h->nb[0] = t->nb[0];
+ h->nb[1] = t->nb[1];
+ h->nb[2] = t->nb[2];
+ h->nb[3] = t->nb[3];
+}
+
+static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_tensor * t, dspqbuf_type type) {
+ if (!t) {
+ return 0;
+ }
+
+ auto buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);
+
+ memset(d, 0, sizeof(*d));
+ d->fd = buf->fd;
+ d->ptr = t->data;
+ d->offset = (uint8_t *) t->data - buf->base;
+ d->size = ggml_nbytes(t);
+
+ if (!d->size) {
+ // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty
+ d->size = 64;
+ }
+
+ switch (type) {
+ case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:
+ // Flush CPU
+ d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER;
+ break;
+ case DSPQBUF_TYPE_CPU_WRITE_DSP_READ:
+ // Flush CPU, Invalidate DSP
+ d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
+ break;
+ default:
+ // Constant buffer, no cache maintenance
+ d->flags = 0;
+ break;
+ }
+
+ htp_req_tensor_init(h, t);
+
+ dspqbuf_dump(d, t, type);
+
+ return 1;
+}
+
+typedef size_t (*htp_req_init_func_t)(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * op);
+
+template <htp_req_init_func_t _init_req_func>
+static inline void ggml_hexagon_dispatch_op(ggml_hexagon_session *sess, const struct ggml_tensor * op, uint32_t flags) {
+ uint64_t t = ggml_time_us();
+
+ // Construct HTP request
+ htp_general_req req;
+ memset(&req, 0, sizeof(req));
+
+ req.flags = flags;
+ if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) {
+ req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
+ }
+ if (!(opt_opmask & HTP_OPMASK_COMPUTE)) {
+ req.flags |= HTP_OPFLAGS_SKIP_COMPUTE;
+ }
+
+ ggml_hexagon_dump_op_exec(sess->name, op, req.flags);
+
+ if ((opt_opmask & HTP_OPMASK_QUEUE)) {
+ dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
+ size_t n_bufs = _init_req_func(&req, bufs, op);
+ sess->enqueue(req, bufs, n_bufs, opt_opsync);
+ }
+
+ t = ggml_time_us() - t;
+
+ ggml_hexagon_dump_op_prof(sess->name, op, sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, t);
+}
+
+template <bool _is_src0_constant>
+static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ switch (t->op) {
+ case GGML_OP_MUL_MAT:
+ req->op = HTP_OP_MUL_MAT;
+ break;
+ case GGML_OP_MUL:
+ req->op = HTP_OP_MUL;
+ break;
+ case GGML_OP_ADD:
+ req->op = HTP_OP_ADD;
+ break;
+ case GGML_OP_SUB:
+ req->op = HTP_OP_SUB;
+ break;
+ case GGML_OP_DIV:
+ req->op = HTP_OP_DIV;
+ break;
+ default:
+ GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
+ break;
+ }
+
+ // src0: Weights (mulmat) or First Operand (binary op).
+ // If constant (e.g. weights), no cache management is needed.
+ // src1: Input Activations (mulmat) or Second Operand (binary op).
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
+static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ req->op = HTP_OP_CPY;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
+static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ req->op = HTP_OP_GET_ROWS;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
+static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ req->op = HTP_OP_ARGSORT;
+ memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
+template <bool _is_src0_constant>
+static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ switch (t->op) {
+ case GGML_OP_MUL_MAT_ID:
+ req->op = HTP_OP_MUL_MAT_ID;
+ break;
+ case GGML_OP_ADD_ID:
+ req->op = HTP_OP_ADD_ID;
+ break;
+ default:
+ GGML_ABORT("ggml-hex: unsupported op: %d\n", t->op);
+ }
+
+ // src0: Weights (mulmat) or Input Activations (other op).
+ // If constant, no cache management is needed.
+ // src1: Input Activations (mulmat) or Second Operand (binary op).
+ // src2: Expert IDs (mulmat) or Activated Experts (other op).
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
+static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ req->op = HTP_OP_SET_ROWS;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
+static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+
+ bool supported = false;
+
+ switch (t->op) {
+ case GGML_OP_RMS_NORM:
+ req->op = HTP_OP_RMS_NORM;
+ supported = true;
+ break;
+
+ case GGML_OP_SCALE:
+ req->op = HTP_OP_SCALE;
+ supported = true;
+ break;
+
+ case GGML_OP_SQR:
+ req->op = HTP_OP_SQR;
+ supported = true;
+ break;
+
+ case GGML_OP_SQRT:
+ req->op = HTP_OP_SQRT;
+ supported = true;
+ break;
+
+ case GGML_OP_UNARY:
+ if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
+ req->op = HTP_OP_UNARY_SILU;
+ supported = true;
+ } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) {
+ req->op = HTP_OP_UNARY_GELU;
+ supported = true;
+ }
+ break;
+
+ case GGML_OP_GLU:
+ if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU) {
+ req->op = HTP_OP_GLU_SWIGLU;
+ supported = true;
+ } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
+ req->op = HTP_OP_GLU_SWIGLU_OAI;
+ supported = true;
+ } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
+ req->op = HTP_OP_GLU_GEGLU;
+ supported = true;
+ }
+ break;
+
+ case GGML_OP_SOFT_MAX:
+ req->op = HTP_OP_SOFTMAX;
+ supported = true;
+ break;
+
+ default:
+ break;
+ }
+
+ if (!supported) {
+ GGML_ABORT("ggml-hex: unary : unsupported op: %d\n", t->op);
+ }
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
+static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+ req->op = HTP_OP_SUM_ROWS;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
+static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+ req->op = HTP_OP_ROPE;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
+static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+ req->op = HTP_OP_FLASH_ATTN_EXT;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
+static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
+ auto sess = static_cast<ggml_hexagon_session *>(backend->context);
+ return sess->name.c_str();
+}
+
+static void ggml_backend_hexagon_free(ggml_backend_t backend) {
+ // we just need to delete the backend here
+ // the sessions are allocated & freed as part of the registry
+ delete backend;
+}
+
+static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
+ return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
+}
+
+static inline bool is_compute_op(ggml_tensor *node)
+{
+ return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
+}
+
+// scan the graph and figure out last compute op index
+static inline int last_compute_op(ggml_cgraph * graph) {
+ int last = 0;
+ for (int i = 0; i < graph->n_nodes; ++i) {
+ if (is_compute_op(graph->nodes[i])) {
+ last = i;
+ }
+ }
+
+ return last;
+}
+
+static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
+ auto sess = static_cast<ggml_hexagon_session *>(backend->context);
+
+ HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes);
+
+ const int last = last_compute_op(graph);
+
+ const struct ggml_tensor * prev_op = nullptr; // prev executed op
+
+ for (int i = 0; i < graph->n_nodes; ++i) {
+ ggml_tensor * node = graph->nodes[i];
+
+ if (!is_compute_op(node)) {
+ continue;
+ }
+
+ uint32_t flags = 0;
+
+ // skip quantizer if src1 is reused
+ if (op_reuse_src1(node, prev_op)) {
+ flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
+ }
+
+ prev_op = node;
+
+ // ask for early notification for the last Op
+ if (i == last) {
+ flags |= HTP_OPFLAGS_EARLY_WAKEUP;
+ }
+
+ switch (node->op) {
+ case GGML_OP_MUL_MAT:
+ if (ggml_is_quantized(node->src[0]->type)) {
+ ggml_hexagon_dispatch_op<init_binary_req<true>>(sess, node, flags);
+ } else {
+ ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
+ }
+ break;
+ case GGML_OP_MUL_MAT_ID:
+ if (ggml_is_quantized(node->src[0]->type)) {
+ ggml_hexagon_dispatch_op<init_binary_id_req<true>>(sess, node, flags);
+ } else {
+ ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
+ }
+ break;
+ case GGML_OP_MUL:
+ case GGML_OP_ADD:
+ case GGML_OP_SUB:
+ case GGML_OP_DIV:
+ ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
+ break;
+ case GGML_OP_ADD_ID:
+ ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
+ break;
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_SCALE:
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+ break;
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+ break;
+ case GGML_OP_SUM_ROWS:
+ ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
+ break;
+ case GGML_OP_UNARY:
+ if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
+ (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+ }
+ break;
+ case GGML_OP_GLU:
+ if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
+ (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
+ (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+ }
+ break;
+ case GGML_OP_SOFT_MAX:
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+ break;
+
+ case GGML_OP_ROPE:
+ ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);
+ break;
+
+ case GGML_OP_FLASH_ATTN_EXT:
+ ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);
+ break;
+
+ case GGML_OP_SET_ROWS:
+ ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);
+ break;
+
+ case GGML_OP_GET_ROWS:
+ ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
+ break;
+
+ case GGML_OP_CPY:
+ ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
+ break;
+
+ case GGML_OP_ARGSORT:
+ ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
+ break;
+
+ default:
+ GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
+ }
+ }
+
+ // Wait until all pending ops complete
+ sess->flush();
+
+ return GGML_STATUS_SUCCESS;
+}
+
+static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) {
+ auto sess = static_cast<ggml_hexagon_session *>(backend->context);
+
+ HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str());
+
+ // Wait until all pending ops complete
+ sess->flush();
+}
+
+struct node_info {
+ ggml_tensor * node;
+
+ std::vector<ggml_tensor *> fused;
+
+ ggml_op op() const {
+ return node->op;
+ }
+
+ const ggml_tensor * dst() const {
+ return fused.empty() ? node : fused.back();
+ }
+
+ const ggml_tensor * src0() const {
+ return node->src[0];
+ }
+
+ const ggml_tensor * src1() const {
+ return node->src[1];
+ }
+
+ bool is_empty() const {
+ return ggml_op_is_empty(node->op);
+ }
+
+ void add_fused(ggml_tensor * t) {
+ fused.push_back(t);
+ }
+
+ bool stackable() const {
+ switch (this->op()) {
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ return ggml_is_quantized(this->src0()->type);
+ default:
+ return false;
+ }
+ }
+
+ bool same_input(const node_info& n) const {
+ return n.src1() == this->src1();
+ }
+};
+
+static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<node_info> & nodes) {
+ const int n = nodes.size();
+
+ std::vector<int> res;
+ res.reserve(n);
+
+ std::vector<bool> used(n, false);
+
+ // The main goal here is to stack the MUL_MAT ops with the same src1 input.
+ // This allows use to reuse dynamically quantized src1 in VTCM.
+
+ // TODO: the current version might do incorrect reodering in cases where quantized src0
+ // input is an output of another Op.
+
+ for (int i0 = 0; i0 < n; i0++) {
+ if (used[i0]) {
+ continue;
+ }
+
+ res.push_back(i0);
+
+ const auto & node0 = nodes[i0];
+
+ if (!node0.stackable()) {
+ continue;
+ }
+
+ // that many nodes forward to search for stackable nodes that can reuse VTCM
+ constexpr int N_FORWARD = 16;
+
+ for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
+ if (used[i1]) {
+ continue;
+ }
+
+ const auto & node1 = nodes[i1];
+
+ if (node1.stackable() && node1.same_input(node0)) {
+ res.push_back(i1);
+ used[i1] = true;
+ }
+ }
+ }
+
+ return res;
+}
+
+static void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgraph * gf) {
+ const int n = gf->n_nodes;
+
+ constexpr int MAX_FUSE = 16;
+
+ enum ggml_op ops[MAX_FUSE];
+
+ std::vector<node_info> nodes;
+ nodes.reserve(gf->n_nodes);
+
+ // fuse nodes:
+ // we don't want to make reorders that break fusing, so we first pack all fusable tensors
+ // and perform the reorder over the fused nodes. after the reorder is done, we unfuse
+ for (int i = 0; i < n; i++) {
+ node_info node = {
+ /*.node =*/gf->nodes[i],
+ /*.fused =*/{},
+ };
+
+ // fuse only ops that start with these operations
+ // can be expanded when needed
+ if (node.op() == GGML_OP_ADD ||
+ node.op() == GGML_OP_NORM ||
+ node.op() == GGML_OP_RMS_NORM) {
+ ops[0] = node.op();
+
+ int f = i + 1;
+ while (f < n && f < i + MAX_FUSE) {
+ // conservatively allow fusing only these ops
+ // can be expanded when needed
+ if (gf->nodes[f]->op != GGML_OP_ADD &&
+ gf->nodes[f]->op != GGML_OP_MUL &&
+ gf->nodes[f]->op != GGML_OP_NORM &&
+ gf->nodes[f]->op != GGML_OP_RMS_NORM) {
+ break;
+ }
+ ops[f - i] = gf->nodes[f]->op;
+ f++;
+ }
+
+ f -= i;
+ for (; f > 1; f--) {
+ if (ggml_can_fuse(gf, i, ops, f)) {
+ break;
+ }
+ }
+
+ // add the fused tensors into the node info so we can unfuse them later
+ for (int k = 1; k < f; k++) {
+ ++i;
+
+ // the .dst() becomes the last fused tensor
+ node.add_fused(gf->nodes[i]);
+ }
+ }
+
+ nodes.push_back(std::move(node));
+ }
+
+ const auto order = ggml_hexagon_graph_optimize_reorder(nodes);
+
+ // unfuse
+ {
+ int j = 0;
+ for (const auto i : order) {
+ const auto & node = nodes[i];
+
+ gf->nodes[j++] = node.node;
+
+ for (auto * fused : node.fused) {
+ gf->nodes[j++] = fused;
+ }
+ }
+ }
+}
+
+static struct ggml_backend_i hexagon_backend_i = {
+ /* .get_name = */ ggml_backend_hexagon_name,
+ /* .free = */ ggml_backend_hexagon_free,
+ /* .set_tensor_async = */ NULL,
+ /* .get_tensor_async = */ NULL,
+ /* .cpy_tensor_async = */ NULL,
+ /* .synchronize = */ ggml_backend_hexagon_synchronize,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_hexagon_graph_compute,
+ /* .event_record = */ NULL,
+ /* .event_wait = */ NULL,
+ /* .graph_optimize = */ ggml_backend_hexagon_graph_optimize,
+};
+
+static ggml_guid_t ggml_backend_hexagon_guid() {
+ static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49,
+ 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11 };
+ return &guid;
+}
+
+bool ggml_backend_is_hexagon(ggml_backend_t backend) {
+ return backend && backend->iface.get_name == ggml_backend_hexagon_name;
+}
+
+// device interface
+
+static ggml_backend_t ggml_backend_hexagon_device_init(ggml_backend_dev_t dev, const char * params) {
+ auto sess = static_cast<ggml_hexagon_session *>(dev->context);
+
+ return new ggml_backend{
+ /* .guid = */ ggml_backend_hexagon_guid(),
+ /* .interface = */ hexagon_backend_i,
+ /* .device = */ dev,
+ /* .context = */ sess,
+ };
+
+ GGML_UNUSED(params);
+}
+
+static const char * ggml_backend_hexagon_device_get_name(ggml_backend_dev_t dev) {
+ auto sess = static_cast<ggml_hexagon_session *>(dev->context);
+ return sess->name.c_str();
+
+ GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_hexagon_device_get_description(ggml_backend_dev_t dev) {
+ return "Hexagon";
+ GGML_UNUSED(dev);
+}
+
+static void ggml_backend_hexagon_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+ // ~2GB per session for now
+ *free = 2ULL * 1024 * 1024 * 1024;
+ *total = *free;
+
+ GGML_UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_hexagon_device_get_type(ggml_backend_dev_t dev) {
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
+
+ GGML_UNUSED(dev);
+}
+
+static void ggml_backend_hexagon_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+ props->name = ggml_backend_hexagon_device_get_name(dev);
+ props->description = ggml_backend_hexagon_device_get_description(dev);
+ props->type = ggml_backend_hexagon_device_get_type(dev);
+ ggml_backend_hexagon_device_get_memory(dev, &props->memory_free, &props->memory_total);
+ props->caps = {
+ /* .async = */ true,
+ /* .host_buffer = */ (bool) opt_hostbuf,
+ /* .buffer_from_host_ptr = */ false,
+ /* .events = */ false,
+ };
+}
+
+static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_buffer_type(ggml_backend_dev_t dev) {
+ auto sess = static_cast<ggml_hexagon_session *>(dev->context);
+ return &sess->buffer_type;
+}
+
+static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_repack_buffer_type(ggml_backend_dev_t dev) {
+ auto sess = static_cast<ggml_hexagon_session *>(dev->context);
+ return &sess->repack_buffer_type;
+}
+
+static bool ggml_hexagon_supported_buffer(ggml_hexagon_session *sess, const struct ggml_tensor * t) {
+ if (t && t->buffer) {
+ if (ggml_backend_buffer_is_hexagon(t->buffer) == false) return false; // not our buffer
+ if (ggml_backend_hexagon_buffer_get_sess(t->buffer) != sess) return false; // wrong session
+ }
+ return true;
+}
+
+static bool ggml_hexagon_supported_buffers(ggml_hexagon_session *sess, const struct ggml_tensor * t) {
+ // all srcs & dsts must be mapped to the same session
+ if (!ggml_hexagon_supported_buffer(sess, t)) {
+ return false;
+ }
+
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (!ggml_hexagon_supported_buffer(sess, t->src[i])) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * dst = op;
+
+ // for now we can do f32 -> f16 and f16 -> f32 (without reshaping)
+ if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
+ if ( dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) return false;
+
+ const bool sametype = (src0->type == dst->type);
+ const bool transposed = ggml_is_transposed(src0) || ggml_is_transposed(dst);
+ const bool sameshape = !transposed && ggml_are_same_shape(src0, dst);
+
+ // can handle any shape and any same-type (pretty slow if reshaping is required)
+ if (sametype) return true;
+
+ // cannot handle re-shaping and type conversion at the same time
+ if (!sameshape) return false;
+
+ return true;
+}
+
+static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+ auto sess = static_cast<ggml_hexagon_session *>(dev->context);
+
+ // all srcs & dsts must be mapped to the same session
+ if (!ggml_hexagon_supported_buffers(sess, op)) {
+ ggml_hexagon_dump_op_supp(sess->name, op, false);
+ return false;
+ }
+
+ bool supp = false;
+ switch (op->op) {
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ supp = true;
+ break;
+
+ case GGML_OP_MUL_MAT:
+ supp = ggml_hexagon_supported_mul_mat(sess, op);
+ break;
+
+ case GGML_OP_MUL_MAT_ID:
+ supp = ggml_hexagon_supported_mul_mat_id(sess, op);
+ break;
+
+ case GGML_OP_MUL:
+ case GGML_OP_ADD:
+ case GGML_OP_SUB:
+ case GGML_OP_DIV:
+ supp = ggml_hexagon_supported_binary(sess, op);
+ break;
+
+ case GGML_OP_ADD_ID:
+ supp = ggml_hexagon_supported_add_id(sess, op);
+ break;
+
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_SCALE:
+ supp = ggml_hexagon_supported_unary(sess, op);
+ break;
+
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ supp = ggml_hexagon_supported_unary(sess, op);
+ break;
+
+ case GGML_OP_SUM_ROWS:
+ supp = ggml_hexagon_supported_sum_rows(sess, op);
+ break;
+
+ case GGML_OP_SOFT_MAX:
+ supp = ggml_hexagon_supported_softmax(sess, op);
+ break;
+
+ case GGML_OP_UNARY:
+ {
+ const auto unary_op = ggml_get_unary_op(op);
+ if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) {
+ supp = ggml_hexagon_supported_activations(sess, op);
+ }
+ break;
+ }
+ case GGML_OP_GLU:
+ {
+ const auto glu_op = ggml_get_glu_op(op);
+ if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
+ supp = ggml_hexagon_supported_activations(sess, op);
+ }
+ break;
+ }
+ case GGML_OP_ROPE:
+ supp = ggml_hexagon_supported_rope(sess, op);
+ break;
+
+ case GGML_OP_FLASH_ATTN_EXT:
+ supp = ggml_hexagon_supported_flash_attn_ext(sess, op);
+ break;
+
+ case GGML_OP_SET_ROWS:
+ supp = ggml_hexagon_supported_set_rows(sess, op);
+ break;
+
+ case GGML_OP_GET_ROWS:
+ supp = ggml_hexagon_supported_get_rows(sess, op);
+ break;
+
+ case GGML_OP_CPY:
+ supp = ggml_hexagon_supported_cpy(sess, op);
+ break;
+
+ case GGML_OP_ARGSORT:
+ supp = ggml_hexagon_supported_argsort(sess, op);
+ break;
+
+ default:
+ break;
+ }
+
+ ggml_hexagon_dump_op_supp(sess->name, op, supp);
+ return supp;
+}
+
+static bool ggml_backend_hexagon_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+ if (buft->iface.get_alignment != ggml_backend_hexagon_buffer_type_get_alignment) {
+ return false;
+ }
+
+ auto s0 = static_cast<ggml_hexagon_session *>(dev->context);
+ auto s1 = static_cast<ggml_backend_hexagon_buffer_type_context *>(buft->context)->sess;
+
+ // Need session/domain-id for buffers to be compatible
+ bool supp = (s0->session_id == s1->session_id);
+
+ HEX_VERBOSE("ggml-hex: %s device-supports-buft %s (%d)\n", s0->name.c_str(), s1->name.c_str(), (int) supp);
+
+ return supp;
+}
+
+static ggml_backend_buffer_type_t * ggml_backend_hexagon_device_get_extra_buffers_type(ggml_backend_dev_t dev) {
+ auto s0 = static_cast<ggml_hexagon_session *>(dev->context);
+ HEX_VERBOSE("ggml-hex: device-get-extra-buft : %s \n", s0->name.c_str());
+
+ static ggml_backend_buffer_type_t bufts[2];
+ bufts[0] = ggml_backend_hexagon_device_get_repack_buffer_type(dev);
+ bufts[1] = NULL;
+ return bufts;
+}
+
+static const struct ggml_backend_device_i ggml_backend_hexagon_device_i = {
+ /* .get_name = */ ggml_backend_hexagon_device_get_name,
+ /* .get_description = */ ggml_backend_hexagon_device_get_description,
+ /* .get_memory = */ ggml_backend_hexagon_device_get_memory,
+ /* .get_type = */ ggml_backend_hexagon_device_get_type,
+ /* .get_props = */ ggml_backend_hexagon_device_get_props,
+ /* .init_backend = */ ggml_backend_hexagon_device_init,
+ /* .get_buffer_type = */ ggml_backend_hexagon_device_get_buffer_type,
+ /* .get_host_buffer_type = */ NULL, // ggml_backend_hexagon_device_get_host_buffer_type,
+ /* .buffer_from_host_ptr = */ NULL, // ggml_backend_hexagon_device_buffer_from_ptr,
+ /* .supports_op = */ ggml_backend_hexagon_device_supports_op,
+ /* .supports_buft = */ ggml_backend_hexagon_device_supports_buft,
+ /* .offload_op = */ NULL, // ggml_backend_hexagon_device_offload_op,
+ /* .event_new = */ NULL,
+ /* .event_free = */ NULL,
+ /* .event_synchronize = */ NULL,
+};
+
+//** backend registry
+
+#define GGML_HEXAGON_MAX_SESSIONS 16
+
+struct ggml_hexagon_registry {
+ ggml_hexagon_registry(ggml_backend_reg_t reg);
+ ~ggml_hexagon_registry();
+
+ ggml_backend_device devices[GGML_HEXAGON_MAX_SESSIONS];
+};
+
+ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
+ GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev);
+
+ if (!opt_arch) {
+ int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch);
+ if (err != 0) {
+ GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err);
+ opt_arch = 73;
+ }
+ }
+
+#if defined(__ANDROID__)
+ if (opt_arch < 75) {
+ opt_ndev = 1;
+ GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
+ }
+#endif
+
+ GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
+
+ // Create devices / sessions
+ for (size_t i = 0; i < opt_ndev; i++) {
+ devices[i].iface = ggml_backend_hexagon_device_i;
+ devices[i].reg = reg;
+ try {
+ devices[i].context = new ggml_hexagon_session(i, &devices[i]);
+ } catch (const std::exception & exc) {
+ GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i);
+ devices[i].context = nullptr;
+ }
+ }
+}
+
+ggml_hexagon_registry::~ggml_hexagon_registry() {
+ GGML_LOG_INFO("ggml-hex: releasing registry\n");
+
+ // Release devices / sessions
+ for (size_t i = 0; i < opt_ndev; i++) {
+ auto sess = static_cast<ggml_hexagon_session *>(devices[i].context);
+ delete sess;
+ }
+}
+
+static const char * ggml_backend_hexagon_reg_get_name(ggml_backend_reg_t reg) {
+ return "HTP";
+ GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_hexagon_reg_get_device_count(ggml_backend_reg_t reg) {
+ return opt_ndev;
+ GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_hexagon_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+ auto hreg = static_cast<ggml_hexagon_registry *>(reg->context);
+
+ if (index >= opt_ndev || !hreg->devices[index].context) {
+ return nullptr;
+ }
+
+ return &hreg->devices[index];
+}
+
+static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+ if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0 && opt_hostbuf) {
+ ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_hexagon_device_get_extra_buffers_type;
+ return (void *) fct;
+ }
+
+ return NULL;
+}
+
+static void ggml_hexagon_init(ggml_backend_reg * reg) {
+ // Basic sanity checks to make sure definitions match
+ static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0,
+ "please update hexagon_type to match ggml_type");
+ static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0,
+ "please update hexagon_type to match ggml_type");
+ static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
+ "please update hexagon_type to match ggml_type");
+
+ const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL");
+ const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
+ const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF");
+ const char * str_opmask = getenv("GGML_HEXAGON_OPMASK");
+ const char * str_opsync = getenv("GGML_HEXAGON_OPSYNC");
+ const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
+ const char * str_etm = getenv("GGML_HEXAGON_ETM");
+ const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
+ const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
+ const char * str_arch = getenv("GGML_HEXAGON_ARCH");
+
+ opt_experimental = str_experimental ? atoi(str_experimental) : 0;
+ opt_verbose = str_verbose ? atoi(str_verbose) : 0;
+ opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
+ opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask;
+ opt_opsync = str_opsync ? atoi(str_opsync) : 0;
+ opt_profile = str_profile ? atoi(str_profile) : 0;
+ opt_etm = str_etm ? atoi(str_etm) : 0;
+ opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
+ opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
+
+ if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
+ opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
+ }
+
+ if (str_arch) {
+ if (str_arch[0] == 'v') {
+ str_arch++;
+ }
+ opt_arch = strtoul(str_arch, NULL, 0);
+ }
+
+ opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1;
+
+ reg->context = new ggml_hexagon_registry(reg);
+
+ HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req),
+ sizeof(struct htp_general_rsp));
+}
+
+static const struct ggml_backend_reg_i ggml_backend_hexagon_reg_i = {
+ /* .get_name = */ ggml_backend_hexagon_reg_get_name,
+ /* .get_device_count = */ ggml_backend_hexagon_reg_get_device_count,
+ /* .get_device = */ ggml_backend_hexagon_reg_get_device,
+ /* .get_proc_address = */ ggml_backend_hexagon_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_hexagon_reg(void) {
+ static bool initialized = false;
+
+ static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION,
+ /* .iface = */ ggml_backend_hexagon_reg_i,
+ /* .context = */ NULL };
+
+ {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ if (!initialized) {
+ auto nErr = htpdrv_init();
+ if (nErr != AEE_SUCCESS) {
+ return NULL;
+ }
+
+ ggml_hexagon_init(&reg);
+ }
+
+ initialized = true;
+ }
+
+ return &reg;
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_hexagon_reg)
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp-drv.cpp b/llama.cpp/ggml/src/ggml-hexagon/htp-drv.cpp
new file mode 100644
index 0000000..2530bb0
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp-drv.cpp
@@ -0,0 +1,418 @@
+// sample drv interface
+
+#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+#pragma clang diagnostic ignored "-Wsign-compare"
+
+#include <filesystem>
+#include <set>
+#include <sstream>
+#include <string>
+#ifdef _WIN32
+# define WIN32_LEAN_AND_MEAN
+# ifndef NOMINMAX
+# define NOMINMAX
+# endif
+# include <windows.h>
+# include <winevt.h>
+#else
+# include <dlfcn.h>
+# include <unistd.h>
+#endif
+#include "ggml-impl.h"
+#include "htp-drv.h"
+#include "libdl.h"
+
+#include <domain.h>
+
+//
+// Driver API types
+//
+
+typedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size);
+typedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size);
+typedef void (*rpcmem_free_pfn_t)(void * po);
+typedef int (*rpcmem_to_fd_pfn_t)(void * po);
+
+typedef AEEResult (*dspqueue_create_pfn_t)(int domain,
+ uint32_t flags,
+ uint32_t req_queue_size,
+ uint32_t resp_queue_size,
+ dspqueue_callback_t packet_callback,
+ dspqueue_callback_t error_callback,
+ void * callback_context,
+ dspqueue_t * queue);
+typedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue);
+typedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id);
+typedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags,
+ uint32_t num_buffers,
+ struct dspqueue_buffer *buffers,
+ uint32_t message_length,
+ const uint8_t *message,
+ uint32_t timeout_us);
+typedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags,
+ uint32_t max_buffers, uint32_t *num_buffers,
+ struct dspqueue_buffer *buffers,
+ uint32_t max_message_length,
+ uint32_t *message_length, uint8_t *message,
+ uint32_t timeout_us);
+
+typedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags);
+typedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length);
+
+typedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph);
+typedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra);
+typedef int (*remote_handle64_close_pfn_t)(remote_handle h);
+typedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen);
+typedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen);
+typedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen);
+
+//
+// Driver API pfns
+//
+
+rpcmem_alloc_pfn_t rpcmem_alloc_pfn = nullptr;
+rpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr;
+rpcmem_free_pfn_t rpcmem_free_pfn = nullptr;
+rpcmem_to_fd_pfn_t rpcmem_to_fd_pfn = nullptr;
+
+fastrpc_mmap_pfn_t fastrpc_mmap_pfn = nullptr;
+fastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr;
+
+dspqueue_create_pfn_t dspqueue_create_pfn = nullptr;
+dspqueue_close_pfn_t dspqueue_close_pfn = nullptr;
+dspqueue_export_pfn_t dspqueue_export_pfn = nullptr;
+dspqueue_write_pfn_t dspqueue_write_pfn = nullptr;
+dspqueue_read_pfn_t dspqueue_read_pfn = nullptr;
+
+remote_handle64_open_pfn_t remote_handle64_open_pfn = nullptr;
+remote_handle64_invoke_pfn_t remote_handle64_invoke_pfn = nullptr;
+remote_handle64_close_pfn_t remote_handle64_close_pfn = nullptr;
+remote_handle_control_pfn_t remote_handle_control_pfn = nullptr;
+remote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr;
+remote_session_control_pfn_t remote_session_control_pfn = nullptr;
+
+//
+// Driver API
+//
+
+void * rpcmem_alloc(int heapid, uint32_t flags, int size) {
+ return rpcmem_alloc_pfn(heapid, flags, size);
+}
+
+void * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) {
+ if (rpcmem_alloc2_pfn) {
+ return rpcmem_alloc2_pfn(heapid, flags, size);
+ } else {
+ GGML_LOG_INFO("ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\n");
+ return rpcmem_alloc_pfn(heapid, flags, size);
+ }
+}
+
+void rpcmem_free(void * po) {
+ return rpcmem_free_pfn(po);
+}
+
+int rpcmem_to_fd(void * po) {
+ return rpcmem_to_fd_pfn(po);
+}
+
+HTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) {
+ return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags);
+}
+
+HTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) {
+ return fastrpc_munmap_pfn(domain, fd, addr, length);
+}
+
+AEEResult dspqueue_create(int domain,
+ uint32_t flags,
+ uint32_t req_queue_size,
+ uint32_t resp_queue_size,
+ dspqueue_callback_t packet_callback,
+ dspqueue_callback_t error_callback,
+ void * callback_context,
+ dspqueue_t * queue) {
+ return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback,
+ callback_context, queue);
+}
+
+AEEResult dspqueue_close(dspqueue_t queue) {
+ return dspqueue_close_pfn(queue);
+}
+
+AEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) {
+ return dspqueue_export_pfn(queue, queue_id);
+}
+
+AEEResult dspqueue_write(dspqueue_t queue,
+ uint32_t flags,
+ uint32_t num_buffers,
+ struct dspqueue_buffer * buffers,
+ uint32_t message_length,
+ const uint8_t * message,
+ uint32_t timeout_us) {
+ return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us);
+}
+
+AEEResult dspqueue_read(dspqueue_t queue,
+ uint32_t * flags,
+ uint32_t max_buffers,
+ uint32_t * num_buffers,
+ struct dspqueue_buffer * buffers,
+ uint32_t max_message_length,
+ uint32_t * message_length,
+ uint8_t * message,
+ uint32_t timeout_us) {
+ return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length,
+ message, timeout_us);
+}
+
+HTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) {
+ return remote_handle64_open_pfn(name, ph);
+}
+
+HTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) {
+ return remote_handle64_invoke_pfn(h, dwScalars, pra);
+}
+
+HTPDRV_API int remote_handle64_close(remote_handle64 h) {
+ return remote_handle64_close_pfn(h);
+}
+
+HTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) {
+ return remote_handle_control_pfn(req, data, datalen);
+}
+
+HTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) {
+ return remote_handle64_control_pfn(h, req, data, datalen);
+}
+
+HTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) {
+ return remote_session_control_pfn(req, data, datalen);
+}
+
+#ifdef _WIN32
+
+static std::string wstr_to_str(std::wstring_view wstr) {
+ std::string result;
+ if (wstr.empty()) {
+ return result;
+ }
+ auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
+ wstr.data(), (int) wstr.size(),
+ nullptr, 0, nullptr, nullptr);
+ if (bytes_needed == 0) {
+ GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
+ throw std::runtime_error("Invalid wstring input");
+ }
+
+ result.resize(bytes_needed, '\0');
+ int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
+ wstr.data(), (int) wstr.size(),
+ result.data(), bytes_needed,
+ nullptr, nullptr);
+ if (bytes_written == 0) {
+ GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
+ throw std::runtime_error("Wstring conversion failed");
+ }
+ return result;
+}
+
+static std::string get_driver_path() {
+ std::wstring serviceName = L"qcnspmcdm";
+ std::string result;
+
+ // Get a handle to the SCM database.
+ SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ);
+ if (nullptr == schSCManager) {
+ GGML_LOG_ERROR("ggml-hex: Failed to open SCManager. Error: %lu\n", GetLastError());
+ return result;
+ }
+
+ // Get a handle to the service.
+ SC_HANDLE schService = OpenServiceW(schSCManager, // SCM database
+ serviceName.c_str(), // name of service
+ SERVICE_QUERY_CONFIG); // need query config access
+
+ if (nullptr == schService) {
+ GGML_LOG_ERROR("ggml-hex: Failed to open qcnspmcdm service. Error: %lu\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return result;
+ }
+
+ // Store the size of buffer used as an output.
+ DWORD bufferSize;
+ if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) &&
+ (GetLastError() != ERROR_INSUFFICIENT_BUFFER)) {
+ GGML_LOG_ERROR("ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+ return result;
+ }
+ // Get the configuration of the service.
+ LPQUERY_SERVICE_CONFIGW serviceConfig =
+ static_cast<LPQUERY_SERVICE_CONFIGW>(LocalAlloc(LMEM_FIXED, bufferSize));
+ if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) {
+ fprintf(stderr, "ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
+ LocalFree(serviceConfig);
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+ return result;
+ }
+
+ // Read the driver file path get its parent directory
+ std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName);
+ driverPath = driverPath.substr(0, driverPath.find_last_of(L"\\"));
+
+ // Clean up resources
+ LocalFree(serviceConfig);
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+
+ // Driver path would contain invalid path string, like:
+ // \SystemRoot\System32\DriverStore\FileRepository\qcadsprpc8280.inf_arm64_c2b9460c9a072f37
+ // "\SystemRoot" should be replace with a correct one (e.g. C:\Windows)
+ const std::wstring systemRootPlaceholder = L"\\SystemRoot";
+ if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) {
+ GGML_LOG_ERROR("ggml-hex: String pattern not found in driver path.\n");
+ return result;
+ }
+
+ // Replace \SystemRoot with an absolute path from system ENV windir
+ const std::wstring systemRootEnv = L"windir";
+
+ // Query the number of wide charactors this variable requires
+ DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0);
+ if (numWords == 0) {
+ GGML_LOG_ERROR("ggml-hex: Failed get systemRoot environment variable\n");
+ return result;
+ }
+
+ // Query the actual system root name from environment variable
+ std::vector<wchar_t> systemRoot(numWords + 1);
+ numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1);
+ if (numWords == 0) {
+ GGML_LOG_ERROR("ggml-hex: Failed to read windir environment variable\n");
+ return result;
+ }
+ driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data()));
+
+ return wstr_to_str(driverPath);
+}
+
+#endif
+
+using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
+
+int htpdrv_init() {
+ static dl_handle_ptr lib_cdsp_rpc_handle = nullptr;
+ static bool initialized = false;
+#ifdef _WIN32
+ std::string drv_path = get_driver_path() + "\\" + "libcdsprpc.dll";
+#else
+ std::string drv_path = "libcdsprpc.so";
+#endif
+ if (initialized) {
+ GGML_LOG_INFO("ggml-hex: Driver already loaded\n");
+ return AEE_SUCCESS;
+ }
+ GGML_LOG_INFO("ggml-hex: Loading driver %s\n", drv_path.c_str());
+
+ fs::path path{ drv_path.c_str() };
+ dl_handle_ptr handle { dl_load_library(path) };
+ if (!handle) {
+ GGML_LOG_ERROR("ggml-hex: failed to load %s: %s\n", path.u8string().c_str(), dl_error());
+ return AEE_EUNABLETOLOAD;
+ }
+
+#define dlsym(drv, type, pfn, symbol, ignore) \
+ do { \
+ pfn = (type) dl_get_sym(drv, #symbol); \
+ if (!ignore && nullptr == pfn) { \
+ GGML_LOG_ERROR("ggml-hex: failed to dlsym %s\n", #symbol); \
+ return AEE_EUNABLETOLOAD; \
+ } \
+ } while (0)
+
+ dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false);
+ dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true);
+ dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false);
+ dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false);
+ dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false);
+ dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false);
+ dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false);
+ dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false);
+ dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false);
+ dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false);
+ dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false);
+ dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false);
+ dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false);
+ dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false);
+ dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false);
+ dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false);
+ dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false);
+
+ lib_cdsp_rpc_handle = std::move(handle);
+ initialized = true;
+
+ return AEE_SUCCESS;
+}
+
+domain * get_domain(int domain_id) {
+ int i = 0;
+ int size = sizeof(supported_domains) / sizeof(domain);
+
+ for (i = 0; i < size; i++) {
+ if (supported_domains[i].id == domain_id) {
+ return &supported_domains[i];
+ }
+ }
+
+ return NULL;
+}
+
+int get_hex_arch_ver(int domain, int * arch) {
+ if (!remote_handle_control_pfn) {
+ GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
+ return AEE_EUNSUPPORTEDAPI;
+ }
+
+ struct remote_dsp_capability arch_ver;
+ arch_ver.domain = (uint32_t) domain;
+ arch_ver.attribute_ID = ARCH_VER;
+ arch_ver.capability = (uint32_t) 0;
+
+ int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
+ if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
+ GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
+ return AEE_EUNSUPPORTEDAPI;
+ }
+
+ if (err != AEE_SUCCESS) {
+ GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
+ return err;
+ }
+
+ switch (arch_ver.capability & 0xff) {
+ case 0x68:
+ *arch = 68;
+ return 0;
+ case 0x69:
+ *arch = 69;
+ return 0;
+ case 0x73:
+ *arch = 73;
+ return 0;
+ case 0x75:
+ *arch = 75;
+ return 0;
+ case 0x79:
+ *arch = 79;
+ return 0;
+ case 0x81:
+ *arch = 81;
+ return 0;
+ }
+ return -1;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp-drv.h b/llama.cpp/ggml/src/ggml-hexagon/htp-drv.h
new file mode 100644
index 0000000..6eba7ba
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp-drv.h
@@ -0,0 +1,121 @@
+#pragma once
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#ifdef _WIN32
+# pragma clang diagnostic ignored "-Wignored-attributes"
+#endif
+
+#include <AEEStdErr.h>
+#include <rpcmem.h>
+#include <remote.h>
+#include <dspqueue.h>
+
+#if defined(_WIN32) && !defined(__MINGW32__)
+# ifdef GGML_BACKEND_BUILD
+# define HTPDRV_API __declspec(dllexport) extern
+# else
+# define HTPDRV_API __declspec(dllimport) extern
+# endif
+#else
+# define HTPDRV_API __attribute__ ((visibility ("default"))) extern
+#endif
+
+/* Offset to differentiate HLOS and Hexagon error codes.
+ Stores the value of AEE_EOFFSET for Hexagon. */
+#ifndef DSP_OFFSET
+# define DSP_OFFSET 0x80000400
+#endif
+
+/* Errno for connection reset by peer. */
+#ifndef ECONNRESET
+# ifdef __hexagon__
+# define ECONNRESET 104
+# endif
+#endif
+
+/* Abstraction of different OS specific sleep APIs.
+ SLEEP accepts input in seconds. */
+#ifndef SLEEP
+# ifdef __hexagon__
+# define SLEEP(x) \
+ { /* Do nothing for simulator. */ \
+ }
+# else
+# ifdef _WIN32
+# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
+# else
+# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
+# endif
+# endif
+#endif
+
+/* Include windows specific header files. */
+#ifdef _WIN32
+# include <windows.h>
+# include <sysinfoapi.h>
+# define _CRT_SECURE_NO_WARNINGS 1
+# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
+#endif
+
+/* Includes and defines for all HLOS except windows */
+#if !defined(__hexagon__) && !defined(_WIN32)
+# include "unistd.h"
+
+# include <sys/time.h>
+#endif
+
+/* Includes and defines for Hexagon and all HLOS except Windows. */
+#if !defined(_WIN32)
+/* Weak reference to remote symbol for compilation. */
+# pragma weak remote_session_control
+# pragma weak remote_handle_control
+# pragma weak remote_handle64_control
+# pragma weak fastrpc_mmap
+# pragma weak fastrpc_munmap
+# pragma weak rpcmem_alloc2
+#endif
+
+#if !defined(_WIN32)
+# pragma weak remote_system_request
+#endif
+
+#ifdef _WIN32
+# define DSPQUEUE_TIMEOUT DSPQUEUE_TIMEOUT_NONE
+#else
+# define DSPQUEUE_TIMEOUT 1000000
+#endif
+
+/**
+ * htpdrv_init API: driver interface entry point
+ *
+ * @return Return AEE error codes as defined in Hexagon SDK.
+ */
+HTPDRV_API int htpdrv_init(void);
+
+/**
+ * get_domain API: get domain struct from domain value.
+ *
+ * @param[in] domain value of a domain
+ * @return Returns domain struct of the domain if it is supported or else
+ * returns NULL.
+ *
+ */
+HTPDRV_API domain * get_domain(int domain_id);
+
+/**
+ * get_hex_arch_ver API: query the Hexagon processor architecture version information
+ *
+ * @param[in] domain_id value of a domain
+ * @param[out] Arch version (73, 75, ...)
+ * @return 0 if query is successful.
+ * non-zero if error, return value points to the error.
+ *
+ */
+HTPDRV_API int get_hex_arch_ver(int domain, int * arch);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/llama.cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt
new file mode 100644
index 0000000..2c23b60
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt
@@ -0,0 +1,45 @@
+cmake_minimum_required(VERSION 3.22.2)
+project(ggml-htp C CXX ASM)
+
+include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
+
+include_directories(
+ ${HEXAGON_SDK_ROOT}/incs
+ ${HEXAGON_SDK_ROOT}/incs/stddef
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../../include
+ ${CMAKE_CURRENT_SOURCE_DIR}/../..
+ ${CMAKE_CURRENT_SOURCE_DIR}/..
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${CMAKE_CURRENT_BINARY_DIR})
+
+set(HTP_LIB ggml-htp-${DSP_VERSION})
+
+add_library(${HTP_LIB} SHARED
+ main.c
+ htp_iface_skel.c
+ worker-pool.c
+ hex-dma.c
+ matmul-ops.c
+ binary-ops.c
+ unary-ops.c
+ sum-rows-ops.c
+ softmax-ops.c
+ act-ops.c
+ rope-ops.c
+ flash-attn-ops.c
+ set-rows-ops.c
+ get-rows-ops.c
+ cpy-ops.c
+ argsort-ops.c
+)
+
+target_compile_definitions(${HTP_LIB} PRIVATE
+ $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
+ $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
+ FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
+
+build_idl(htp_iface.idl ${HTP_LIB})
+
+set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
+
+install(TARGETS ${HTP_LIB})
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c
new file mode 100644
index 0000000..950d836
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c
@@ -0,0 +1,823 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <math.h>
+#include <string.h>
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+#define htp_act_preamble3 \
+ const uint32_t ne00 = src0->ne[0]; \
+ const uint32_t ne01 = src0->ne[1]; \
+ const uint32_t ne02 = src0->ne[2]; \
+ const uint32_t ne03 = src0->ne[3]; \
+ \
+ const uint32_t ne10 = src1->ne[0]; \
+ const uint32_t ne11 = src1->ne[1]; \
+ const uint32_t ne12 = src1->ne[2]; \
+ const uint32_t ne13 = src1->ne[3]; \
+ \
+ const uint32_t ne0 = dst->ne[0]; \
+ const uint32_t ne1 = dst->ne[1]; \
+ const uint32_t ne2 = dst->ne[2]; \
+ const uint32_t ne3 = dst->ne[3]; \
+ \
+ const uint32_t nb00 = src0->nb[0]; \
+ const uint32_t nb01 = src0->nb[1]; \
+ const uint32_t nb02 = src0->nb[2]; \
+ const uint32_t nb03 = src0->nb[3]; \
+ \
+ const uint32_t nb10 = src1->nb[0]; \
+ const uint32_t nb11 = src1->nb[1]; \
+ const uint32_t nb12 = src1->nb[2]; \
+ const uint32_t nb13 = src1->nb[3]; \
+ \
+ const uint32_t nb0 = dst->nb[0]; \
+ const uint32_t nb1 = dst->nb[1]; \
+ const uint32_t nb2 = dst->nb[2]; \
+ const uint32_t nb3 = dst->nb[3];
+
+#define htp_act_preamble2 \
+ const uint32_t ne00 = src0->ne[0]; \
+ const uint32_t ne01 = src0->ne[1]; \
+ const uint32_t ne02 = src0->ne[2]; \
+ const uint32_t ne03 = src0->ne[3]; \
+ \
+ const uint32_t ne0 = dst->ne[0]; \
+ const uint32_t ne1 = dst->ne[1]; \
+ const uint32_t ne2 = dst->ne[2]; \
+ const uint32_t ne3 = dst->ne[3]; \
+ \
+ const uint32_t nb00 = src0->nb[0]; \
+ const uint32_t nb01 = src0->nb[1]; \
+ const uint32_t nb02 = src0->nb[2]; \
+ const uint32_t nb03 = src0->nb[3]; \
+ \
+ const uint32_t nb0 = dst->nb[0]; \
+ const uint32_t nb1 = dst->nb[1]; \
+ const uint32_t nb2 = dst->nb[2]; \
+ const uint32_t nb3 = dst->nb[3];
+
+static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
+ const struct htp_tensor * src1,
+ struct htp_tensor * dst,
+ const int32_t * op_params,
+ struct htp_spad * src0_spad,
+ struct htp_spad * src1_spad,
+ struct htp_spad * dst_spad,
+ uint32_t nth,
+ uint32_t ith,
+ uint32_t src0_nrows_per_thread,
+ dma_queue * dma_queue) {
+ htp_act_preamble3;
+
+ size_t src0_row_size = nb01;
+ size_t src1_row_size = nb11;
+ size_t dst_row_size = nb1;
+
+
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
+ const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
+ uint8_t * restrict data_dst = (uint8_t *) dst->data;
+
+ const bool src1_valid = src1->ne[0];
+ const int nc = (src1_valid) ? ne00 : ne00 / 2;
+ if (!src1_valid) {
+ const int32_t swapped = op_params[1];
+ data_src1 = data_src0;
+ src1_row_size = src0_row_size;
+
+ const size_t nc_in_bytes = nc * SIZEOF_FP32;
+ data_src0 += swapped ? nc_in_bytes : 0;
+ data_src1 += swapped ? 0 : nc_in_bytes;
+ }
+
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+ const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+
+ uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
+ uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
+ uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
+
+ // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
+ size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
+ size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
+ size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
+
+ const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+ if (BLOCK == 0) {
+ FARF(ERROR,
+ "swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
+ src0_spad->size_per_thread, src0_row_size_aligned);
+ return;
+ }
+
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
+ for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
+ dma_queue_push_vtcm_to_ddr(dma_queue,
+ dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
+ dst_row_size, dst_row_size_aligned, 0);
+
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, block_size);
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
+ src1_row_size_aligned, src1_row_size, block_size);
+ }
+
+ for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
+ float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
+ float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
+
+ for (uint32_t ib = 0; ib < block_size; ib++) {
+ const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
+ const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));
+ float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
+
+ //swiglu(x) = x1 * sigmoid(x0)
+ hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, nc);
+ hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
+ (const uint8_t *) src1_spad_ptr, nc);
+ }
+
+ dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
+ dst_row_size_aligned, block_size);
+
+ // prefetch N+2 loop iteration if any
+ const uint32_t pref_block = (ir + BLOCK * 2);
+ if (pref_block < src0_end_row) {
+ const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, pref_block_size);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
+ src1_row_size_aligned, src1_row_size, pref_block_size);
+ }
+ }
+
+ dma_queue_flush(dma_queue);
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+ ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
+ const struct htp_tensor * src1,
+ struct htp_tensor * dst,
+ const int32_t * op_params,
+ struct htp_spad * src0_spad,
+ struct htp_spad * src1_spad,
+ struct htp_spad * dst_spad,
+ uint32_t nth,
+ uint32_t ith,
+ uint32_t src0_nrows_per_thread,
+ dma_queue * dma_queue) {
+ htp_act_preamble3;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ size_t src0_row_size = nb01;
+ size_t src1_row_size = nb11;
+ size_t dst_row_size = nb1;
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
+ const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
+ uint8_t * restrict data_dst = (uint8_t *) dst->data;
+
+ const bool src1_valid = src1->ne[0];
+ const int nc = (src1_valid) ? ne00 : ne00 / 2;
+ if (!src1_valid) {
+ const int32_t swapped = op_params[1];
+ data_src1 = data_src0;
+ src1_row_size = src0_row_size;
+
+ const size_t nc_in_bytes = nc * SIZEOF_FP32;
+ data_src0 += swapped ? nc_in_bytes : 0;
+ data_src1 += swapped ? 0 : nc_in_bytes;
+ }
+
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+ const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+
+ uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
+ uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
+ uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
+
+ // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
+ size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
+ size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
+ size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
+
+ const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+ if (BLOCK == 0) {
+ FARF(ERROR,
+ "swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
+ "%zu\n",
+ src0_spad->size_per_thread, src0_row_size_aligned);
+ return;
+ }
+ const float alpha = ((const float *) (op_params))[2];
+ const float limit = ((const float *) (op_params))[3];
+
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
+ for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
+ dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
+ dst_row_size, dst_row_size_aligned, 0);
+
+ dma_queue_push_ddr_to_vtcm(
+ dma_queue,
+ dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, block_size);
+ dma_queue_push_ddr_to_vtcm(
+ dma_queue,
+ dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
+ src1_row_size_aligned, src1_row_size, block_size);
+ }
+
+ for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
+ float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
+ float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
+
+ for (uint32_t ib = 0; ib < block_size; ib++) {
+ const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
+ const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));
+ float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
+
+ // x (src0_spad_data) = std::min(src0_p[k], limit);
+ hvx_min_scalar_f32((uint8_t *) src0_spad_ptr, (const uint8_t *) src0_spad_ptr, limit, nc);
+ // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
+ hvx_clamp_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, -limit, limit, nc);
+ // y (src1_spad_data) = y1 + 1.f
+ hvx_add_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, 1.0, nc);
+ // x1 (dst_spad_data) = alpha * (x)
+ hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, alpha, nc);
+ // x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1))
+ hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc);
+ // out = x * sigmoid(alpha * x) * (y + 1.f)
+ hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
+ (const uint8_t *) src1_spad_ptr, nc);
+ }
+
+ dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
+ dst_row_size_aligned, block_size);
+
+ // prefetch N+2 loop iteration if any
+ const uint32_t pref_block = (ir + BLOCK * 2);
+ if (pref_block < src0_end_row) {
+ const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, pref_block_size);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
+ src1_row_size_aligned, src1_row_size, pref_block_size);
+ }
+ }
+
+ dma_queue_flush(dma_queue);
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "swiglu-oai-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0],
+ src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2],
+ src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+
+static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
+ struct htp_tensor * dst,
+ const int32_t * op_params,
+ struct htp_spad * src0_spad,
+ struct htp_spad * dst_spad,
+ uint32_t nth,
+ uint32_t ith,
+ uint32_t src0_nrows_per_thread,
+ dma_queue * dma_queue) {
+ htp_act_preamble2;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ const size_t src0_row_size = nb01;
+ const size_t dst_row_size = nb1;
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03;
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ const uint8_t * data_src0 = (const uint8_t *) src0->data;
+ uint8_t * data_dst = (uint8_t *) dst->data;
+
+ uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
+ uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
+
+ // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
+ size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
+ size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
+
+ // In gelu = x*sigmoid(x*1.702)
+ const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+
+ if (BLOCK == 0) {
+ FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
+ src0_spad->size_per_thread, src0_row_size_aligned);
+ return;
+ }
+
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
+ for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
+ dma_queue_push_vtcm_to_ddr(dma_queue,
+ dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
+ dst_row_size, dst_row_size_aligned, 0);
+
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, block_size);
+ }
+
+ for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ float* dst_spad = (float *) dma_queue_pop(dma_queue).src;
+ float* src0_spad = (float *) dma_queue_pop(dma_queue).dst;
+
+ for (uint32_t ib = 0; ib < block_size; ib++) {
+ const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
+ float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
+
+ // gelu = x * sigmoid(1.702 * x) // current implementation
+ hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
+ hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
+ hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
+ }
+
+ dma_queue_push_vtcm_to_ddr(dma_queue,
+ dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
+ dst_row_size, dst_row_size_aligned, block_size);
+
+ // prefetch N+2 loop iteration if any
+ const uint32_t pref_block = (ir + BLOCK * 2);
+ if (pref_block < src0_end_row) {
+ const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, pref_block_size);
+ }
+ }
+
+ dma_queue_flush(dma_queue);
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "gelu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02,
+ ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = (struct htp_ops_context *) data;
+ unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
+ octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+
+
+static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
+ struct htp_tensor * dst,
+ const int32_t * op_params,
+ struct htp_spad * src0_spad,
+ struct htp_spad * dst_spad,
+ uint32_t nth,
+ uint32_t ith,
+ uint32_t src0_nrows_per_thread,
+ dma_queue * dma_queue) {
+ htp_act_preamble2;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ const size_t src0_row_size = nb01;
+ const size_t dst_row_size = nb1;
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03;
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ const uint8_t * data_src0 = (const uint8_t *) src0->data;
+ uint8_t * data_dst = (uint8_t *) dst->data;
+
+ uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
+ uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
+
+ // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
+ size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
+ size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
+
+ const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+
+ if (BLOCK == 0) {
+ FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
+ src0_spad->size_per_thread, src0_row_size_aligned);
+ return;
+ }
+
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
+ for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
+ dma_queue_push_vtcm_to_ddr(dma_queue,
+ dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
+ dst_row_size, dst_row_size_aligned, 0);
+
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, block_size);
+ }
+
+ for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ float* dst_spad = (float *) dma_queue_pop(dma_queue).src;
+ float* src0_spad = (float *) dma_queue_pop(dma_queue).dst;
+
+ for (uint32_t ib = 0; ib < block_size; ib++) {
+ const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
+ float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
+
+ // silu = x * sigmoid(x)
+ hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
+ hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
+ }
+
+ dma_queue_push_vtcm_to_ddr(dma_queue,
+ dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
+ dst_row_size, dst_row_size_aligned, block_size);
+
+ // prefetch N+2 loop iteration if any
+ const uint32_t pref_block = (ir + BLOCK * 2);
+ if (pref_block < src0_end_row) {
+ const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, pref_block_size);
+ }
+ }
+
+ dma_queue_flush(dma_queue);
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "silu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02,
+ ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static const float GELU_COEF_A = 0.044715f;
+static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
+ const struct htp_tensor * src1,
+ struct htp_tensor * dst,
+ const int32_t * op_params,
+ struct htp_spad * src0_spad,
+ struct htp_spad * src1_spad,
+ struct htp_spad * dst_spad,
+ uint32_t nth,
+ uint32_t ith,
+ uint32_t src0_nrows_per_thread,
+ dma_queue * dma_queue) {
+ htp_act_preamble3;
+
+ size_t src0_row_size = nb01;
+ size_t src1_row_size = nb11;
+ size_t dst_row_size = nb1;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
+ const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
+ uint8_t * restrict data_dst = (uint8_t *) dst->data;
+
+ const bool src1_valid = src1->ne[0];
+ const int nc = (src1_valid) ? ne00 : ne00 / 2;
+ if (!src1_valid) {
+ const int32_t swapped = op_params[1];
+ data_src1 = data_src0;
+ src1_row_size = src0_row_size;
+
+ const size_t nc_in_bytes = nc * SIZEOF_FP32;
+ data_src0 += swapped ? nc_in_bytes : 0;
+ data_src1 += swapped ? 0 : nc_in_bytes;
+ }
+
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+ const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+
+ uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
+ uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
+ uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
+
+ // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
+ size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
+ size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
+ size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
+
+ const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+ if (BLOCK == 0) {
+ FARF(ERROR,
+ "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
+ src0_spad->size_per_thread, src0_row_size_aligned);
+ return;
+ }
+
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
+ for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
+ dma_queue_push_vtcm_to_ddr(dma_queue,
+ dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
+ dst_row_size, dst_row_size_aligned, 0);
+
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, block_size);
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
+ src1_row_size_aligned, src1_row_size, block_size);
+ }
+
+ for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
+ float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
+ float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
+
+ for (uint32_t ib = 0; ib < block_size; ib++) {
+ const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
+ const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
+ uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
+
+ // geglu tanh implementation
+ // geglu(x, g) = gelu(x) * g
+ // gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
+ hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x
+ hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A
+ hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
+ hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
+ hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
+ hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res)
+ hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
+ hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
+ hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f
+ hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g
+ }
+
+ dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
+ dst_row_size_aligned, block_size);
+
+ // prefetch N+2 loop iteration if any
+ const uint32_t pref_block = (ir + BLOCK * 2);
+ if (pref_block < src0_end_row) {
+ const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, pref_block_size);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
+ src1_row_size_aligned, src1_row_size, pref_block_size);
+ }
+ }
+
+ dma_queue_flush(dma_queue);
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+ ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = (struct htp_ops_context *) data;
+ unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
+ octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = (struct htp_ops_context *) data;
+ glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
+ &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = (struct htp_ops_context *) data;
+ glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
+ &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = (struct htp_ops_context *) data;
+ glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
+ &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
+static int execute_op_activations_f32(struct htp_ops_context * octx) {
+ int err = HTP_STATUS_OK;
+
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * src1 = &octx->src1;
+ struct htp_tensor * dst = &octx->dst;
+
+ if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) {
+ FARF(ERROR, "Non-contiguous tensors are not supported at this time \n");
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ worker_callback_t act_op_func;
+ const char * op_type = NULL;
+
+ switch (octx->op) {
+ case HTP_OP_UNARY_SILU:
+ act_op_func = unary_silu_f32;
+ op_type = "silu-f32";
+ break;
+
+ case HTP_OP_GLU_SWIGLU:
+ act_op_func = glu_swiglu_f32;
+ op_type = "swiglu-f32";
+ break;
+
+ case HTP_OP_GLU_SWIGLU_OAI:
+ act_op_func = glu_swiglu_oai_f32;
+ op_type = "swiglu-oai-f32";
+ break;
+ case HTP_OP_UNARY_GELU:
+ act_op_func = unary_gelu_f32;
+ op_type = "gelu-f32";
+ break;
+
+ case HTP_OP_GLU_GEGLU:
+ act_op_func = glu_geglu_f32;
+ op_type = "geglu-f32";
+ break;
+ default:
+ FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ const uint32_t n_threads = octx->n_threads;
+ const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+
+ size_t src0_row_size = src0->nb[1];
+ size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used
+ size_t dst_row_size = dst->nb[1];
+
+ const bool src1_valid = src1->ne[0];
+ if (!src1_valid) {
+ src1_row_size = src0_row_size;
+ }
+
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+ const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+ // VTCM scratchpads for all tensors
+ // N rows per thread, padded to HVX vector size
+
+ size_t spad_size_per_row = (src0_row_size_aligned + src1_row_size_aligned) + dst_row_size_aligned;
+ size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads* spad_size_per_row);
+
+ // Make sure the reserved vtcm size is sufficient
+ if(vtcm_row_per_thread ==0){
+ FARF(ERROR, "act-%s : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", op_type, octx->ctx->vtcm_size,
+ spad_size_per_row * n_threads);
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread;
+ octx->src1_spad.size_per_thread = src1_row_size_aligned * vtcm_row_per_thread;
+ octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread;
+
+ octx->dst_spad.size = n_threads* octx->dst_spad.size_per_thread;
+ octx->src0_spad.size = n_threads* octx->src0_spad.size_per_thread;
+ octx->src1_spad.size = n_threads* octx->src1_spad.size_per_thread;
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+ octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+
+ if (src1->ne[0]) {
+ FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
+ op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
+ src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
+ octx->dst_spad.size);
+ } else {
+ FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+ }
+
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ uint32_t n_jobs = MIN(n_threads, src0_nrows);
+ octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
+ }
+
+ return err;
+}
+
+int op_activations(struct htp_ops_context * octx) {
+ int err = HTP_STATUS_OK;
+
+ switch (octx->src0.type) {
+ case HTP_TYPE_F32:
+ err = execute_op_activations_f32(octx);
+ break;
+
+ default:
+ err = HTP_STATUS_NO_SUPPORT;
+ break;
+ }
+
+ return err;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c
new file mode 100644
index 0000000..a4cee98
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c
@@ -0,0 +1,281 @@
+#include <string.h>
+#include <stdlib.h>
+#include <math.h>
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "ggml.h"
+
+#include "hvx-utils.h"
+#include "hex-dma.h"
+
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+#ifndef MIN
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+struct htp_argsort_context {
+ struct htp_ops_context * octx;
+ uint32_t nrows_per_thread;
+};
+
+static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
+{
+ const HVX_Vector one = Q6_V_vsplat_R(1);
+ const HVX_Vector zero = Q6_V_vzero();
+
+ HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
+ HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
+ HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
+ return hvx_vec_get_i32(sum) == 32;
+}
+
+// Sorts values and mirrors swaps to indices.
+static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
+ if (left >= right) return;
+
+ int pivot_idx = (left + right) / 2;
+ float pivot = values[pivot_idx];
+ int i = left;
+ int j = right;
+
+ HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
+ while (i <= j) {
+ // Vectorized scan for i
+ while (i <= j) {
+ // Check if we have at least one full vector
+ if (i + 32 <= j) {
+ HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
+ if (all_greater_f32(pivot_vec, vals_vec)) {
+ // If all elements are < pivot, we can skip this whole block
+ i += 32;
+ continue;
+ }
+ }
+
+ // Scalar fallback / cleanup
+ if (values[i] < pivot) {
+ i++;
+ } else {
+ break;
+ }
+ }
+
+ // Vectorized scan for j
+ while (i <= j) {
+ if (j - 32 >= i) {
+ // Load 32 elements ending at j.
+ // Since we want `values[j] > pivot`, let's load from j-31 to j.
+ HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
+ if (all_greater_f32(vals_vec, pivot_vec)) {
+ j -= 32;
+ continue;
+ }
+ }
+
+ if (values[j] > pivot) {
+ j--;
+ } else {
+ break;
+ }
+ }
+
+ if (i <= j) {
+ float tmp_val = values[i];
+ values[i] = values[j];
+ values[j] = tmp_val;
+
+ int32_t tmp_idx = indices[i];
+ indices[i] = indices[j];
+ indices[j] = tmp_idx;
+ i++;
+ j--;
+ }
+ }
+
+ if (left < j) quicksort_values_indices_asc(values, indices, left, j);
+ if (i < right) quicksort_values_indices_asc(values, indices, i, right);
+}
+
+static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
+ if (left >= right) return;
+
+ int pivot_idx = (left + right) / 2;
+ float pivot = values[pivot_idx];
+ int i = left;
+ int j = right;
+
+ HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
+
+ while (i <= j) {
+ // Vectorized scan for i (values[i] > pivot)
+ while (i <= j) {
+ if (i + 32 <= j) {
+ HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
+ if (all_greater_f32(vals_vec, pivot_vec)) {
+ i += 32;
+ continue;
+ }
+ }
+
+ if (values[i] > pivot) {
+ i++;
+ } else {
+ break;
+ }
+ }
+
+ // Vectorized scan for j (values[j] < pivot)
+ while (i <= j) {
+ if (j - 32 >= i) {
+ HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
+ if (all_greater_f32(pivot_vec, vals_vec)) {
+ j -= 32;
+ continue;
+ }
+ }
+
+ if (values[j] < pivot) {
+ j--;
+ } else {
+ break;
+ }
+ }
+
+ if (i <= j) {
+ float tmp_val = values[i];
+ values[i] = values[j];
+ values[j] = tmp_val;
+
+ int32_t tmp_idx = indices[i];
+ indices[i] = indices[j];
+ indices[j] = tmp_idx;
+ i++;
+ j--;
+ }
+ }
+
+ if (left < j) quicksort_values_indices_desc(values, indices, left, j);
+ if (i < right) quicksort_values_indices_desc(values, indices, i, right);
+}
+
+static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
+ struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
+ struct htp_ops_context * octx = actx->octx;
+
+ // Unpack context
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * dst = &octx->dst;
+
+ // Scratchpad memory
+ uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
+
+ // Dimensions
+ uint32_t ne00 = src0->ne[0];
+ uint32_t ne01 = src0->ne[1];
+ uint32_t ne02 = src0->ne[2];
+ uint32_t ne03 = src0->ne[3];
+
+ uint32_t nb01 = src0->nb[1];
+ //uint32_t nb02 = src0->nb[2];
+ //uint32_t nb03 = src0->nb[3];
+
+ uint32_t nb1 = dst->nb[1];
+ //uint32_t nb2 = dst->nb[2];
+ //uint32_t nb3 = dst->nb[3];
+
+ // Sort order
+ enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
+
+ // Rows to process
+ uint32_t total_rows = ne01 * ne02 * ne03;
+ uint32_t rows_per_thread = actx->nrows_per_thread;
+ uint32_t start_row = rows_per_thread * i;
+ uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
+
+ // Scratchpad layout:
+ // We need space for one row of float data (values) and one row of int32 indices.
+ // values: ne00 * sizeof(float)
+ // indices: ne00 * sizeof(int32_t)
+ // Padded to 128 bytes.
+
+ size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
+ float * values_buf = (float *) spad;
+ int32_t * indices_buf = (int32_t *) (spad + values_size);
+
+ for (uint32_t r = start_row; r < end_row; r++) {
+ uint32_t src_offset = r * nb01;
+ uint32_t dst_offset = r * nb1;
+
+ uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
+ uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset;
+
+ hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
+ hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
+
+ // Initialize indices
+ for (uint32_t j = 0; j < ne00; j++) {
+ indices_buf[j] = j;
+ }
+
+ // Sort values and mirror swaps to indices
+ if (order == GGML_SORT_ORDER_ASC) {
+ quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
+ } else {
+ quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
+ }
+
+ // Copy indices back to DDR
+ hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
+ }
+}
+
+int op_argsort(struct htp_ops_context * octx) {
+ // Check supported types
+ if (octx->src0.type != HTP_TYPE_F32) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ // Allocate scratchpad
+ // We need 1 row of float + 1 row of int32 per thread.
+ uint32_t ne00 = octx->src0.ne[0];
+ size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
+ size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
+ size_t spad_per_thread = values_size + indices_size;
+
+ // Make sure we round up to 256 for alignment requirements
+ spad_per_thread = hex_round_up(spad_per_thread, 256);
+
+ size_t total_spad_size = spad_per_thread * octx->n_threads;
+
+ if (octx->ctx->vtcm_size < total_spad_size) {
+ FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->src0_spad.size = total_spad_size;
+ octx->src0_spad.size_per_thread = spad_per_thread;
+
+ FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
+ octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
+ octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
+ octx->src0.data, octx->dst.data);
+
+ uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
+ uint32_t n_jobs = MIN(total_rows, octx->n_threads);
+
+ struct htp_argsort_context actx;
+ actx.octx = octx;
+ actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
+
+ // Run jobs
+ worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
+
+ return HTP_STATUS_OK;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/binary-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/binary-ops.c
new file mode 100644
index 0000000..00dbcf8
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/binary-ops.c
@@ -0,0 +1,827 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <math.h>
+#include <string.h>
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+#ifndef MIN
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+// Context for binary operations
+struct htp_binary_context {
+ struct htp_ops_context * octx;
+ struct fastdiv_values dim1_div;
+ struct fastdiv_values dim2_div;
+ struct fastdiv_values dim12_div;
+
+ struct fastdiv_values src1_dim1_div; // ne11
+ struct fastdiv_values src1_dim2_div; // ne12
+ struct fastdiv_values src1_dim3_div; // ne13
+
+ uint32_t nrows_per_thread;
+ bool split_at_ne01;
+ bool split_at_ne02;
+
+ // Precomputed values
+ uint32_t block_max;
+ size_t src0_row_size_aligned;
+ size_t src1_row_size_aligned;
+ size_t dst_row_size_aligned;
+ uint32_t src1_fetch_rows; // 1 or block_max
+ uint32_t src1_dma_stride; // 0 or stride
+};
+
+#define htp_binary_preamble \
+ const struct htp_tensor * src0 = &octx->src0; \
+ const struct htp_tensor * src1 = &octx->src1; \
+ struct htp_tensor * dst = &octx->dst; \
+ \
+ const uint32_t ne00 = src0->ne[0]; \
+ const uint32_t ne01 = src0->ne[1]; \
+ const uint32_t ne02 = src0->ne[2]; \
+ const uint32_t ne03 = src0->ne[3]; \
+ \
+ const uint32_t ne10 = src1->ne[0]; \
+ const uint32_t ne11 = src1->ne[1]; \
+ const uint32_t ne12 = src1->ne[2]; \
+ const uint32_t ne13 = src1->ne[3]; \
+ \
+ const uint32_t nb01 = src0->nb[1]; \
+ const uint32_t nb02 = src0->nb[2]; \
+ const uint32_t nb03 = src0->nb[3]; \
+ \
+ const uint32_t nb11 = src1->nb[1]; \
+ const uint32_t nb12 = src1->nb[2]; \
+ const uint32_t nb13 = src1->nb[3]; \
+ \
+ const uint32_t nb1 = dst->nb[1]; \
+ const uint32_t nb2 = dst->nb[2]; \
+ const uint32_t nb3 = dst->nb[3];
+
+static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row,
+ uint32_t ne01, uint32_t ne02) {
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint32_t rows_left = end_row - ir;
+ uint32_t block_limit = rows_left;
+
+ if (bctx->split_at_ne01) {
+ block_limit = MIN(block_limit, ne01 - i01);
+ }
+ if (bctx->split_at_ne02) {
+ uint32_t rows_in_plane = (ne02 * ne01) - rem;
+ block_limit = MIN(block_limit, rows_in_plane);
+ }
+
+ return MIN(bctx->block_max, block_limit);
+}
+
+// Macro for scalar op switch
+#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \
+ case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \
+ case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \
+ case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \
+ default: break; \
+ }
+
+// Macro for vector op switch (All Aligned)
+#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
+ default: break; \
+ }
+
+// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
+#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
+ default: break; \
+ }
+
+// Macro for vector op switch (All Unaligned - generic loop used in element repeat)
+#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
+ default: break; \
+ }
+
+// 1. Scalar src1 (ne10 == 1)
+static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
+ htp_binary_preamble;
+
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
+
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
+
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
+
+ // Preamble
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
+
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
+ }
+
+ // Main loop
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ // src1 indices (broadcast/repeat)
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+ uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div);
+
+ uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+ uint32_t s1_stride = (ne11 == 1) ? 0 : nb11;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+ float val = *(float *)src1_ptr;
+ src1_ptr += s1_stride;
+ COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00);
+ }
+
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ ir_prefetch += next_block_size;
+ }
+ ir += current_block_size;
+ }
+ dma_queue_flush(q);
+}
+
+// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast
+static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
+ htp_binary_preamble;
+
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
+
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t src1_spad_half = octx->src1_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
+
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
+
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint32_t i13 = (ne13 == 1) ? 0 : i03;
+ uint32_t i12 = (ne12 == 1) ? 0 : i02;
+ uint32_t i11 = (ne11 == 1) ? 0 : i01;
+
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
+
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
+ }
+
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+ uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+ COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
+ }
+
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+
+ uint32_t p13 = (ne13 == 1) ? 0 : p03;
+ uint32_t p12 = (ne12 == 1) ? 0 : p02;
+ uint32_t p11 = (ne11 == 1) ? 0 : p01;
+
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+ uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
+
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size);
+
+ ir_prefetch += next_block_size;
+ }
+ ir += current_block_size;
+ }
+ dma_queue_flush(q);
+}
+
+// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1)
+static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
+ htp_binary_preamble;
+
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
+
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
+
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
+
+ void * s1_ptr = (void *) src1_spad;
+
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
+
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
+ }
+
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+ COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
+ }
+
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ ir_prefetch += next_block_size;
+ }
+ ir += current_block_size;
+ }
+ dma_queue_flush(q);
+}
+
+// 4. Vector Complex (ne10 == ne00, complex broadcast)
+static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
+ htp_binary_preamble;
+
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
+
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
+
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
+
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
+
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
+ }
+
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint32_t r_i01 = i01 + r;
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+ uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
+
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+
+ // Read src1 from DDR (unaligned)
+ COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00);
+ }
+
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ ir_prefetch += next_block_size;
+ }
+ ir += current_block_size;
+ }
+ dma_queue_flush(q);
+}
+
+// 5. Element Repeat (ne10 != ne00)
+static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
+ htp_binary_preamble;
+
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
+
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
+
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
+
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
+
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
+ }
+
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint32_t r_i01 = i01 + r;
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+ uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
+
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+
+ // Repeat src1 row
+ for (uint32_t c = 0; c < ne00; c += ne10) {
+ uint32_t len = MIN(ne10, ne00 - c);
+ // Use UUU for speed and simplicity
+ COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len);
+ }
+ }
+
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ ir_prefetch += next_block_size;
+ }
+ ir += current_block_size;
+ }
+ dma_queue_flush(q);
+}
+
+// 6. ADD_ID (src1 gathered via src2 indices)
+static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
+
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * src1 = &octx->src1;
+ const struct htp_tensor * src2 = &octx->src2;
+ struct htp_tensor * dst = &octx->dst;
+
+ const uint32_t ne00 = src0->ne[0];
+ const uint32_t ne01 = src0->ne[1];
+ const uint32_t ne02 = src0->ne[2];
+ const uint32_t ne03 = src0->ne[3];
+ const uint32_t ne11 = src1->ne[1]; // for bounds check
+
+ const uint32_t nb01 = src0->nb[1];
+ const uint32_t nb02 = src0->nb[2];
+ const uint32_t nb03 = src0->nb[3];
+ const uint32_t nb11 = src1->nb[1]; // src1 row stride
+ const uint32_t nb1 = dst->nb[1];
+ const uint32_t nb2 = dst->nb[2];
+ const uint32_t nb3 = dst->nb[3];
+
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
+
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
+
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
+
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
+
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
+ }
+
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint32_t r_i01 = i01 + r; // linear within block since we split at ne01
+
+ const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]);
+
+ uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11;
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+
+ hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00);
+ }
+
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ ir_prefetch += next_block_size;
+ }
+ ir += current_block_size;
+ }
+ dma_queue_flush(q);
+}
+
+static int execute_op_binary_f32(struct htp_ops_context * octx) {
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * src1 = &octx->src1;
+ struct htp_tensor * dst = &octx->dst;
+
+ const uint32_t n_threads = octx->n_threads;
+ const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+
+ // Use packed row sizes for VTCM allocation
+ const size_t src0_row_size = src0->ne[0] * sizeof(float);
+ const size_t src1_row_size = src1->ne[0] * sizeof(float);
+ const size_t dst_row_size = dst->ne[0] * sizeof(float);
+
+ // Align to VLEN
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+ size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
+
+ bool is_add_id = (octx->op == HTP_OP_ADD_ID);
+ bool is_scalar = !is_add_id && (src1->ne[0] == 1);
+
+ // Determine which kernel we will use to alloc memory and dispatch
+ bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] &&
+ (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
+ (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
+ (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
+
+ bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
+ bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]);
+ bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]);
+
+ size_t spad_row_total;
+ if (is_scalar) {
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+ } else if (is_row_bcast) {
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+ } else if (use_vector_same) {
+ spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);
+ } else if (is_add_id) {
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly
+ } else {
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+ }
+
+ size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);
+ // Adjust for static src1 in row_bcast case
+ if (is_row_bcast) {
+ size_t needed_static = src1_row_size_aligned;
+ if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL;
+ size_t avail = octx->ctx->vtcm_size - needed_static;
+ rows_per_buffer = avail / (n_threads * spad_row_total);
+ }
+
+ if (rows_per_buffer < 1) {
+ FARF(ERROR, "binary-f32: VTCM too small\n");
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;
+ octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned;
+
+ if (is_scalar || use_complex || use_repeat || is_add_id) {
+ octx->src1_spad.size_per_thread = 0;
+ } else if (is_row_bcast) {
+ octx->src1_spad.size_per_thread = 0;
+ } else {
+ octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;
+ }
+
+ octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
+ if (is_row_bcast) {
+ octx->src1_spad.size = src1_row_size_aligned;
+ } else {
+ octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
+ }
+ octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
+
+ if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+ octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+
+ if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ return HTP_STATUS_OK;
+ }
+
+ uint32_t n_jobs = MIN(n_threads, src0_nrows);
+
+ dma_queue * q = octx->ctx->dma[0];
+ if (is_row_bcast) {
+ dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1);
+ }
+
+ struct htp_binary_context bctx;
+ bctx.octx = octx;
+ bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ bctx.block_max = rows_per_buffer;
+ bctx.src0_row_size_aligned = src0_row_size_aligned;
+ bctx.src1_row_size_aligned = src1_row_size_aligned;
+ bctx.dst_row_size_aligned = dst_row_size_aligned;
+
+ bctx.dim1_div = init_fastdiv_values(src0->ne[1]);
+ bctx.dim2_div = init_fastdiv_values(src0->ne[2]);
+ bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
+
+ bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
+ bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
+ bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
+
+ bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);
+ bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
+
+ bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);
+ bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
+
+ bctx.split_at_ne01 = (src0->ne[2] > 1) &&
+ ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
+
+ bctx.split_at_ne02 = (src0->ne[3] > 1) &&
+ ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
+
+ // Precompute specific kernel parameters
+ if (use_vector_same) {
+ bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1];
+ bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer;
+ }
+
+ worker_callback_t worker_func;
+ if (is_add_id) worker_func = binary_job_add_id;
+ else if (is_scalar) worker_func = binary_job_scalar;
+ else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
+ else if (use_vector_same) worker_func = binary_job_vector_same_shape;
+ else if (use_complex) worker_func = binary_job_vector_complex;
+ else worker_func = binary_job_element_repeat;
+
+ if (is_row_bcast) {
+ dma_queue_pop(q);
+ }
+
+ worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs);
+
+ return HTP_STATUS_OK;
+}
+
+int op_binary(struct htp_ops_context * octx) {
+ if (octx->src0.type == HTP_TYPE_F32) {
+ return execute_op_binary_f32(octx);
+ }
+ return HTP_STATUS_NO_SUPPORT;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake b/llama.cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake
new file mode 100644
index 0000000..7fa236e
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake
@@ -0,0 +1,157 @@
+if (HEXAGON_TOOLCHAIN_INCLUDED)
+ return()
+endif()
+set(HEXAGON_TOOLCHAIN_INCLUDED true)
+
+#Cross Compiling for Hexagon
+set(HEXAGON TRUE)
+set(CMAKE_SYSTEM_NAME QURT)
+set(CMAKE_SYSTEM_PROCESSOR Hexagon)
+set(CMAKE_SYSTEM_VERSION "1") #${HEXAGON_PLATFORM_LEVEL})
+set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
+set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
+set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
+set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
+set(CUSTOM_RUNELF_PATH "")
+
+#To fix backward compatibility with EAI addon.
+if (NOT HEXAGON_SDK_ROOT)
+ set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
+endif()
+
+if (NOT HEXAGON_TOOLS_ROOT)
+ if (DEFINED ENV{HEXAGON_TOOLS_ROOT})
+ set(HEXAGON_TOOLS_ROOT $ENV{HEXAGON_TOOLS_ROOT})
+ endif()
+ if(NOT HEXAGON_TOOLS_ROOT)
+ set(HEXAGON_TOOLS_ROOT $ENV{DEFAULT_HEXAGON_TOOLS_ROOT})
+ endif()
+endif()
+
+file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
+file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
+
+#Get the Binary extension of the Hexagon Toolchain
+if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
+ set(HEXAGON_TOOLCHAIN_SUFFIX .exe)
+endif()
+message(DEBUG "CMAKE_HOST_SYSTEM_NAME:${CMAKE_HOST_SYSTEM_NAME}")
+
+include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_arch.cmake)
+
+set(HEXAGON_TOOLCHAIN ${HEXAGON_TOOLS_ROOT})
+set(HEXAGON_LIB_DIR "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib")
+set(HEXAGON_ISS_DIR ${HEXAGON_TOOLCHAIN}/Tools/lib/iss)
+
+set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES
+ HEXAGON_SDK_ROOT
+ HEXAGON_TOOLS_ROOT
+)
+
+#QURT Related includes and linker flags
+set(V_ARCH ${HEXAGON_ARCH})
+set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}")
+set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}")
+
+if( ${TREE} MATCHES PAKMAN )
+ set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}")
+endif()
+message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}")
+set(RTOS_DIR ${_QURT_INSTALL_DIR})
+set(QCC_DIR "${HEXAGON_QCC_DIR}/${V_ARCH}/G0")
+set(TARGET_DIR "${HEXAGON_LIB_DIR}/${V_ARCH}/G0")
+
+include_directories(
+ ${_QURT_INSTALL_DIR}/include
+ ${_QURT_INSTALL_DIR}/include/qurt
+ ${_QURT_INSTALL_DIR}/include/posix
+ )
+
+set(QURT_START_LINK_LIBS)
+set(QURT_START_LINK_LIBS
+ "${TARGET_DIR}/init.o"
+ "${RTOS_DIR}/lib/crt1.o"
+ "${RTOS_DIR}/lib/debugmon.o"
+ "${RTOS_DIR}/lib/libqurt.a"
+ "${TARGET_DIR}/libc.a"
+ "${TARGET_DIR}/libqcc.a"
+ "${TARGET_DIR}/libhexagon.a"
+ "${RTOS_DIR}/lib/libqurtcfs.a"
+ "${RTOS_DIR}/lib/libtimer_island.a"
+ "${RTOS_DIR}/lib/libtimer_main.a"
+ "${RTOS_DIR}/lib/libposix.a"
+ )
+STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}")
+
+set(QURT_END_LINK_LIBS
+ ${TARGET_DIR}/fini.o
+ )
+
+#Non QURT related includes and linker flags
+
+set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}")
+
+if (NOT NO_WRAP_MEM_API)
+ set(WRAP_MALLOC -Wl,--wrap=malloc)
+ set(WRAP_CALLOC -Wl,--wrap=calloc)
+ set(WRAP_FREE -Wl,--wrap=free)
+ set(WRAP_REALLOC -Wl,--wrap=realloc)
+ set(WRAP_MEMALIGN -Wl,--wrap=memalign)
+endif()
+
+set(PIC_SHARED_LD_FLAGS
+ -mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}
+ -G0
+ -fpic
+ -Wl,-Bsymbolic
+ -Wl,-L${TARGET_DIR_NOOS}/G0/pic
+ -Wl,-L${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/
+ -Wl,--no-threads ${WRAP_MALLOC} ${WRAP_CALLOC} ${WRAP_FREE} ${WRAP_REALLOC} ${WRAP_MEMALIGN}
+ -shared
+ "-o <TARGET> <SONAME_FLAG><TARGET_SONAME>"
+ "<LINK_FLAGS>"
+ -Wl,--start-group
+ "<OBJECTS>"
+ "<LINK_LIBRARIES>"
+ -Wl,--end-group
+ -lc
+ )
+STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}")
+
+set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}")
+
+#System include paths
+include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)
+include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)
+include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)
+
+#LLVM toolchain setup
+#Compiler paths, options and architecture
+set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})
+set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
+set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})
+set(CMAKE_ASM_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
+set(HEXAGON_LINKER ${CMAKE_C_COMPILER})
+set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)
+
+set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
+set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
+
+#Compiler Options
+set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
+
+set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
+set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
+set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
+
+set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
+set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
+set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
+
+set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}")
+set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}")
+set(CMAKE_ASM_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}" )
+
+#Linker Options
+set(CMAKE_C_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}")
+set(CMAKE_CXX_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}")
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c
new file mode 100644
index 0000000..559ca18
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c
@@ -0,0 +1,251 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+
+struct htp_copy_context {
+ struct htp_ops_context * octx;
+
+ uint32_t src0_type_size;
+ uint32_t src0_block_size;
+
+ uint32_t dst_type_size;
+ uint32_t dst_block_size;
+
+ uint32_t src0_blocks_per_row;
+ uint32_t dst_blocks_per_row;
+
+ uint32_t src0_nrows_per_thread;
+
+ void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith);
+};
+
+#define cpy_preamble \
+ struct htp_tensor *src0 = &octx->src0; \
+ struct htp_tensor *dst = &octx->dst; \
+ \
+ const uint32_t ne00 = src0->ne[0]; \
+ const uint32_t ne01 = src0->ne[1]; \
+ const uint32_t ne02 = src0->ne[2]; \
+ const uint32_t ne03 = src0->ne[3]; \
+ \
+ const uint32_t nb00 = src0->nb[0]; \
+ const uint32_t nb01 = src0->nb[1]; \
+ const uint32_t nb02 = src0->nb[2]; \
+ const uint32_t nb03 = src0->nb[3]; \
+ \
+ const uint32_t ne0 = dst->ne[0]; \
+ const uint32_t ne1 = dst->ne[1]; \
+ const uint32_t ne2 = dst->ne[2]; \
+ const uint32_t ne3 = dst->ne[3]; \
+ \
+ const uint32_t nb0 = dst->nb[0]; \
+ const uint32_t nb1 = dst->nb[1]; \
+ const uint32_t nb2 = dst->nb[2]; \
+ const uint32_t nb3 = dst->nb[3]; \
+ \
+ const uint32_t nr = ne01;
+
+static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
+ cpy_preamble;
+
+ // parallelize by src0 rows
+ const uint32_t dr = ct->src0_nrows_per_thread;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
+
+ // copy by rows
+ for (uint32_t i03 = 0; i03 < ne03; i03++) {
+ for (uint32_t i02 = 0; i02 < ne02; i02++) {
+ #pragma unroll(2)
+ for (uint32_t i01 = ir0; i01 < ir1; i01++) {
+ uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
+ uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2);
+ hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size);
+ }
+ }
+ }
+}
+
+static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) {
+ cpy_preamble;
+
+ // parallelize by src0 rows
+ const uint32_t dr = ct->src0_nrows_per_thread;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
+
+ // dst counters
+ int64_t k10 = 0;
+ int64_t i11 = 0;
+ int64_t i12 = 0;
+ int64_t i13 = 0;
+
+ // number of blocks in a row
+ const int64_t nk00 = ct->src0_blocks_per_row;
+ const int64_t nk0 = ct->dst_blocks_per_row;
+
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ k10 += nk00 * ir0;
+ while (k10 >= nk0) {
+ k10 -= nk0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t k00 = 0; k00 < nk00; k00++) {
+ const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+ memcpy(dst_ptr, src0_ptr, ct->dst_type_size);
+
+ if (++k10 == nk0) {
+ k10 = 0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ k10 += nk00 * (ne01 - ir1);
+ while (k10 >= nk0) {
+ k10 -= nk0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
+ cpy_preamble;
+
+ // parallelize by src0 rows
+ const uint32_t dr = ct->src0_nrows_per_thread;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
+
+ // copy by rows
+ for (uint32_t i03 = 0; i03 < ne03; i03++) {
+ for (uint32_t i02 = 0; i02 < ne02; i02++) {
+ #pragma unroll(2)
+ for (uint32_t i01 = ir0; i01 < ir1; i01++) {
+ uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
+ uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ hex_l2fetch(src0_ptr, ne00 * sizeof(float), nb01, 2);
+ hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
+ }
+ }
+ }
+}
+
+static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
+ cpy_preamble;
+
+ // parallelize by src0 rows
+ const uint32_t dr = ct->src0_nrows_per_thread;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
+
+ // copy by rows
+ for (uint32_t i03 = 0; i03 < ne03; i03++) {
+ for (uint32_t i02 = 0; i02 < ne02; i02++) {
+ #pragma unroll(2)
+ for (uint32_t i01 = ir0; i01 < ir1; i01++) {
+ uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
+ uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ hex_l2fetch(src0_ptr, ne00 * sizeof(__fp16), nb01, 2);
+ hvx_copy_f32_f16_uu(dst_ptr, src0_ptr, ne00);
+ }
+ }
+ }
+}
+
+static void cpy_work_func(unsigned int n, unsigned int i, void *data) {
+ struct htp_copy_context *ct = (struct htp_copy_context *) data;
+ ct->copy(ct, ct->octx, n, i);
+}
+
+int op_cpy(struct htp_ops_context * octx) {
+ cpy_preamble;
+
+ struct htp_copy_context ct;
+ ct.octx = octx;
+
+ switch (src0->type) {
+ case HTP_TYPE_F32: ct.src0_type_size = 4; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
+ case HTP_TYPE_F16: ct.src0_type_size = 2; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
+ default:
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ switch (dst->type) {
+ case HTP_TYPE_F32: ct.dst_type_size = 4; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
+ case HTP_TYPE_F16: ct.dst_type_size = 2; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
+ default:
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+ return HTP_STATUS_OK;
+ }
+
+ const bool sametype = (src0->type == dst->type);
+ const bool transposed = (nb00 > nb01) || (nb0 > nb1);
+ const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3);
+
+ const uint32_t n_jobs = MIN(nr, octx->n_threads);
+ ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+
+ if (sametype && sameshape) {
+ ct.copy = cpy_thread_sametype_sameshape;
+ } else if (sameshape) {
+ /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32)
+ ct.copy = cpy_thread_f16_f32_sameshape;
+ else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16)
+ ct.copy = cpy_thread_f32_f16_sameshape;
+ else
+ return HTP_STATUS_NO_SUPPORT;
+ } else if (sametype) {
+ ct.copy = cpy_thread_sametype_reshape;
+ } else {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs);
+
+ return HTP_STATUS_OK;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c
new file mode 100644
index 0000000..c184637
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c
@@ -0,0 +1,684 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <assert.h>
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+#include <math.h>
+#include <string.h>
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
+ return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+}
+
+// Dot product of FP32 and FP16 vectors, accumulating to float
+static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
+ const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
+
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
+ }
+
+ if (nloe) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
+
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
+
+ // Zero-out unused elements
+ // Note that we need to clear both x and y because they may contain NANs
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ x_hf = Q6_V_vand_QV(bmask, x_hf);
+ y_hf = Q6_V_vand_QV(bmask, y_hf);
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
+ }
+
+ rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
+ hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
+}
+
+// Dot product of FP32 and FP16 vectors, accumulating to float
+static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
+ const void * restrict y,
+ const void * restrict x0,
+ const void * restrict x1,
+ unsigned int n,
+ float s) {
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
+ const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
+ const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector rsum0 = Q6_V_vsplat_R(0);
+ HVX_Vector rsum1 = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(2)
+ for (i = 0; i < nvec; i++) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
+ // Load x (fp16)
+ HVX_Vector x0_hf = vx0[i];
+ HVX_Vector x1_hf = vx1[i];
+
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+ rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+ rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+ }
+
+ if (nloe) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
+
+ // Load x (fp16)
+ HVX_Vector x0_hf = vx0[i];
+ HVX_Vector x1_hf = vx1[i];
+
+ // Zero-out unused elements
+ // Note that we need to clear both x and y because they may contain NANs
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ x0_hf = Q6_V_vand_QV(bmask, x0_hf);
+ x1_hf = Q6_V_vand_QV(bmask, x1_hf);
+ y_hf = Q6_V_vand_QV(bmask, y_hf);
+
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+ rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+ rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+ }
+
+ HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
+ hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
+}
+
+// Dot product of two F16 vectors, accumulating to float
+static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
+ const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ HVX_Vector y_hf = vy[i];
+ HVX_Vector x_hf = vx[i];
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
+ }
+
+ if (nloe) {
+ HVX_Vector y_hf = vy[i];
+
+ // Load x (fp16) and zero-out unused elements
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
+ }
+
+ rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
+ hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
+}
+
+static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
+ const void * restrict y,
+ const void * restrict x0,
+ const void * restrict x1,
+ unsigned int n,
+ float s) {
+ const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
+ const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector rsum0 = Q6_V_vsplat_R(0);
+ HVX_Vector rsum1 = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ HVX_Vector y_hf = vy[i];
+ HVX_Vector x0_hf = vx0[i];
+ HVX_Vector x1_hf = vx1[i];
+
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+ rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+ rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+ }
+
+ if (nloe) {
+ HVX_Vector y_hf = vy[i];
+
+ // Load x (fp16) and zero-out unused elements
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
+ HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
+
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+ rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+ rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+ }
+
+ HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
+ hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
+}
+
+// MAD: y (F32) += x (F16) * s (float)
+static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
+ const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
+ HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ HVX_Vector S = hvx_vec_splat_f16(s);
+
+ uint32_t i = 0;
+ #pragma unroll(4)
+ for (i = 0; i < nvec; ++i) {
+ // Multiply x * s -> pair of F32 vectors
+ HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
+ ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
+ ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
+ }
+
+ if (nloe) {
+ HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
+
+ HVX_Vector xs = Q6_V_lo_W(xs_p);
+ i = 2 * i; // index for ptr_y
+
+ if (nloe >= 32) {
+ ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
+ nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
+ }
+
+ if (nloe) {
+ HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
+ hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
+ }
+ }
+}
+
+#define FLASH_ATTN_BLOCK_SIZE 128
+
+static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
+ const struct htp_tensor * q = &octx->src0;
+ const struct htp_tensor * k = &octx->src1;
+ const struct htp_tensor * v = &octx->src2;
+ const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
+ const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
+ struct htp_tensor * dst = &octx->dst;
+
+ const uint32_t neq0 = q->ne[0];
+ const uint32_t neq1 = q->ne[1];
+ const uint32_t neq2 = q->ne[2];
+ const uint32_t neq3 = q->ne[3];
+
+ const uint32_t nek0 = k->ne[0];
+ const uint32_t nek1 = k->ne[1];
+ const uint32_t nek2 = k->ne[2];
+ const uint32_t nek3 = k->ne[3];
+
+ const uint32_t nev0 = v->ne[0];
+ const uint32_t nev1 = v->ne[1];
+ const uint32_t nev2 = v->ne[2];
+ const uint32_t nev3 = v->ne[3];
+
+ const uint32_t nbq1 = q->nb[1];
+ const uint32_t nbq2 = q->nb[2];
+ const uint32_t nbq3 = q->nb[3];
+
+ const uint32_t nbk1 = k->nb[1];
+ const uint32_t nbk2 = k->nb[2];
+ const uint32_t nbk3 = k->nb[3];
+
+ const uint32_t nbv1 = v->nb[1];
+ const uint32_t nbv2 = v->nb[2];
+ const uint32_t nbv3 = v->nb[3];
+
+ const uint32_t ne1 = dst->ne[1];
+ const uint32_t ne2 = dst->ne[2];
+ const uint32_t ne3 = dst->ne[3];
+
+ const uint32_t nb1 = dst->nb[1];
+ const uint32_t nb2 = dst->nb[2];
+ const uint32_t nb3 = dst->nb[3];
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+ float logit_softcap = 0.0f;
+
+ memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
+ memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
+
+ if (logit_softcap != 0) {
+ scale /= logit_softcap;
+ }
+
+ // total rows in q
+ const uint32_t nr = neq1*neq2*neq3;
+
+ const uint32_t dr = (nr + nth - 1) / nth;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = MIN(ir0 + dr, nr);
+
+ if (ir0 >= ir1) return;
+
+ dma_queue * dma = octx->ctx->dma[ith];
+
+ const uint32_t DK = nek0;
+ const uint32_t DV = nev0;
+
+ const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
+ const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
+
+ const size_t size_k_row = DK * sizeof(__fp16);
+ const size_t size_v_row = DV * sizeof(__fp16);
+ const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
+
+ const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
+ const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
+
+ const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
+
+ // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
+ uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
+ uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
+ uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
+ uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
+ uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
+
+ const uint32_t n_head = neq2;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ for (uint32_t ir = ir0; ir < ir1; ++ir) {
+ const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
+ const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
+ const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
+
+ const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
+ const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
+
+ const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
+ const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
+
+ // Fetch Q row
+ const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
+ dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
+
+ const uint32_t h = iq2; // head index
+ const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
+
+ float S = 0.0f; // sum
+ float M = -INFINITY; // maximum KQ value
+
+ // Clear accumulator
+ hvx_splat_f32_a(spad_a, 0, DV);
+ float * VKQ32 = (float *) spad_a;
+
+ const __fp16 * mp_base = NULL;
+ if (mask) {
+ const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
+ const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
+ mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
+ }
+
+ const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
+
+ // Prefetch first two blocks
+ for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
+ const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
+ const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
+
+ // K
+ const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
+ uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
+ dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
+
+ // V
+ const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
+ uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
+ dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
+
+ // Mask
+ if (mask) {
+ const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
+ uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
+ // Mask is 1D contiguous for this row
+ dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
+ }
+ }
+
+ const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
+
+ for (uint32_t ib = 0; ib < n_blocks; ++ib) {
+ const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
+ const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
+
+ // Wait for DMA
+ uint8_t * k_base = dma_queue_pop(dma).dst; // K
+ uint8_t * v_base = dma_queue_pop(dma).dst; // V
+ __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
+
+ // Inner loop processing the block from VTCM
+ uint32_t ic = 0;
+
+ const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
+
+ // Process in blocks of 32 (VLEN_FP32)
+ static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
+ HVX_Vector_x4 scores_x4;
+ HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
+ for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
+ // 1. Compute scores
+ float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
+ for (int j = 0; j < VLEN_FP32; j += 2) {
+ const uint32_t cur_ic = ic + j;
+ const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
+ if (is_q_fp32) {
+ hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
+ } else {
+ hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
+ }
+ }
+
+ HVX_Vector scores = *(HVX_Vector *) scores_arr;
+
+ // 2. Softcap
+ if (logit_softcap != 0.0f) {
+ scores = hvx_vec_tanh_f32(scores);
+ scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
+ scores = Q6_Vsf_equals_Vqf32(scores);
+ }
+
+ // 3. Mask
+ if (mask) {
+ const __fp16 * mp = m_base + ic;
+ HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
+
+ HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
+ HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
+
+ HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
+
+ HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
+ HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
+ scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
+ scores = Q6_Vsf_equals_Vqf32(scores);
+ }
+
+ scores_x4.v[iv] = scores;
+ v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
+ }
+
+ {
+ // 4. Online Softmax Update
+ v_max = hvx_vec_reduce_max_f32(v_max);
+ float m_block = hvx_vec_get_f32(v_max);
+ float M_old = M;
+ float M_new = (m_block > M) ? m_block : M;
+ M = M_new;
+
+ const float ms = expf(M_old - M_new);
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+
+ HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
+ HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
+ for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
+ HVX_Vector scores = scores_x4.v[iv];
+ HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
+ HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
+
+ p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
+
+ // 5. Accumulate V
+ float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
+ *(HVX_Vector*)p_arr = P;
+
+ for (int j = 0; j < VLEN_FP32; ++j) {
+ const uint32_t cur_ic = ic2 + j;
+ const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
+ hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
+ }
+ }
+
+ p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
+ S = S * ms + hvx_vec_get_f32(p_sum_vec);
+ }
+
+ // Leftover
+ for (; ic < current_block_size; ++ic) {
+ float s_val;
+ const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
+
+ if (is_q_fp32) {
+ hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
+ } else {
+ hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
+ }
+
+ if (logit_softcap != 0.0f) {
+ s_val = logit_softcap * tanhf(s_val);
+ }
+
+ if (mask) {
+ const float m_val = m_base[ic];
+ s_val += slope * m_val;
+ }
+
+ const float Mold = M;
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (s_val > M) {
+ M = s_val;
+ ms = expf(Mold - M);
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+ } else {
+ vs = expf(s_val - M);
+ }
+
+ const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
+
+ hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
+
+ S = S * ms + vs;
+ }
+
+ // Issue DMA for next+1 block (if exists)
+ if (ib + 2 < n_blocks) {
+ const uint32_t next_ib = ib + 2;
+ const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
+ const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
+
+ // K
+ const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
+ dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
+
+ // V
+ const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
+ dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
+
+ // Mask
+ if (mask) {
+ const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
+ dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
+ }
+ }
+ }
+
+ // sinks
+ if (sinks) {
+ const float s = ((float *)((char *) sinks->data))[h];
+
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (s > M) {
+ ms = expf(M - s);
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+ } else {
+ vs = expf(s - M);
+ }
+
+ S = S * ms + vs;
+ }
+
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
+
+ // Store result
+ // dst indices
+ const int i1 = iq1;
+ const int i2 = iq2;
+ const int i3 = iq3;
+
+ // dst is permuted
+ uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
+
+ if (dst->type == HTP_TYPE_F32) {
+ hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
+ } else if (dst->type == HTP_TYPE_F16) {
+ hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
+ }
+ }
+}
+
+static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+ flash_attn_ext_f16_thread(octx, i, n);
+}
+
+int op_flash_attn_ext(struct htp_ops_context * octx) {
+ const struct htp_tensor * q = &octx->src0;
+ const struct htp_tensor * k = &octx->src1;
+ const struct htp_tensor * v = &octx->src2;
+ const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
+ struct htp_tensor * dst = &octx->dst;
+
+ // Check support
+ if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
+ k->type != HTP_TYPE_F16 ||
+ v->type != HTP_TYPE_F16) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
+ octx->src0_div1 = init_fastdiv_values(q->ne[1]);
+
+ octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
+ octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
+ octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
+ octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
+
+ if (mask) {
+ octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
+ octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
+ }
+
+ size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
+ size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
+ size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
+
+ size_t size_q_block = size_q_row_padded * 1; // single row for now
+ size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
+
+ size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
+
+ octx->src0_spad.size_per_thread = size_q_block * 1;
+ octx->src1_spad.size_per_thread = size_k_block * 2;
+ octx->src2_spad.size_per_thread = size_v_block * 2;
+ octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
+ octx->dst_spad.size_per_thread = size_vkq_acc;
+
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
+ octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
+ octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
+
+ size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
+
+ if (octx->ctx->vtcm_size < total_spad) {
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+ octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+ octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
+ octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
+
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
+ }
+
+ return HTP_STATUS_OK;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c
new file mode 100644
index 0000000..a657cd2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c
@@ -0,0 +1,106 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+
+#define get_rows_preamble \
+ const uint32_t ne00 = octx->src0.ne[0]; \
+ const uint32_t ne01 = octx->src0.ne[1]; \
+ const uint32_t ne02 = octx->src0.ne[2]; \
+ const uint32_t ne03 = octx->src0.ne[3]; \
+ \
+ const uint32_t ne10 = octx->src1.ne[0]; \
+ const uint32_t ne11 = octx->src1.ne[1]; \
+ const uint32_t ne12 = octx->src1.ne[2]; \
+ \
+ const uint32_t nb01 = octx->src0.nb[1]; \
+ const uint32_t nb02 = octx->src0.nb[2]; \
+ const uint32_t nb03 = octx->src0.nb[3]; \
+ \
+ const uint32_t nb10 = octx->src1.nb[0]; \
+ const uint32_t nb11 = octx->src1.nb[1]; \
+ const uint32_t nb12 = octx->src1.nb[2]; \
+ \
+ const uint32_t nb1 = octx->dst.nb[1]; \
+ const uint32_t nb2 = octx->dst.nb[2]; \
+ const uint32_t nb3 = octx->dst.nb[3]; \
+ \
+ const uint32_t nr = ne10 * ne11 * ne12;
+
+static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+ get_rows_preamble;
+
+ // parallelize by src1 elements (which correspond to dst rows)
+ const uint32_t dr = octx->src1_nrows_per_thread;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
+
+ const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
+
+ for (uint32_t i = ir0; i < ir1; ++i) {
+ const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
+ const uint32_t rem = i - i12 * ne11 * ne10;
+ const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
+ const uint32_t i10 = rem - i11 * ne10;
+
+ const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
+
+ uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
+
+ if (i01 >= ne01) {
+ // invalid index, skip for now to avoid crash
+ continue;
+ }
+
+ const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
+ const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
+ hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
+ }
+
+ return HTP_STATUS_OK;
+}
+
+static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
+ get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
+}
+
+int op_get_rows(struct htp_ops_context * octx) {
+ get_rows_preamble;
+
+ if (octx->src0.type != HTP_TYPE_F32) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->dst.type != HTP_TYPE_F32) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+ return HTP_STATUS_OK;
+ }
+
+ octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
+ octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
+
+ const uint32_t n_jobs = MIN(nr, octx->n_threads);
+ octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+
+ worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
+ return HTP_STATUS_OK;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.c b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.c
new file mode 100644
index 0000000..44e1be4
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.c
@@ -0,0 +1,63 @@
+#include "hex-dma.h"
+
+#include <stdbool.h>
+#include <stdlib.h>
+#include <string.h>
+
+#pragma clang diagnostic ignored "-Wunused-function"
+
+static inline uint32_t pow2_ceil(uint32_t x) {
+ if (x <= 1) {
+ return 1;
+ }
+ int p = 2;
+ x--;
+ while (x >>= 1) {
+ p <<= 1;
+ }
+ return p;
+}
+
+dma_queue * dma_queue_create(size_t capacity) {
+ dma_queue * q = (dma_queue *) memalign(32, sizeof(dma_queue));
+ if (q == NULL) {
+ FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__);
+ return NULL;
+ }
+
+ capacity = pow2_ceil(capacity);
+
+ memset(q, 0, sizeof(dma_queue));
+ q->capacity = capacity;
+ q->idx_mask = capacity - 1;
+
+ q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t));
+ memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t));
+
+ q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr));
+ memset(q->dptr, 0, capacity * sizeof(dma_ptr));
+
+ q->tail = &q->desc[capacity - 1];
+
+ if (!q->desc && !q->dptr) {
+ FARF(ERROR, "%s: failed to allocate DMA queue items\n", __FUNCTION__);
+ return NULL;
+ }
+
+ FARF(HIGH, "dma-queue: capacity %u\n", capacity);
+
+ return q;
+}
+
+void dma_queue_delete(dma_queue * q) {
+ if (!q) {
+ return;
+ }
+ free(q->desc);
+ free(q->dptr);
+ free(q);
+}
+
+void dma_queue_flush(dma_queue * q) {
+ while (dma_queue_pop(q).dst != NULL) ;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.h
new file mode 100644
index 0000000..d1ddb0e
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dma.h
@@ -0,0 +1,156 @@
+#ifndef HTP_DMA_H
+#define HTP_DMA_H
+
+#include <HAP_farf.h>
+#include <hexagon_types.h>
+#include <stdbool.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef struct {
+ void *dst;
+ const void *src;
+} dma_ptr;
+
+typedef struct {
+ hexagon_udma_descriptor_type1_t * desc; // descriptor pointers
+ hexagon_udma_descriptor_type1_t * tail; // tail pointer
+ dma_ptr * dptr; // dst/src pointers
+ uint32_t push_idx;
+ uint32_t pop_idx;
+ uint32_t capacity;
+ uint32_t idx_mask;
+} dma_queue;
+
+dma_queue * dma_queue_create(size_t capacity);
+void dma_queue_delete(dma_queue * q);
+void dma_queue_flush(dma_queue * q);
+
+// TODO: technically we don't need these and could use Q6_dmstart/wait/etc instead
+// but those do not seem to always compiler properly.
+static inline void dmstart(void * next) {
+ asm volatile(" release(%0):at" : : "r"(next));
+ asm volatile(" dmstart(%0)" : : "r"(next));
+}
+
+static inline void dmlink(void * cur, void * next) {
+ asm volatile(" release(%0):at" : : "r"(next));
+ asm volatile(" dmlink(%0, %1)" : : "r"(cur), "r"(next));
+}
+
+static inline unsigned int dmpoll(void) {
+ unsigned int ret = 0;
+ asm volatile(" %0 = dmpoll" : "=r"(ret) : : "memory");
+ return ret;
+}
+
+static inline unsigned int dmwait(void) {
+ unsigned int ret = 0;
+ asm volatile(" %0 = dmwait" : "=r"(ret) : : "memory");
+ return ret;
+}
+
+static inline dma_ptr dma_make_ptr(void *dst, const void *src)
+{
+ dma_ptr p = { dst, src };
+ return p;
+}
+
+static inline bool dma_queue_push(dma_queue * q,
+ dma_ptr dptr,
+ size_t dst_row_size,
+ size_t src_row_size,
+ size_t width, // width in bytes. number of bytes to transfer per row
+ size_t nrows) {
+ if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
+ FARF(ERROR, "dma-push: queue full\n");
+ return false;
+ }
+
+ hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx];
+
+ desc->next = NULL;
+ desc->length = 0;
+ desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1;
+ desc->dstbypass = 1;
+ desc->srcbypass = 1;
+#if __HVX_ARCH__ >= 73
+ desc->dstbypass = 1;
+ desc->srcbypass = 1;
+#else
+ desc->dstbypass = 0;
+ desc->srcbypass = 1;
+#endif
+ desc->order = 0;
+ desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE;
+ desc->src = (void *) dptr.src;
+ desc->dst = (void *) dptr.dst;
+ desc->allocation = 0;
+ desc->padding = 0;
+ desc->roiwidth = width;
+ desc->roiheight = nrows;
+ desc->srcstride = src_row_size;
+ desc->dststride = dst_row_size;
+ desc->srcwidthoffset = 0;
+ desc->dstwidthoffset = 0;
+
+ q->dptr[q->push_idx] = dptr;
+
+ dmlink(q->tail, desc);
+ q->tail = desc;
+
+ // FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
+ q->push_idx = (q->push_idx + 1) & q->idx_mask;
+ return true;
+}
+
+static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q,
+ dma_ptr dptr,
+ size_t dst_row_size,
+ size_t src_row_size,
+ size_t nrows) {
+ return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows);
+}
+
+
+static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q,
+ dma_ptr dptr,
+ size_t dst_row_size,
+ size_t src_row_size,
+ size_t nrows) {
+ return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
+}
+
+static inline dma_ptr dma_queue_pop(dma_queue * q) {
+ dma_ptr dptr = { NULL };
+
+ if (q->push_idx == q->pop_idx) {
+ return dptr;
+ }
+
+ hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx];
+
+ // Wait for desc to complete
+ while (1) {
+ dmpoll();
+ if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) {
+ break;
+ }
+ // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
+ }
+
+ dptr = q->dptr[q->pop_idx];
+
+ // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
+ q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
+ return dptr;
+}
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif /* HTP_DMA_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dump.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dump.h
new file mode 100644
index 0000000..e3badb5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-dump.h
@@ -0,0 +1,77 @@
+#ifndef HEX_DUMP_H
+#define HEX_DUMP_H
+
+#include <HAP_farf.h>
+
+static inline void hex_dump_int8_line(char * pref, const int8_t * x, int n) {
+ char str[1024], *p = str, *p_end = str + sizeof(str);
+ p += snprintf(p, p_end - p, "%s: ", pref);
+ for (int i = 0; i < n && p < p_end; i++) {
+ p += snprintf(p, p_end - p, "%d, ", x[i]);
+ }
+ FARF(HIGH, "%s\n", str);
+}
+
+static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) {
+ char str[1024], *p = str, *p_end = str + sizeof(str);
+ p += snprintf(p, p_end - p, "%s: ", pref);
+ for (int i = 0; i < n && p < p_end; i++) {
+ p += snprintf(p, p_end - p, "%d, ", x[i]);
+ }
+ FARF(HIGH, "%s\n", str);
+}
+
+static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
+ char str[1024], *p = str, *p_end = str + sizeof(str);
+ p += snprintf(p, p_end - p, "%s: ", pref);
+ for (int i = 0; i < n; i++) {
+ p += snprintf(p, p_end - p, "%d, ", (int) x[i]);
+ }
+ FARF(HIGH, "%s\n", str);
+}
+
+static inline void hex_dump_f16_line(char * pref, const __fp16 * x, uint32_t n) {
+ char str[1024], *p = str, *p_end = str + sizeof(str);
+ p += snprintf(p, p_end - p, "%s: ", pref);
+ for (int i = 0; i < n; i++) {
+ p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]);
+ }
+ FARF(HIGH, "%s\n", str);
+}
+
+static inline void hex_dump_f32_line(char * pref, const float * x, uint32_t n) {
+ char str[1024], *p = str, *p_end = str + sizeof(str);
+ p += snprintf(p, p_end - p, "%s: ", pref);
+ for (int i = 0; i < n; i++) {
+ p += snprintf(p, p_end - p, "%.6f, ", x[i]);
+ }
+ FARF(HIGH, "%s\n", str);
+}
+
+static inline void hex_dump_f32(char * pref, const float * x, uint32_t n) {
+ uint32_t n0 = n / 16;
+ uint32_t n1 = n % 16;
+
+ uint32_t i = 0;
+ for (; i < n0; i++) {
+ hex_dump_f32_line(pref, x + (16 * i), 16);
+ }
+ if (n1) {
+ hex_dump_f32_line(pref, x + (16 * i), n1);
+ }
+}
+
+static inline void hex_dump_f16(char * pref, const __fp16 * x, uint32_t n) {
+ uint32_t n0 = n / 16;
+ uint32_t n1 = n % 16;
+
+ uint32_t i = 0;
+ for (; i < n0; i++) {
+ hex_dump_f16_line(pref, x + (16 * i), 16);
+ }
+ if (n1) {
+ hex_dump_f16_line(pref, x + (16 * i), n1);
+ }
+}
+
+#endif /* HEX_DUMP_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h
new file mode 100644
index 0000000..b7b5867
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h
@@ -0,0 +1,37 @@
+#ifndef HEX_FASTDIV_H
+#define HEX_FASTDIV_H
+
+// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
+// Precompute mp (m' in the paper) and L such that division
+// can be computed using a multiply (high 32b of 64b result)
+// and a shift:
+//
+// n/d = (mulhi(n, mp) + n) >> L;
+struct fastdiv_values {
+ uint32_t mp;
+ uint32_t l;
+};
+
+static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
+ struct fastdiv_values result = { 0, 0 };
+ // compute L = ceil(log2(d));
+ while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
+ ++(result.l);
+ }
+
+ result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
+ return result;
+}
+
+static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
+ // Compute high 32 bits of n * mp
+ const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp)
+ // add n, apply bit shift
+ return (hi + n) >> vals->l;
+}
+
+static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
+ return n - fastdiv(n, vals) * d;
+}
+
+#endif /* HEX_FASTDIV_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hex-utils.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-utils.h
new file mode 100644
index 0000000..fb8a25a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hex-utils.h
@@ -0,0 +1,51 @@
+#ifndef HEX_UTILS_H
+#define HEX_UTILS_H
+
+#include <stdbool.h>
+#include <stdint.h>
+
+#include "hexagon_types.h"
+
+#include "hex-fastdiv.h"
+#include "hex-dump.h"
+
+#ifndef MAX
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#endif
+
+#ifndef MIN
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+static inline uint64_t hex_get_cycles() {
+ uint64_t cycles = 0;
+ asm volatile(" %0 = c15:14\n" : "=r"(cycles));
+ return cycles;
+}
+
+static inline uint64_t hex_get_pktcnt() {
+ uint64_t pktcnt;
+ asm volatile(" %0 = c19:18\n" : "=r"(pktcnt));
+ return pktcnt;
+}
+
+static inline int32_t hex_is_aligned(void * addr, uint32_t align) {
+ return ((size_t) addr & (align - 1)) == 0;
+}
+
+static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
+ uint32_t left_off = (size_t) addr & (chunk_size - 1);
+ uint32_t right_off = left_off + n;
+ return right_off <= chunk_size;
+}
+
+static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
+ return m * ((n + m - 1) / m);
+}
+
+static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
+ const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
+ Q6_l2fetch_AP((void *) p, control);
+}
+
+#endif /* HEX_UTILS_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h
new file mode 100644
index 0000000..a707d98
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h
@@ -0,0 +1,35 @@
+#ifndef HTP_CTX_H
+#define HTP_CTX_H
+
+#include "hex-dma.h"
+#include "worker-pool.h"
+
+#include <assert.h>
+#include <dspqueue.h>
+#include <stdatomic.h>
+#include <stdint.h>
+
+#define HTP_MAX_NTHREADS 10
+
+// Main context for htp DSP backend
+struct htp_context {
+ dspqueue_t queue;
+ dma_queue * dma[HTP_MAX_NTHREADS];
+ worker_pool_context_t worker_pool;
+ uint32_t n_threads;
+
+ int thread_id;
+ int thread_prio;
+
+ uint8_t * vtcm_base;
+ size_t vtcm_size;
+ uint32_t vtcm_rctx;
+
+ atomic_bool vtcm_valid;
+ atomic_bool vtcm_inuse;
+ atomic_bool vtcm_needs_release;
+
+ uint32_t opmask;
+};
+
+#endif /* HTP_CTX_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/htp-msg.h b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-msg.h
new file mode 100644
index 0000000..25403bb
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-msg.h
@@ -0,0 +1,154 @@
+#ifndef HTP_MSG_H
+#define HTP_MSG_H
+
+#include <assert.h>
+
+// ggml-common.h must be included prio to this header
+
+// Mask to enable various stages of the Ops.
+// Used for debugging and profiling.
+enum {
+ HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP)
+ HTP_OPMASK_QUANTIZE = (1 << 1), // Enable Quantize
+ HTP_OPMASK_COMPUTE = (1 << 2), // Enable Compute
+};
+
+// Op flags
+enum {
+ HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0), // Skip dynamic quantization (reuse quantized tensors)
+ HTP_OPFLAGS_SKIP_COMPUTE = (1 << 1), // Skip actual computation (used for profiling)
+ HTP_OPFLAGS_EARLY_WAKEUP = (1 << 2) // Send early wakeup notification
+};
+
+enum htp_status {
+ HTP_STATUS_OK = 1,
+ HTP_STATUS_INTERNAL_ERR = 2,
+ HTP_STATUS_NO_SUPPORT = 3,
+ HTP_STATUS_INVAL_PARAMS = 4,
+ HTP_STATUS_VTCM_TOO_SMALL = 5,
+};
+
+// The values must match the ggml_type.
+// Duplicated here because we can't include full ggml.h in the htp build.
+// We have some static_asserts in the cpp code to ensure things are in sync.
+enum htp_data_type {
+ HTP_TYPE_F32 = 0,
+ HTP_TYPE_F16 = 1,
+ HTP_TYPE_Q4_0 = 2,
+ HTP_TYPE_Q8_0 = 8,
+ HTP_TYPE_I32 = 26,
+ HTP_TYPE_I64 = 27,
+ HTP_TYPE_MXFP4 = 39,
+ HTP_TYPE_COUNT
+};
+
+// Do not reorder first 4 (used as an index)
+enum htp_op {
+ HTP_OP_MUL = 0,
+ HTP_OP_ADD = 1,
+ HTP_OP_SUB = 2,
+ HTP_OP_DIV = 3,
+ HTP_OP_MUL_MAT,
+ HTP_OP_MUL_MAT_ID,
+ HTP_OP_RMS_NORM,
+ HTP_OP_UNARY_SILU,
+ HTP_OP_UNARY_GELU,
+ HTP_OP_GLU_SWIGLU,
+ HTP_OP_GLU_SWIGLU_OAI,
+ HTP_OP_GLU_GEGLU,
+ HTP_OP_SOFTMAX,
+ HTP_OP_ADD_ID,
+ HTP_OP_ROPE,
+ HTP_OP_FLASH_ATTN_EXT,
+ HTP_OP_SET_ROWS,
+ HTP_OP_GET_ROWS,
+ HTP_OP_SCALE,
+ HTP_OP_CPY,
+ HTP_OP_ARGSORT,
+ HTP_OP_SQR,
+ HTP_OP_SQRT,
+ HTP_OP_SUM_ROWS,
+ INVALID
+};
+
+static inline size_t htp_t_block_size(uint32_t t) {
+ switch (t) {
+ case HTP_TYPE_F32:
+ return 1;
+ case HTP_TYPE_F16:
+ return 1;
+ case HTP_TYPE_Q4_0:
+ return QK4_0;
+ case HTP_TYPE_Q8_0:
+ return QK8_0;
+ case HTP_TYPE_MXFP4:
+ return QK_MXFP4;
+ default:
+ assert(0 && "unsupported HTP data type");
+ }
+ return 0;
+}
+
+static inline size_t htp_type_nbytes(uint32_t t) {
+ switch (t) {
+ case HTP_TYPE_F32:
+ return 4;
+ case HTP_TYPE_F16:
+ return 2;
+ case HTP_TYPE_Q4_0:
+ return sizeof(block_q4_0);
+ case HTP_TYPE_Q8_0:
+ return sizeof(block_q8_0);
+ case HTP_TYPE_MXFP4:
+ return sizeof(block_mxfp4);
+ default:
+ assert(0 && "unsupported HTP data type");
+ }
+ return 0;
+}
+
+// Internal types
+#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
+#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
+#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
+
+#define HTP_MAX_DIMS 4
+
+struct htp_tensor {
+ uint32_t data; // Buffer offset in the messages, and data pointer on the NSP
+ uint32_t type; // Data type
+ uint32_t ne[HTP_MAX_DIMS]; // Number of elements
+ uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
+};
+
+#define HTP_MAX_OP_PARAMS 64
+
+struct htp_general_req {
+ uint32_t op; // GGML/HTP Op
+ int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
+ // Params for the op, e.g. epsilon of RMS norm
+ uint32_t flags; // Request flags
+
+ struct htp_tensor src0; // Input0 tensor
+ struct htp_tensor src1; // Input1 tensor
+ struct htp_tensor src2; // Input2 tensor
+ struct htp_tensor src3; // Input3 tensor
+ struct htp_tensor src4; // Input4 tensor
+ struct htp_tensor dst; // Output tensor
+
+ // should be multiple of 64 bytes (cacheline)
+};
+
+struct htp_general_rsp {
+ uint32_t op; // GGML/HTP Op
+ uint32_t status; // HTP_STATUS_...
+ uint32_t prof_usecs; // Number of usec per request
+ uint32_t prof_cycles; // Number of cycles per request
+ uint32_t prof_pkts; // Number of instruction packets per request
+ uint8_t unused[44]; // Pad to 64 bytes
+};
+
+#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
+#define HTP_MAX_PACKET_BUFFERS 8
+
+#endif /* HTP_MSG_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ops.h b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ops.h
new file mode 100644
index 0000000..f1ad24d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ops.h
@@ -0,0 +1,91 @@
+#ifndef HTP_OPS_H
+#define HTP_OPS_H
+
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "worker-pool.h"
+
+#include <assert.h>
+#include <stdint.h>
+
+#include <hex-fastdiv.h>
+
+// ggml-common.h must be included prior to this header
+
+struct htp_spad {
+ uint8_t * data;
+ size_t stride;
+ size_t size;
+ size_t size_per_thread;
+};
+
+struct htp_ops_context {
+ struct htp_context * ctx;
+
+ enum htp_op op;
+ int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
+
+ struct htp_tensor src0;
+ struct htp_tensor src1;
+ struct htp_tensor src2;
+ struct htp_tensor src3;
+ struct htp_tensor src4;
+ struct htp_tensor dst;
+
+ struct htp_spad src0_spad;
+ struct htp_spad src1_spad;
+ struct htp_spad src2_spad;
+ struct htp_spad src3_spad;
+ struct htp_spad dst_spad;
+
+ worker_pool_context_t * wpool; // worker pool
+ uint32_t n_threads; // num threads
+
+ uint32_t src0_nrows_per_thread;
+ uint32_t src1_nrows_per_thread;
+
+ struct fastdiv_values src0_div1; // fastdiv values for ne1
+ struct fastdiv_values src0_div2; // fastdiv values for ne2
+ struct fastdiv_values src0_div3; // fastdiv values for ne3
+ struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
+
+ struct fastdiv_values src1_div1; // fastdiv values for ne1
+ struct fastdiv_values src1_div2; // fastdiv values for ne2
+ struct fastdiv_values src1_div3; // fastdiv values for ne3
+ struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
+
+ struct fastdiv_values src3_div1; // fastdiv values for ne1
+ struct fastdiv_values src3_div2; // fastdiv values for ne2
+ struct fastdiv_values src3_div3; // fastdiv values for ne3
+ struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
+
+ struct fastdiv_values broadcast_rk2;
+ struct fastdiv_values broadcast_rk3;
+ struct fastdiv_values broadcast_rv2;
+ struct fastdiv_values broadcast_rv3;
+
+ struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
+ struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
+
+ struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
+ struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
+
+ uint32_t flags;
+};
+
+int op_matmul(struct htp_ops_context * octx);
+int op_matmul_id(struct htp_ops_context * octx);
+int op_binary(struct htp_ops_context * octx);
+int op_unary(struct htp_ops_context * octx);
+int op_sum_rows(struct htp_ops_context * octx);
+int op_activations(struct htp_ops_context * octx);
+int op_softmax(struct htp_ops_context * octx);
+int op_add_id(struct htp_ops_context * octx);
+int op_rope(struct htp_ops_context * octx);
+int op_flash_attn_ext(struct htp_ops_context * octx);
+int op_set_rows(struct htp_ops_context * octx);
+int op_get_rows(struct htp_ops_context * octx);
+int op_cpy(struct htp_ops_context * octx);
+int op_argsort(struct htp_ops_context * octx);
+
+#endif /* HTP_OPS_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl b/llama.cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl
new file mode 100644
index 0000000..9ebd937
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl
@@ -0,0 +1,16 @@
+// FastRPC IDL interface for GGML HTP
+
+#ifndef HTP_IDL
+#define HTP_IDL
+
+#include "AEEStdDef.idl"
+#include "remote.idl"
+
+interface htp_iface : remote_handle64 {
+ AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx);
+ AEEResult stop();
+ AEEResult enable_etm();
+ AEEResult disable_etm();
+};
+
+#endif /* HTP_IDL */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h
new file mode 100644
index 0000000..2577cdd
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h
@@ -0,0 +1,470 @@
+#ifndef HVX_ARITH_H
+#define HVX_ARITH_H
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <math.h>
+
+#include "hvx-base.h"
+#include "hex-utils.h"
+
+//
+// Binary operations (add, mul, sub)
+//
+
+#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src0_type * restrict vsrc0 = (src0_type *) src0; \
+ src1_type * restrict vsrc1 = (src1_type *) src1; \
+ \
+ const uint32_t elem_size = sizeof(float); \
+ const uint32_t epv = 128 / elem_size; \
+ const uint32_t nvec = n / epv; \
+ const uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ vdst[i] = vec_op(vsrc0[i], vsrc1[i]); \
+ } \
+ if (nloe) { \
+ HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \
+ vec_store((void *) &vdst[i], nloe * elem_size, v); \
+ } \
+ } while(0)
+
+#if __HVX_ARCH__ < 79
+#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
+#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
+#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#else
+#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
+#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
+#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#endif
+
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
+static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ assert((uintptr_t) src0 % 128 == 0); \
+ assert((uintptr_t) src1 % 128 == 0); \
+ hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ assert((uintptr_t) src0 % 128 == 0); \
+ hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ assert((uintptr_t) src1 % 128 == 0); \
+ hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) src0 % 128 == 0); \
+ assert((uintptr_t) src1 % 128 == 0); \
+ hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) src0 % 128 == 0); \
+ hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) src1 % 128 == 0); \
+ hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
+} \
+
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
+
+// Dispatcher logic
+#define HVX_BINARY_DISPATCHER(OP_NAME) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
+ if (hex_is_aligned((void *) dst, 128)) { \
+ if (hex_is_aligned((void *) src0, 128)) { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
+ else OP_NAME##_aau(dst, src0, src1, num_elems); \
+ } else { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
+ else OP_NAME##_auu(dst, src0, src1, num_elems); \
+ } \
+ } else { \
+ if (hex_is_aligned((void *) src0, 128)) { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
+ else OP_NAME##_uau(dst, src0, src1, num_elems); \
+ } else { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
+ else OP_NAME##_uuu(dst, src0, src1, num_elems); \
+ } \
+ } \
+}
+
+HVX_BINARY_DISPATCHER(hvx_add_f32)
+HVX_BINARY_DISPATCHER(hvx_sub_f32)
+HVX_BINARY_DISPATCHER(hvx_mul_f32)
+
+// Mul-Mul Optimized
+static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src0 % 128 == 0);
+ assert((unsigned long) src1 % 128 == 0);
+ assert((unsigned long) src2 % 128 == 0);
+
+ HVX_Vector * restrict vdst = (HVX_Vector *) dst;
+ HVX_Vector * restrict vsrc0 = (HVX_Vector *) src0;
+ HVX_Vector * restrict vsrc1 = (HVX_Vector *) src1;
+ HVX_Vector * restrict vsrc2 = (HVX_Vector *) src2;
+
+ const uint32_t elem_size = sizeof(float);
+ const uint32_t epv = 128 / elem_size;
+ const uint32_t nvec = num_elems / epv;
+ const uint32_t nloe = num_elems % epv;
+
+ uint32_t i = 0;
+
+ _Pragma("unroll(4)")
+ for (; i < nvec; i++) {
+ HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
+ vdst[i] = HVX_OP_MUL(v1, vsrc2[i]);
+ }
+
+ if (nloe) {
+ HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
+ HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]);
+ hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2);
+ }
+}
+
+// Scalar Operations
+
+#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const uint32_t elem_size = sizeof(float); \
+ const uint32_t epv = 128 / elem_size; \
+ const uint32_t nvec = n / epv; \
+ const uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ HVX_Vector v = vsrc[i]; \
+ vdst[i] = scalar_op_macro(v); \
+ } \
+ if (nloe) { \
+ HVX_Vector v = vsrc[i]; \
+ v = scalar_op_macro(v); \
+ vec_store((void *) &vdst[i], nloe * elem_size, v); \
+ } \
+ } while(0)
+
+#define HVX_OP_ADD_SCALAR(v) \
+ ({ \
+ const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \
+ HVX_Vector out = HVX_OP_ADD(v, val_vec); \
+ Q6_V_vmux_QVV(pred_inf, inf, out); \
+ })
+
+#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec)
+#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec)
+
+// Add Scalar Variants
+
+static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
+}
+
+static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
+ assert((unsigned long) dst % 128 == 0);
+ hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
+}
+
+static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
+ assert((unsigned long) src % 128 == 0);
+ hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
+}
+
+static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ static const float kInf = INFINITY;
+ const HVX_Vector inf = hvx_vec_splat_f32(kInf);
+ hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
+}
+
+// Sub Scalar Variants
+
+static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
+}
+
+static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ assert((unsigned long) dst % 128 == 0);
+ hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
+}
+
+static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ assert((unsigned long) src % 128 == 0);
+ hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
+}
+
+static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
+}
+
+// Mul Scalar Variants
+
+static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
+}
+
+static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ assert((unsigned long) dst % 128 == 0);
+ hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
+}
+
+static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ assert((unsigned long) src % 128 == 0);
+ hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
+}
+
+static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
+}
+
+static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
+ if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
+ hvx_add_scalar_f32_aa(dst, src, val, num_elems);
+ } else if (hex_is_aligned((void *) dst, 128)) {
+ hvx_add_scalar_f32_au(dst, src, val, num_elems);
+ } else if (hex_is_aligned((void *) src, 128)) {
+ hvx_add_scalar_f32_ua(dst, src, val, num_elems);
+ } else {
+ hvx_add_scalar_f32_uu(dst, src, val, num_elems);
+ }
+}
+
+static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
+ if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
+ hvx_mul_scalar_f32_aa(dst, src, val, num_elems);
+ } else if (hex_is_aligned((void *) dst, 128)) {
+ hvx_mul_scalar_f32_au(dst, src, val, num_elems);
+ } else if (hex_is_aligned((void *) src, 128)) {
+ hvx_mul_scalar_f32_ua(dst, src, val, num_elems);
+ } else {
+ hvx_mul_scalar_f32_uu(dst, src, val, num_elems);
+ }
+}
+
+static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
+ if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
+ hvx_sub_scalar_f32_aa(dst, src, val, num_elems);
+ } else if (hex_is_aligned((void *) dst, 128)) {
+ hvx_sub_scalar_f32_au(dst, src, val, num_elems);
+ } else if (hex_is_aligned((void *) src, 128)) {
+ hvx_sub_scalar_f32_ua(dst, src, val, num_elems);
+ } else {
+ hvx_sub_scalar_f32_uu(dst, src, val, num_elems);
+ }
+}
+
+// MIN Scalar variants
+
+#define HVX_OP_MIN_SCALAR(v) Q6_Vsf_vmin_VsfVsf(val_vec, v)
+
+static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
+}
+
+static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ assert((unsigned long) dst % 128 == 0);
+ hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
+}
+
+static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ assert((unsigned long) src % 128 == 0);
+ hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
+}
+
+static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+ const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+ hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
+}
+
+static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
+ if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
+ hvx_min_scalar_f32_aa(dst, src, val, num_elems);
+ } else if (hex_is_aligned((void *) dst, 128)) {
+ hvx_min_scalar_f32_au(dst, src, val, num_elems);
+ } else if (hex_is_aligned((void *) src, 128)) {
+ hvx_min_scalar_f32_ua(dst, src, val, num_elems);
+ } else {
+ hvx_min_scalar_f32_uu(dst, src, val, num_elems);
+ }
+}
+
+// CLAMP Scalar variants
+
+#define HVX_OP_CLAMP_SCALAR(v) \
+ ({ \
+ HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(v, max_vec); \
+ HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(min_vec, v); \
+ HVX_Vector tmp = Q6_V_vmux_QVV(pred_cap_right, max_vec, v); \
+ Q6_V_vmux_QVV(pred_cap_left, min_vec, tmp); \
+ })
+
+static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
+ const HVX_Vector min_vec = hvx_vec_splat_f32(min);
+ const HVX_Vector max_vec = hvx_vec_splat_f32(max);
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
+}
+
+static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
+ const HVX_Vector min_vec = hvx_vec_splat_f32(min);
+ const HVX_Vector max_vec = hvx_vec_splat_f32(max);
+ assert((unsigned long) dst % 128 == 0);
+ hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
+}
+
+static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
+ const HVX_Vector min_vec = hvx_vec_splat_f32(min);
+ const HVX_Vector max_vec = hvx_vec_splat_f32(max);
+ assert((unsigned long) src % 128 == 0);
+ hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
+}
+
+static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
+ const HVX_Vector min_vec = hvx_vec_splat_f32(min);
+ const HVX_Vector max_vec = hvx_vec_splat_f32(max);
+ hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
+}
+
+static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) {
+ if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
+ hvx_clamp_scalar_f32_aa(dst, src, min, max, num_elems);
+ } else if (hex_is_aligned((void *) dst, 128)) {
+ hvx_clamp_scalar_f32_au(dst, src, min, max, num_elems);
+ } else if (hex_is_aligned((void *) src, 128)) {
+ hvx_clamp_scalar_f32_ua(dst, src, min, max, num_elems);
+ } else {
+ hvx_clamp_scalar_f32_uu(dst, src, min, max, num_elems);
+ }
+}
+
+//
+// Square
+//
+
+#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const uint32_t elem_size = sizeof(float); \
+ const uint32_t epv = 128 / elem_size; \
+ const uint32_t nvec = n / epv; \
+ const uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
+ } \
+ if (nloe) { \
+ HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
+ vec_store((void *) &vdst[i], nloe * elem_size, v); \
+ } \
+ } while(0)
+
+static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) src % 128 == 0);
+ hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
+ if (hex_is_aligned((void *) dst, 128)) {
+ if (hex_is_aligned((void *) src, 128)) {
+ hvx_sqr_f32_aa(dst, src, num_elems);
+ } else {
+ hvx_sqr_f32_au(dst, src, num_elems);
+ }
+ } else {
+ if (hex_is_aligned((void *) src, 128)) {
+ hvx_sqr_f32_ua(dst, src, num_elems);
+ } else {
+ hvx_sqr_f32_uu(dst, src, num_elems);
+ }
+ }
+}
+
+#undef HVX_OP_ADD
+#undef HVX_OP_SUB
+#undef HVX_OP_MUL
+#undef hvx_arith_loop_body
+#undef HVX_OP_ADD_SCALAR
+#undef HVX_OP_SUB_SCALAR
+#undef HVX_OP_MUL_SCALAR
+#undef hvx_scalar_loop_body
+#undef HVX_OP_MIN_SCALAR
+#undef HVX_OP_CLAMP_SCALAR
+#undef DEFINE_HVX_BINARY_OP_VARIANTS
+#undef HVX_BINARY_DISPATCHER
+
+#endif // HVX_ARITH_H
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h
new file mode 100644
index 0000000..12a1b7f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-base.h
@@ -0,0 +1,173 @@
+#ifndef HVX_BASE_H
+#define HVX_BASE_H
+
+#include <stdbool.h>
+#include <stdint.h>
+
+#include "hex-utils.h"
+#include "hvx-types.h"
+
+static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) {
+ // Rotate as needed.
+ v = Q6_V_vlalign_VVR(v, v, (size_t) dst);
+
+ uint32_t left_off = (size_t) dst & 127;
+ uint32_t right_off = left_off + n;
+
+ HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) dst);
+ HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off);
+
+ if (right_off > 128) {
+ Q6_vmem_QRIV(qr, (HVX_Vector *) dst + 1, v);
+ // all 1's
+ qr = Q6_Q_vcmp_eq_VbVb(v, v);
+ }
+
+ ql_not = Q6_Q_or_QQn(ql_not, qr);
+ Q6_vmem_QnRIV(ql_not, (HVX_Vector *) dst, v);
+}
+
+static inline void hvx_vec_store_a(void * restrict dst, uint32_t n, HVX_Vector v) {
+ assert((unsigned long) dst % 128 == 0);
+ HVX_VectorPred m = Q6_Q_or_QQn(Q6_Q_vsetq_R((unsigned long) dst), Q6_Q_vsetq2_R(n));
+ Q6_vmem_QnRIV(m, (HVX_Vector *) dst, v);
+}
+
+static inline HVX_Vector hvx_vec_splat_f32(float v) {
+ union { float f; uint32_t i; } u = { .f = v };
+ return Q6_V_vsplat_R(u.i);
+}
+
+static inline HVX_Vector hvx_vec_splat_f16(float v) {
+ union { __fp16 f; uint16_t i; } u = { .f = v };
+ return Q6_Vh_vsplat_R(u.i);
+}
+
+static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) {
+ // vdelta control to replicate first 4 bytes across all elements
+ static const uint8_t __attribute__((aligned(128))) repl[128] = {
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+ 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+ 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+ 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+ 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+ 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+ 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+ 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+ };
+
+ HVX_Vector ctrl = *(HVX_Vector *) repl;
+ return Q6_V_vdelta_VV(v, ctrl);
+}
+
+static inline float hvx_vec_get_f32(HVX_Vector v) {
+ float __attribute__((aligned(128))) x;
+ hvx_vec_store_a(&x, 4, v);
+ return x;
+}
+
+static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
+ int32_t __attribute__((aligned(128))) x;
+ hvx_vec_store_a(&x, 4, v);
+ return x;
+}
+
+static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
+ // abs by clearing the fp16 sign bit
+ HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
+ return Q6_V_vand_VV(v, mask);
+}
+
+static inline HVX_Vector hvx_vec_neg_f16(HVX_Vector v) {
+ // neg by setting the fp16 sign bit
+ HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
+ return Q6_V_vxor_VV(v, mask);
+}
+
+static inline HVX_Vector hvx_vec_abs_f32(HVX_Vector v) {
+ // abs by clearing the fp32 sign bit
+ HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff);
+ return Q6_V_vand_VV(v, mask);
+}
+
+static inline HVX_Vector hvx_vec_neg_f32(HVX_Vector v) {
+#if __HVX_ARCH__ > 75
+ return Q6_Vsf_vfneg_Vsf(v);
+#else
+ // neg by setting the fp32 sign bit
+ HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
+ return Q6_V_vxor_VV(v, mask);
+#endif // __HVX_ARCH__ > 75
+}
+
+static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
+ const HVX_Vector vnan_exp = Q6_Vh_vsplat_R(0x7C00);
+ const HVX_Vector vnan_frac = Q6_Vh_vsplat_R(0x7FFF);
+
+ // get pred of which are NaN, i.e., exponent bits all 1s and fraction bits non 0s
+ HVX_VectorPred p_exp = Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_exp), vnan_exp);
+ HVX_VectorPred p_frac = Q6_Q_not_Q(Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_frac), vnan_exp));
+ return Q6_Q_and_QQ(p_exp, p_frac);
+}
+
+static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
+ HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
+ HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)));
+
+#if __HVX_ARCH__ < 79
+ // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
+ const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY);
+ HVX_VectorPred nan = hvx_vec_is_nan_f16(v);
+ v = Q6_V_vmux_QVV(nan, neg_inf, v);
+#endif
+
+ return v;
+}
+
+/* Q6_Vsf_equals_Vw is only available on v73+.*/
+#if __HVX_ARCH__ < 73
+static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
+{
+ HVX_Vector const vzero = Q6_V_vzero();
+ HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
+ HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
+ HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
+ HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
+ HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
+ HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
+ return ret;
+}
+
+static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
+{
+ return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in));
+}
+#endif
+
+static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
+ // This looks complicated.
+ // Ideally should just be Q6_Vh_equals_Vhf(vin)
+ // but that instruction does not do proper rounding.
+
+ // convert to qf32, multiplying by 1.0 in the process.
+ HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00));
+
+ // 'in-range' values are +/32752.
+ // add 192K to it, convert to sf
+ HVX_Vector v192K = Q6_V_vsplat_R(0x48400000);
+ HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K));
+ HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K));
+
+ // for in-range cases, result is {163858... 229360} so the exponent is always 144.
+ // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer.
+ // Start by <<10 to get the final 'sign' bit in bit 15...
+ vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10);
+ vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10);
+
+ // now round down to 16
+ return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0);
+}
+
+#endif /* HVX_BASE_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h
new file mode 100644
index 0000000..ae0dbed
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h
@@ -0,0 +1,245 @@
+#ifndef HVX_COPY_H
+#define HVX_COPY_H
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "hvx-base.h"
+
+#define hvx_splat_loop_body(dst_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ \
+ uint32_t nvec = n / (128 / elem_size); \
+ uint32_t nloe = n % (128 / elem_size); \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ vdst[i] = src; \
+ } \
+ if (nloe) { \
+ vec_store((void *) &vdst[i], nloe * elem_size, src); \
+ } \
+ } while(0)
+
+static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
+ assert((unsigned long) dst % 128 == 0);
+ hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
+ hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) {
+ hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float));
+}
+
+static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) {
+ hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float));
+}
+
+static inline void hvx_splat_f16_a(uint8_t * restrict dst, float v, uint32_t n) {
+ hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
+}
+
+static inline void hvx_splat_f16_u(uint8_t * restrict dst, float v, uint32_t n) {
+ hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16));
+}
+
+#define hvx_copy_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const uint32_t epv = 128 / elem_size; \
+ const uint32_t nvec = n / epv; \
+ const uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { vdst[i] = vsrc[i]; } \
+ if (nloe) { \
+ vec_store((void *) &vdst[i], nloe * elem_size, vsrc[i]); \
+ } \
+ } while(0)
+
+// Generic copy routines
+static inline void hvx_copy_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_copy_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_copy_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
+ assert((unsigned long) dst % 128 == 0);
+ hvx_copy_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_copy_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
+ assert((unsigned long) src % 128 == 0);
+ hvx_copy_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_copy_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
+ hvx_copy_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+// copy n fp16 elements : source and destination are aligned to HVX Vector (128)
+static inline void hvx_copy_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_copy_aa(dst, src, n, sizeof(__fp16));
+}
+
+// copy n fp16 elements : source is aligned, destination is potentially unaligned
+static inline void hvx_copy_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_copy_au(dst, src, n, sizeof(__fp16));
+}
+
+// copy n fp16 elements : source is aligned, destination is potentially unaligned
+static inline void hvx_copy_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_copy_ua(dst, src, n, sizeof(__fp16));
+}
+
+// copy n fp16 elements : source is aligned, destination is potentially unaligned
+static inline void hvx_copy_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_copy_uu(dst, src, n, sizeof(__fp16));
+}
+
+// copy n fp32 elements : source and destination are aligned to HVX Vector (128)
+static inline void hvx_copy_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_copy_aa(dst, src, n, sizeof(float));
+}
+
+// copy n fp32 elements : source is aligned, destination is unaligned
+static inline void hvx_copy_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_copy_ua(dst, src, n, sizeof(float));
+}
+
+// copy n fp32 elements : source is unaligned, destination is aligned
+static inline void hvx_copy_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_copy_au(dst, src, n, sizeof(float));
+}
+
+// copy n fp32 elements : source is unaligned, destination unaligned
+static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_copy_uu(dst, src, n, sizeof(float));
+}
+
+//// fp32 -> fp16
+
+#define hvx_copy_f16_f32_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const uint32_t elem_size = sizeof(__fp16); \
+ const uint32_t epv = 128 / elem_size; \
+ const uint32_t nvec = n / epv; \
+ const uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ vdst[i] = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \
+ } \
+ if (nloe) { \
+ HVX_Vector v = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \
+ vec_store((void *) &vdst[i], nloe * elem_size, v); \
+ } \
+ } while(0)
+
+// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is aligned
+static inline void hvx_copy_f16_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
+static inline void hvx_copy_f16_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
+static inline void hvx_copy_f16_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) src % 128 == 0);
+ hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
+static inline void hvx_copy_f16_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+//// fp16 -> fp32
+
+#define hvx_copy_f32_f16_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const HVX_Vector one = hvx_vec_splat_f16(1.0); \
+ \
+ const uint32_t elem_size = sizeof(__fp16); \
+ const uint32_t epv = 128 / elem_size; \
+ const uint32_t nvec = n / epv; \
+ uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (i = 0; i < nvec; ++i) { \
+ HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \
+ vdst[i*2] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)); \
+ vdst[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)); \
+ } \
+ \
+ if (nloe) { \
+ HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \
+ \
+ HVX_Vector vd = Q6_V_lo_W(p); \
+ i = 2 * i; \
+ \
+ if (nloe >= 32) { \
+ vdst[i] = Q6_Vsf_equals_Vqf32(vd); \
+ nloe -= 32; ++i; vd = Q6_V_hi_W(p); \
+ } \
+ \
+ if (nloe) { \
+ vd = Q6_Vsf_equals_Vqf32(vd); \
+ hvx_vec_store_u(&vdst[i], nloe * sizeof(float), vd); \
+ } \
+ } \
+ } while(0)
+
+// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is aligned
+static inline void hvx_copy_f32_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is aligned
+static inline void hvx_copy_f32_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is unaligned
+static inline void hvx_copy_f32_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) src % 128 == 0);
+ hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is unaligned
+static inline void hvx_copy_f32_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+#endif // HVX_COPY_H
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-div.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-div.h
new file mode 100644
index 0000000..7dae012
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-div.h
@@ -0,0 +1,116 @@
+#ifndef HVX_DIV_H
+#define HVX_DIV_H
+
+#include <HAP_farf.h>
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "hvx-base.h"
+#include "hex-utils.h"
+#include "hvx-inverse.h"
+#include "hvx-arith.h"
+
+#if __HVX_ARCH__ < 79
+#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#else
+#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#endif
+
+#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src0_type * restrict vsrc0 = (src0_type *) src0; \
+ src1_type * restrict vsrc1 = (src1_type *) src1; \
+ \
+ const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
+ \
+ const uint32_t nvec = n / VLEN_FP32; \
+ const uint32_t nloe = n % VLEN_FP32; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
+ HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
+ vdst[i] = res; \
+ } \
+ if (nloe) { \
+ HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
+ HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
+ vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \
+ } \
+ } while(0)
+
+// 3-letter suffix variants
+static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) dst % 128 == 0);
+ assert((uintptr_t) src0 % 128 == 0);
+ assert((uintptr_t) src1 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) dst % 128 == 0);
+ assert((uintptr_t) src0 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) dst % 128 == 0);
+ assert((uintptr_t) src1 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) dst % 128 == 0);
+ hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) src0 % 128 == 0);
+ assert((uintptr_t) src1 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) src0 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) src1 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
+ if (hex_is_aligned((void *) dst, 128)) {
+ if (hex_is_aligned((void *) src0, 128)) {
+ if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
+ else hvx_div_f32_aau(dst, src0, src1, num_elems);
+ } else {
+ if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
+ else hvx_div_f32_auu(dst, src0, src1, num_elems);
+ }
+ } else {
+ if (hex_is_aligned((void *) src0, 128)) {
+ if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
+ else hvx_div_f32_uau(dst, src0, src1, num_elems);
+ } else {
+ if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
+ else hvx_div_f32_uuu(dst, src0, src1, num_elems);
+ }
+ }
+}
+
+#undef HVX_OP_MUL
+
+#endif // HVX_DIV_H
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h
new file mode 100644
index 0000000..85201fc
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h
@@ -0,0 +1,129 @@
+#ifndef HVX_DUMP_H
+#define HVX_DUMP_H
+
+#include <HAP_farf.h>
+
+#include <stdbool.h>
+#include <stdint.h>
+
+#include "hex-utils.h"
+#include "hvx-types.h"
+
+static void hvx_vec_dump_f16_n(char * pref, HVX_Vector v, uint32_t n) {
+ HVX_VectorAlias u = { .v = v };
+
+ const uint32_t n0 = n / 16;
+ const uint32_t n1 = n % 16;
+ int i = 0;
+ for (; i < n0; i++) {
+ hex_dump_f16_line(pref, u.fp16 + (16 * i), 16);
+ }
+ if (n1) {
+ hex_dump_f16_line(pref, u.fp16 + (16 * i), n1);
+ }
+}
+
+static void hvx_vec_dump_f16(char * pref, HVX_Vector v) {
+ hvx_vec_dump_f16_n(pref, v, 64);
+}
+
+static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) {
+ HVX_VectorAlias u = { .v = v };
+
+ const uint32_t n0 = n / 16;
+ const uint32_t n1 = n % 16;
+ int i = 0;
+ for (; i < n0; i++) {
+ hex_dump_f32_line(pref, u.fp32 + (16 * i), 16);
+ }
+ if (n1) {
+ hex_dump_f32_line(pref, u.fp32 + (16 * i), n1);
+ }
+}
+
+static void hvx_vec_dump_f32_hmt(char * pref, HVX_Vector v) {
+ union {
+ HVX_Vector v;
+ float d[32];
+ } u = { .v = v };
+
+ FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1],
+ u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
+}
+
+static void hvx_vec_dump_f32(char * pref, HVX_Vector v) {
+ hvx_vec_dump_f32_n(pref, v, 32);
+}
+
+static void hvx_vec_dump_int32(char * pref, HVX_Vector v) {
+ union {
+ HVX_Vector v;
+ int32_t d[32];
+ } u = { .v = v };
+
+ for (int i = 0; i < 32 / 16; i++) {
+ hex_dump_int32_line(pref, u.d + (16 * i), 16);
+ }
+}
+
+static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) {
+ union {
+ HVX_Vector v;
+ int32_t d[32];
+ } u = { .v = v };
+
+ FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12],
+ u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
+}
+
+static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) {
+ union {
+ HVX_Vector v;
+ int8_t d[128];
+ } u = { .v = v };
+
+ FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60],
+ u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]);
+}
+
+static void hvx_vec_dump_int8(char * pref, HVX_Vector v) {
+ union {
+ HVX_Vector v;
+ int8_t d[128];
+ } u = { .v = v };
+
+ for (int i = 0; i < 128 / 16; i++) {
+ hex_dump_int8_line(pref, u.d + (16 * i), 16);
+ }
+}
+
+static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) {
+ union {
+ HVX_Vector v;
+ uint8_t d[128];
+ } u = { .v = v };
+
+ for (int i = 0; i < 128 / 16; i++) {
+ hex_dump_uint8_line(pref, u.d + (16 * i), 16);
+ }
+}
+
+static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) {
+ typedef union {
+ HVX_Vector v;
+ int8_t d[128];
+ } U;
+
+ U u0 = { .v = v0 };
+ U u1 = { .v = v1 };
+
+ for (int i = 0; i < n; i++) {
+ if (u0.d[i] != u1.d[i]) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+#endif /* HVX_DUMP_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h
new file mode 100644
index 0000000..44dfe23
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h
@@ -0,0 +1,215 @@
+#ifndef HVX_EXP_H
+#define HVX_EXP_H
+
+#include <stdbool.h>
+#include <stdint.h>
+
+#include "hvx-base.h"
+#include "hvx-floor.h"
+
+#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!)
+#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!)
+#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!)
+#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!)
+#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!)
+#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
+#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
+#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
+#define EXP_ONE (0x3f800000) // 1.0
+#define EXP_RANGE_R (0x41a00000) // 20.0
+#define EXP_RANGE_L (0xc1a00000) // -20.0
+
+static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
+ HVX_Vector z_qf32_v;
+ HVX_Vector x_v;
+ HVX_Vector x_qf32_v;
+ HVX_Vector y_v;
+ HVX_Vector k_v;
+ HVX_Vector f_v;
+ HVX_Vector epsilon_v;
+ HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E);
+ HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2);
+ HVX_Vector E_const;
+ HVX_Vector zero_v = Q6_V_vzero();
+
+ // exp(x) is approximated as follows:
+ // f = floor(x/ln(2)) = floor(x*log2(e))
+ // epsilon = x - f*ln(2)
+ // exp(x) = exp(epsilon+f*ln(2))
+ // = exp(epsilon)*exp(f*ln(2))
+ // = exp(epsilon)*2^f
+ //
+ // Since epsilon is close to zero, it can be approximated with its Taylor series:
+ // exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+...
+ // Preserving the first eight elements, we get:
+ // exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7
+ // = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2
+
+ HVX_Vector temp_v = in_vec;
+
+ // Clamp inputs to (-20.0, 20.0)
+ HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
+ HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
+
+ in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
+ in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
+
+ epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
+ epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
+
+ // f_v is the floating point result and k_v is the integer result
+ f_v = hvx_vec_floor_f32(epsilon_v);
+ k_v = hvx_vec_truncate_f32(f_v);
+
+ x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v);
+
+ // x = x - f_v * logn2;
+ epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2);
+ x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v);
+ // normalize before every QFloat's vmpy
+ x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
+
+ // z = x * x;
+ z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
+ z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
+
+ x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
+
+ // y = E4 + E5 * x;
+ E_const = Q6_V_vsplat_R(EXP_COEFF_5);
+ y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
+ E_const = Q6_V_vsplat_R(EXP_COEFF_4);
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+ // y = E3 + y * x;
+ E_const = Q6_V_vsplat_R(EXP_COEFF_3);
+ y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+ // y = E2 + y * x;
+ E_const = Q6_V_vsplat_R(EXP_COEFF_2);
+ y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+ // y = E1 + y * x;
+ E_const = Q6_V_vsplat_R(EXP_COEFF_1);
+ y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+ // y = E0 + y * x;
+ E_const = Q6_V_vsplat_R(EXP_COEFF_0);
+ y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+ // y = x + y * z;
+ y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v);
+ y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v);
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+ // y = y + 1.0;
+ y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE));
+
+ // insert exponents
+ // y = ldexpf(y, k);
+ // y_v += k_v; // qf32
+ // modify exponent
+
+ y_v = Q6_Vsf_equals_Vqf32(y_v);
+
+ // add k_v to the exponent of y_v
+ HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1);
+
+ y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1);
+ y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent);
+
+ // exponent cannot be negative; if overflow is detected, result is set to zero
+ HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent);
+
+ y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN);
+
+ y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v);
+
+ return y_v;
+}
+
+static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) {
+ const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
+
+ HVX_Vector out = hvx_vec_exp_f32(in_vec);
+
+ return Q6_V_vmux_QVV(pred0, inf, out);
+}
+
+static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
+ int left_over = num_elems & (VLEN_FP32 - 1);
+ int num_elems_whole = num_elems - left_over;
+
+ int unaligned_addr = 0;
+ int unaligned_loop = 0;
+ if ((0 == hex_is_aligned((void *) src, VLEN)) || (0 == hex_is_aligned((void *) dst, VLEN))) {
+ unaligned_addr = 1;
+ }
+ // assert((0 == unaligned_addr) || (0 == num_elems_whole));
+ if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+ unaligned_loop = 1;
+ }
+
+ HVX_Vector vec_out = Q6_V_vzero();
+
+ static const float kInf = INFINITY;
+ static const float kMaxExp = 88.02f; // log(INF)
+
+ const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
+ const HVX_Vector inf = hvx_vec_splat_f32(kInf);
+
+ if (0 == unaligned_loop) {
+ HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
+ HVX_Vector * p_vec_out = (HVX_Vector *) dst;
+
+ #pragma unroll(4)
+ for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+ if (true == negate) {
+ HVX_Vector neg_vec_in = hvx_vec_neg_f32(*p_vec_in1++);
+ *p_vec_out++ = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
+ } else {
+ *p_vec_out++ = hvx_vec_exp_f32_guard(*p_vec_in1++, max_exp, inf);
+ }
+ }
+ } else {
+ #pragma unroll(4)
+ for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+ HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+
+ if (true == negate) {
+ HVX_Vector neg_vec_in = hvx_vec_neg_f32(in);
+ *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
+ } else {
+ *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(in, max_exp, inf);
+ }
+ }
+ }
+
+ if (left_over > 0) {
+ const float * srcf = (float *) src + num_elems_whole;
+ float * dstf = (float *) dst + num_elems_whole;
+
+ HVX_Vector in = *(HVX_UVector *) srcf;
+
+ if (true == negate) {
+ HVX_Vector neg_vec_in = hvx_vec_neg_f32(in);
+
+ vec_out = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
+ } else {
+ vec_out = hvx_vec_exp_f32_guard(in, max_exp, inf);
+ }
+
+ hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
+ }
+}
+
+#endif /* HVX_EXP_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h
new file mode 100644
index 0000000..6a1bfde
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h
@@ -0,0 +1,100 @@
+#ifndef HVX_FLOOR_H
+#define HVX_FLOOR_H
+
+#include <stdbool.h>
+#include <stdint.h>
+
+#include "hvx-base.h"
+
+#define IEEE_VSF_EXPLEN (8)
+#define IEEE_VSF_EXPBIAS (127)
+#define IEEE_VSF_EXPMASK (0xFF)
+#define IEEE_VSF_MANTLEN (23)
+#define IEEE_VSF_MANTMASK (0x7FFFFF)
+#define IEEE_VSF_MIMPMASK (0x800000)
+
+static inline HVX_Vector hvx_vec_truncate_f32(HVX_Vector in_vec) {
+ HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
+ HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
+ HVX_Vector const_zero_v = Q6_V_vzero();
+
+ HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
+
+ HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
+ expval_v &= IEEE_VSF_EXPMASK;
+ expval_v -= IEEE_VSF_EXPBIAS;
+
+ // negative exp == fractional value
+ HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
+
+ HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift
+
+ HVX_Vector mant_v = in_vec & mask_mant_v; // obtain mantissa
+ HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0
+
+ vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer
+ vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0
+
+ HVX_Vector neg_vout = -vout;
+
+ vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives
+
+ return (vout);
+}
+
+static inline HVX_Vector hvx_vec_floor_f32(HVX_Vector in_vec) {
+ HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
+ HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
+ HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN);
+ HVX_Vector const_zero_v = Q6_V_vzero();
+ HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf
+
+ HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
+
+ HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
+ expval_v &= IEEE_VSF_EXPMASK;
+ expval_v -= IEEE_VSF_EXPBIAS;
+
+ HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
+ HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v);
+ HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v);
+ HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec);
+
+ // if expval < 0 (q_negexp) // <0, floor is 0
+ // if vin > 0
+ // floor = 0
+ // if vin < 0
+ // floor = -1
+ // if expval < mant_len (q_expltmn) // >0, but fraction may exist
+ // get sign (q_negative)
+ // mask >> expval // fraction bits to mask off
+ // vout = ~(mask) // apply mask to remove fraction
+ // if (qneg) // negative floor is one less (more, sign bit for neg)
+ // vout += ((impl_mask) >> expval)
+ // if (mask && vin)
+ // vout = vin
+ // else // already an integer
+ // ; // no change
+
+ // compute floor
+ mask_mant_v >>= expval_v;
+ HVX_Vector neg_addin_v = mask_impl_v >> expval_v;
+ HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v);
+ HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec);
+
+ HVX_Vector mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v); // chk if bits set
+ HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v);
+
+ HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear
+ HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits
+
+ vout = in_vec;
+ vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval<mant
+ vout = Q6_V_vmux_QVV(q_integral, in_vec, vout); // integral values
+ vout = Q6_V_vmux_QVV(q_negexp_pos, const_zero_v, vout); // expval<0 x>0 -> 0
+ vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1
+
+ return vout;
+}
+
+#endif /* HVX_FLOOR_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h
new file mode 100644
index 0000000..49f3efa
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h
@@ -0,0 +1,176 @@
+#ifndef HVX_INVERSE_H
+#define HVX_INVERSE_H
+
+#include <HAP_farf.h>
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "hvx-base.h"
+
+// ====================================================
+// FUNCTION: 1/(x+1) y(0) = 1, y(0.5) = 0.6667, y(1) = 0.5
+// Order:3; continuity: True; Ends forced: True
+// Mode: unsigned; Result fractional bits: 14
+// Peak Error: 1.1295e-04 Rms Error: 2.8410e-05 Mean Error: 1.1370e-05
+// 32769 -32706 31252 -10589
+// 32590 -30635 22793 -4493
+// 32066 -27505 16481 -2348
+// 31205 -24054 11849 -1306
+
+static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) {
+ // input is 0..0xffff representing 0.0 .. 1.0
+ HVX_Vector p;
+ p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull);
+ p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull);
+ p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull);
+ p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull);
+ return p; // signed result, 14 fractional bits
+}
+
+// Find reciprocal of fp16.
+// (1) first, convert to fp32, multiplying by 1.0; this is done to
+// handle denormals. Ignoring sign and zero, result should be at
+// least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000)
+// (exponent in range [103,143])
+// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly
+// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32
+// (4) convert that to fp16
+// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace
+// the result with the max value.
+static inline HVX_Vector hvx_vec_inverse_f16(HVX_Vector vals) {
+ HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF);
+ HVX_Vector avals = Q6_V_vand_VV(vals, em_mask);
+ HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(avals, vals);
+ // is too small to 1/x ? for 'standard' fp16, this would be 0x101
+ HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals);
+
+ HVX_VectorPair to_qf32 = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00)); // *1.0
+ HVX_Vector to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32));
+ HVX_Vector to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32));
+
+ // bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector
+ HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9));
+ // likewise extract the upper 16 from each, containing the exponents in range 103..142
+ HVX_Vector exp_u16 = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0);
+ //Get exponent in IEEE 32-bit representation
+ exp_u16 = Q6_Vuh_vlsr_VuhR(exp_u16, 7);
+
+ // so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane
+ // We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0)
+ // Use poly to transform to 1/x, with 14 fractional bits
+ //
+ HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16);
+
+ HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros
+
+ // Get mantissa for 16-bit represenation
+ HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF));
+
+ //Compute Reciprocal Exponent
+ HVX_Vector exp_recip =
+ Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1)));
+ //Convert it for 16-bit representation
+ exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15));
+ exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10);
+
+ //Merge exponent and mantissa for reciprocal
+ HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip);
+ // map 'small' inputs to standard largest value 0x7bff
+ recip = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip);
+ // add sign back
+ recip = Q6_V_vandor_VQR(recip, is_neg, 0x80008000);
+ return recip;
+}
+
+static inline HVX_Vector hvx_vec_inverse_f32(HVX_Vector v_sf) {
+ HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3);
+ HVX_Vector two_sf = hvx_vec_splat_f32(2.0);
+
+ // First approximation
+ HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf);
+
+ HVX_Vector r_qf;
+
+ // Refine
+ r_qf = Q6_Vqf32_vmpy_VsfVsf(
+ i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf)))));
+ r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
+ r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
+ r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
+ r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
+
+ return Q6_Vsf_equals_Vqf32(r_qf);
+}
+
+static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
+ HVX_Vector out = hvx_vec_inverse_f32(v_sf);
+
+ HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask);
+ const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out);
+
+ return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
+}
+
+#define hvx_inverse_f32_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
+ \
+ const uint32_t nvec = n / VLEN_FP32; \
+ const uint32_t nloe = n % VLEN_FP32; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ vdst[i] = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \
+ } \
+ if (nloe) { \
+ HVX_Vector v = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \
+ vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, v); \
+ } \
+ } while(0)
+
+static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) src % 128 == 0);
+ hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) {
+ if ((unsigned long) dst % 128 == 0) {
+ if ((unsigned long) src % 128 == 0) {
+ hvx_inverse_f32_aa(dst, src, num_elems);
+ } else {
+ hvx_inverse_f32_au(dst, src, num_elems);
+ }
+ } else {
+ if ((unsigned long) src % 128 == 0) {
+ hvx_inverse_f32_ua(dst, src, num_elems);
+ } else {
+ hvx_inverse_f32_uu(dst, src, num_elems);
+ }
+ }
+}
+
+#endif // HVX_INVERSE_H
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h
new file mode 100644
index 0000000..1ca7c05
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h
@@ -0,0 +1,266 @@
+#ifndef HVX_REDUCE_H
+#define HVX_REDUCE_H
+
+#include <math.h>
+#include <stdbool.h>
+#include <stdint.h>
+#include <assert.h>
+
+#include "hex-utils.h"
+#include "hvx-base.h"
+#include "hvx-types.h"
+
+static inline HVX_Vector hvx_vec_reduce_sum_n_i32(HVX_Vector in, unsigned int n) {
+ unsigned int total = n * 4; // total vec nbytes
+ unsigned int width = 4; // int32
+
+ HVX_Vector sum = in, sum_t;
+ while (width < total) {
+ sum_t = Q6_V_vror_VR(sum, width); // rotate right
+ sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum
+ width = width << 1;
+ }
+ return sum;
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_i32(HVX_Vector in) {
+ return hvx_vec_reduce_sum_n_i32(in, 32);
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_n_qf32(HVX_Vector in, unsigned int n) {
+ unsigned int total = n * 4; // total vec nbytes
+ unsigned int width = 4; // fp32 nbytes
+
+ HVX_Vector sum = in, sum_t;
+ while (width < total) {
+ sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right
+ sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum
+ width = width << 1;
+ }
+ return sum;
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {
+ return hvx_vec_reduce_sum_n_qf32(in, 32);
+}
+
+#if __HVX_ARCH__ > 75
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
+ HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
+ HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
+
+ sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));
+ sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));
+ sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));
+ sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16));
+ return sum_sf;
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
+ unsigned int total = n * 4; // total vec nbytes
+ unsigned int width = 4; // fp32 nbytes
+
+ HVX_Vector sum = in, sum_t;
+ while (width < total) {
+ sum_t = Q6_V_vror_VR(sum, width); // rotate right
+ sum = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum
+ width = width << 1;
+ }
+ return sum;
+}
+
+#else
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
+ HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
+ HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
+
+ sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));
+ sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));
+ sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));
+ sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16));
+ return Q6_Vsf_equals_Vqf32(sum_qf);
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
+ unsigned int total = n * 4; // total vec nbytes
+ unsigned int width = 4; // fp32 nbytes
+
+ HVX_Vector sum = in, sum_t;
+ while (width < total) {
+ sum_t = Q6_V_vror_VR(sum, width); // rotate right
+ sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
+ width = width << 1;
+ }
+ return sum;
+}
+
+#endif
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {
+ return hvx_vec_reduce_sum_n_f32(in, 32);
+}
+
+static inline HVX_Vector hvx_vec_reduce_max_f16(HVX_Vector in) {
+ unsigned total = 128; // total vec nbytes
+ unsigned width = 2; // fp16 nbytes
+
+ HVX_Vector _max = in, _max_t;
+ while (width < total) {
+ _max_t = Q6_V_vror_VR(_max, width); // rotate right
+ _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
+ width = width << 1;
+ }
+
+ return _max;
+}
+
+static inline HVX_Vector hvx_vec_reduce_max2_f16(HVX_Vector in, HVX_Vector _max) {
+ unsigned total = 128; // total vec nbytes
+ unsigned width = 2; // fp32 nbytes
+
+ HVX_Vector _max_t;
+
+ _max = Q6_Vhf_vmax_VhfVhf(in, _max);
+ while (width < total) {
+ _max_t = Q6_V_vror_VR(_max, width); // rotate right
+ _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
+ width = width << 1;
+ }
+
+ return _max;
+}
+
+static inline HVX_Vector hvx_vec_reduce_max_f32(HVX_Vector in) {
+ unsigned total = 128; // total vec nbytes
+ unsigned width = 4; // fp32 nbytes
+
+ HVX_Vector _max = in, _max_t;
+ while (width < total) {
+ _max_t = Q6_V_vror_VR(_max, width); // rotate right
+ _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
+ width = width << 1;
+ }
+
+ return _max;
+}
+
+static inline HVX_Vector hvx_vec_reduce_max2_f32(HVX_Vector in, HVX_Vector _max) {
+ unsigned total = 128; // total vec nbytes
+ unsigned width = 4; // fp32 nbytes
+
+ HVX_Vector _max_t;
+
+ _max = Q6_Vsf_vmax_VsfVsf(in, _max);
+ while (width < total) {
+ _max_t = Q6_V_vror_VR(_max, width); // rotate right
+ _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
+ width = width << 1;
+ }
+
+ return _max;
+}
+
+#define hvx_reduce_loop_body(src_type, init_vec, pad_vec, vec_op, reduce_op, scalar_reduce) \
+ do { \
+ src_type * restrict vsrc = (src_type *) src; \
+ HVX_Vector acc = init_vec; \
+ \
+ const uint32_t elem_size = sizeof(float); \
+ const uint32_t epv = 128 / elem_size; \
+ const uint32_t nvec = num_elems / epv; \
+ const uint32_t nloe = num_elems % epv; \
+ \
+ uint32_t i = 0; \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ acc = vec_op(acc, vsrc[i]); \
+ } \
+ if (nloe) { \
+ const float * srcf = (const float *) src + i * epv; \
+ HVX_Vector in = *(HVX_UVector *) srcf; \
+ HVX_Vector temp = Q6_V_valign_VVR(in, pad_vec, nloe * elem_size); \
+ acc = vec_op(acc, temp); \
+ } \
+ HVX_Vector v = reduce_op(acc); \
+ return scalar_reduce(v); \
+ } while(0)
+
+#define HVX_REDUCE_MAX_OP(acc, val) Q6_Vsf_vmax_VsfVsf(acc, val)
+#define HVX_REDUCE_SUM_OP(acc, val) Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(acc), val)
+#define HVX_SUM_SQ_OP(acc, val) Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(val, val))
+#define HVX_REDUCE_MAX_SCALAR(v) hvx_vec_get_f32(v)
+#define HVX_REDUCE_SUM_SCALAR(v) hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(v))
+
+// Max variants
+
+static inline float hvx_reduce_max_f32_a(const uint8_t * restrict src, const int num_elems) {
+ HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
+ assert((unsigned long) src % 128 == 0);
+ hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
+}
+
+static inline float hvx_reduce_max_f32_u(const uint8_t * restrict src, const int num_elems) {
+ HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
+ hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
+}
+
+static inline float hvx_reduce_max_f32(const uint8_t * restrict src, const int num_elems) {
+ if (hex_is_aligned((void *) src, 128)) {
+ return hvx_reduce_max_f32_a(src, num_elems);
+ } else {
+ return hvx_reduce_max_f32_u(src, num_elems);
+ }
+}
+
+// Sum variants
+
+static inline float hvx_reduce_sum_f32_a(const uint8_t * restrict src, const int num_elems) {
+ HVX_Vector init_vec = Q6_V_vsplat_R(0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
+}
+
+static inline float hvx_reduce_sum_f32_u(const uint8_t * restrict src, const int num_elems) {
+ HVX_Vector init_vec = Q6_V_vsplat_R(0);
+ hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
+}
+
+static inline float hvx_reduce_sum_f32(const uint8_t * restrict src, const int num_elems) {
+ if (hex_is_aligned((void *) src, 128)) {
+ return hvx_reduce_sum_f32_a(src, num_elems);
+ } else {
+ return hvx_reduce_sum_f32_u(src, num_elems);
+ }
+}
+
+// Sum of squares variants
+
+static inline float hvx_sum_of_squares_f32_a(const uint8_t * restrict src, const int num_elems) {
+ HVX_Vector init_vec = Q6_V_vsplat_R(0);
+ assert((uintptr_t) src % 128 == 0);
+ hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
+}
+
+static inline float hvx_sum_of_squares_f32_u(const uint8_t * restrict src, const int num_elems) {
+ HVX_Vector init_vec = Q6_V_vsplat_R(0);
+ hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
+}
+
+static inline float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) {
+ if (hex_is_aligned((void *) src, 128)) {
+ return hvx_sum_of_squares_f32_a(src, num_elems);
+ } else {
+ return hvx_sum_of_squares_f32_u(src, num_elems);
+ }
+}
+
+#undef hvx_reduce_loop_body
+#undef HVX_REDUCE_MAX_OP
+#undef HVX_REDUCE_SUM_OP
+#undef HVX_REDUCE_MAX_SCALAR
+#undef HVX_REDUCE_SUM_SCALAR
+#undef HVX_SUM_SQ_OP
+
+#endif /* HVX_REDUCE_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h
new file mode 100644
index 0000000..c65c986
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h
@@ -0,0 +1,133 @@
+#ifndef HVX_SCALE_H
+#define HVX_SCALE_H
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "hvx-base.h"
+
+#define hvx_scale_f32_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ HVX_Vector vs = hvx_vec_splat_f32(scale); \
+ \
+ const uint32_t elem_size = sizeof(float); \
+ const uint32_t epv = 128 / elem_size; \
+ const uint32_t nvec = n / epv; \
+ const uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; ++i) { \
+ HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \
+ vdst[i] = Q6_Vsf_equals_Vqf32(v); \
+ } \
+ if (nloe) { \
+ HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \
+ vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \
+ } \
+ } while(0)
+
+static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+ assert((size_t) dst % 128 == 0);
+ assert((size_t) src % 128 == 0);
+ hvx_scale_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_scale_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+ assert((size_t) dst % 128 == 0);
+ hvx_scale_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_scale_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+ assert((size_t) src % 128 == 0);
+ hvx_scale_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+ hvx_scale_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+ if (((size_t) dst & 127) == 0) {
+ if (((size_t) src & 127) == 0) {
+ hvx_scale_f32_aa(dst, src, n, scale);
+ } else {
+ hvx_scale_f32_au(dst, src, n, scale);
+ }
+ } else {
+ if (((size_t) src & 127) == 0) {
+ hvx_scale_f32_ua(dst, src, n, scale);
+ } else {
+ hvx_scale_f32_uu(dst, src, n, scale);
+ }
+ }
+}
+
+#define hvx_scale_offset_f32_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ HVX_Vector vs = hvx_vec_splat_f32(scale); \
+ HVX_Vector vo = hvx_vec_splat_f32(offset); \
+ \
+ const uint32_t elem_size = sizeof(float); \
+ const uint32_t epv = 128 / elem_size; \
+ const uint32_t nvec = n / epv; \
+ const uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; ++i) { \
+ HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \
+ vdst[i] = Q6_Vsf_equals_Vqf32(v); \
+ } \
+ if (nloe) { \
+ HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \
+ vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \
+ } \
+ } while(0)
+
+static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+ assert((size_t) dst % 128 == 0);
+ assert((size_t) src % 128 == 0);
+ hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_scale_offset_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+ assert((size_t) dst % 128 == 0);
+ hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_scale_offset_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+ assert((size_t) src % 128 == 0);
+ hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+ hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+ if (((size_t) dst & 127) == 0) {
+ if (((size_t) src & 127) == 0) {
+ hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
+ } else {
+ hvx_scale_offset_f32_au(dst, src, n, scale, offset);
+ }
+ } else {
+ if (((size_t) src & 127) == 0) {
+ hvx_scale_offset_f32_ua(dst, src, n, scale, offset);
+ } else {
+ hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
+ }
+ }
+}
+
+#endif // HVX_SCALE_H
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h
new file mode 100644
index 0000000..0951932
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h
@@ -0,0 +1,141 @@
+#ifndef HVX_SIGMOID_H
+#define HVX_SIGMOID_H
+
+#include "hvx-base.h"
+
+#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
+#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
+#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267
+#define FAST_SIGMOID_C3 (0x3f000000) // 0.5
+
+static inline HVX_Vector hvx_vec_fast_sigmoid_f32(HVX_Vector v) {
+ v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
+ v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3));
+
+ HVX_Vector in_int = hvx_vec_truncate_f32(Q6_Vsf_equals_Vqf32(v));
+ HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int));
+ HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x);
+
+ HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2));
+ v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
+
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1));
+ v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx);
+ v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x);
+
+ HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1));
+ HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1);
+ v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24);
+ v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent);
+ v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24);
+
+ HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1));
+ HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4));
+
+ HVX_Vector res = hvx_vec_inverse_f32(v5);
+ res = Q6_Vqf32_vmpy_VsfVsf(v3, res);
+
+ return Q6_Vsf_equals_Vqf32(res);
+}
+
+static inline HVX_Vector hvx_vec_fast_sigmoid_f32_guard(HVX_Vector v,
+ HVX_Vector one,
+ HVX_Vector max_exp,
+ HVX_Vector min_exp) {
+ const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v);
+ const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp);
+
+ HVX_Vector out = hvx_vec_fast_sigmoid_f32(v);
+ out = Q6_V_vmux_QVV(pred_max, out, one);
+ return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
+}
+
+static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
+ // tanh(x) = 2 * sigmoid(2x) - 1
+ HVX_Vector two = hvx_vec_splat_f32(2.0f);
+ HVX_Vector one = hvx_vec_splat_f32(1.0f);
+ HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two);
+
+ HVX_Vector max_exp = hvx_vec_splat_f32(87.f);
+ HVX_Vector min_exp = hvx_vec_splat_f32(-87.f);
+
+ HVX_Vector sig2x = hvx_vec_fast_sigmoid_f32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
+
+ HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
+ res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
+ return Q6_Vsf_equals_Vqf32(res);
+}
+
+#define hvx_sigmoid_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const HVX_Vector one = hvx_vec_splat_f32(1.f); \
+ const HVX_Vector max_exp = hvx_vec_splat_f32(87.f); \
+ const HVX_Vector min_exp = hvx_vec_splat_f32(-87.f); \
+ \
+ const uint32_t epv = 128 / sizeof(float); \
+ const uint32_t nvec = n / epv; \
+ const uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ vdst[i] = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
+ } \
+ if (nloe) { \
+ HVX_Vector tmp = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
+ vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
+ } \
+ } while(0)
+
+#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const uint32_t epv = 128 / sizeof(float); \
+ const uint32_t nvec = n / epv; \
+ const uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \
+ } \
+ if (nloe) { \
+ HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \
+ vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
+ } \
+ } while(0)
+
+static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_sigmoid_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sigmoid_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ hvx_sigmoid_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_sigmoid_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) src % 128 == 0);
+ hvx_sigmoid_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+#endif /* HVX_SIGMOID_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h
new file mode 100644
index 0000000..e31a100
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h
@@ -0,0 +1,126 @@
+#ifndef HVX_SQRT_H
+#define HVX_SQRT_H
+
+#include <stdbool.h>
+#include <stdint.h>
+
+#include "hex-utils.h"
+
+#include "hvx-base.h"
+
+#define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation
+#define RSQRT_ONE_HALF 0x3f000000 // 0.5
+#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
+
+#if __HVX_ARCH__ < 79
+#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#else
+#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#endif
+
+static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
+ //Algorithm :
+ // x2 = input*0.5
+ // y = * (long *) &input
+ // y = 0x5f3759df - (y>>1)
+ // y = y*(threehalfs - x2*y*y)
+
+ HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
+ HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF);
+ HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES);
+
+ HVX_Vector x2, y, ypower2, temp;
+
+ x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf);
+ x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero());
+
+ y = Q6_Vw_vasr_VwR(in_vec, 1);
+ y = Q6_Vw_vsub_VwVw(rsqrtconst, y);
+
+ // 1st iteration
+ ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y);
+ ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
+ temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
+ temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
+ temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp));
+
+ // 2nd iteration
+ y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
+ ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
+ ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
+ temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
+ temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
+ temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
+
+ // 3rd iteration
+ y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
+ ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
+ ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
+ temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
+ temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
+ temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
+
+ return Q6_Vsf_equals_Vqf32(temp);
+}
+
+// Compute sqrt(x) as x*inv_sqrt(x)
+#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const uint32_t nvec = n / VLEN_FP32; \
+ const uint32_t nloe = n % VLEN_FP32; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
+ HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
+ vdst[i] = sqrt_res; \
+ } \
+ if (nloe) { \
+ HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
+ HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
+ vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \
+ } \
+ } while(0)
+
+static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) src % 128 == 0);
+ hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
+ if ((unsigned long) dst % 128 == 0) {
+ if ((unsigned long) src % 128 == 0) {
+ hvx_sqrt_f32_aa(dst, src, num_elems);
+ } else {
+ hvx_sqrt_f32_au(dst, src, num_elems);
+ }
+ } else {
+ if ((unsigned long) src % 128 == 0) {
+ hvx_sqrt_f32_ua(dst, src, num_elems);
+ } else {
+ hvx_sqrt_f32_uu(dst, src, num_elems);
+ }
+ }
+}
+
+#endif /* HVX_SQRT_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-types.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-types.h
new file mode 100644
index 0000000..d495a59
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-types.h
@@ -0,0 +1,36 @@
+#ifndef HVX_TYPES_H
+#define HVX_TYPES_H
+
+#include <stdbool.h>
+#include <stdint.h>
+
+#include <hexagon_types.h>
+
+#define SIZEOF_FP32 (4)
+#define SIZEOF_FP16 (2)
+#define VLEN (128)
+#define VLEN_FP32 (VLEN / SIZEOF_FP32)
+#define VLEN_FP16 (VLEN / SIZEOF_FP16)
+
+typedef union {
+ HVX_Vector v;
+ uint8_t b[VLEN];
+ uint16_t h[VLEN_FP16];
+ uint32_t w[VLEN_FP32];
+ __fp16 fp16[VLEN_FP16];
+ float fp32[VLEN_FP32];
+} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias;
+
+typedef struct {
+ HVX_Vector v[2];
+} HVX_Vector_x2;
+
+typedef struct {
+ HVX_Vector v[4];
+} HVX_Vector_x4;
+
+typedef struct {
+ HVX_Vector v[8];
+} HVX_Vector_x8;
+
+#endif /* HVX_TYPES_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h
new file mode 100644
index 0000000..a518ad3
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h
@@ -0,0 +1,18 @@
+#ifndef HVX_UTILS_H
+#define HVX_UTILS_H
+
+#include "hex-utils.h"
+
+#include "hvx-types.h"
+#include "hvx-copy.h"
+#include "hvx-scale.h"
+#include "hvx-exp.h"
+#include "hvx-inverse.h"
+#include "hvx-reduce.h"
+#include "hvx-sigmoid.h"
+#include "hvx-sqrt.h"
+#include "hvx-arith.h"
+#include "hvx-div.h"
+#include "hvx-base.h"
+
+#endif /* HVX_UTILS_H */
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/main.c b/llama.cpp/ggml/src/ggml-hexagon/htp/main.c
new file mode 100644
index 0000000..62708ee
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/main.c
@@ -0,0 +1,1150 @@
+#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
+#pragma clang diagnostic ignored "-Wunused-function"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+#include <AEEStdErr.h>
+#include <dspqueue.h>
+#include <HAP_compute_res.h>
+#include <HAP_etm_config.h>
+#include <HAP_mem.h>
+#include <HAP_power.h>
+#include <HAP_ps.h>
+#include <qurt.h>
+#include <qurt_thread.h>
+#include <remote.h>
+#include <string.h>
+
+#include "hex-dma.h"
+#include "hex-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "worker-pool.h"
+
+AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
+ struct htp_context * ctx;
+ int err = 0;
+
+ ctx = calloc(1, sizeof(*ctx));
+ if (ctx == NULL) {
+ return AEE_ENOMEMORY;
+ }
+
+ // Use the context structure as a handle
+ *handle = (remote_handle64) ctx;
+
+ // Enable FARF logs
+ HAP_setFARFRuntimeLoggingParams(0xffff, NULL, 0);
+
+ // Set client class
+ {
+ HAP_power_request_t request;
+ memset(&request, 0, sizeof(HAP_power_request_t));
+ request.type = HAP_power_set_apptype;
+ request.apptype = HAP_POWER_COMPUTE_CLIENT_CLASS;
+
+ if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
+ return err;
+ }
+ }
+
+ {
+ HAP_power_request_t request;
+ memset(&request, 0, sizeof(request));
+
+ request.type = HAP_power_set_DCVS_v3;
+ request.dcvs_v3.set_dcvs_enable = TRUE;
+ request.dcvs_v3.dcvs_enable = TRUE;
+ request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE;
+ request.dcvs_v3.set_bus_params = TRUE;
+ request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX;
+ request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX;
+ request.dcvs_v3.bus_params.target_corner = HAP_DCVS_VCORNER_MAX;
+ request.dcvs_v3.set_core_params = TRUE;
+ request.dcvs_v3.core_params.min_corner = HAP_DCVS_VCORNER_MAX;
+ request.dcvs_v3.core_params.max_corner = HAP_DCVS_VCORNER_MAX;
+ request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX;
+ request.dcvs_v3.set_sleep_disable = TRUE;
+ request.dcvs_v3.sleep_disable = TRUE;
+ if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
+ return err;
+ }
+
+ memset(&request, 0, sizeof(request));
+ request.type = HAP_power_set_HVX;
+ request.hvx.power_up = TRUE;
+ if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
+ return err;
+ }
+ }
+
+ {
+ // Power on HMX
+ HAP_power_request_t request;
+ memset(&request, 0, sizeof(HAP_power_request_t));
+ request.type = HAP_power_set_HMX;
+ request.hmx.power_up = TRUE;
+ FARF(ALWAYS, "Powering HMX on\n");
+ err = HAP_power_set((void *) &ctx, &request);
+ if (err != AEE_SUCCESS) {
+ FARF(ERROR, "Error powering on HMX.");
+ return err;
+ }
+ }
+
+ return AEE_SUCCESS;
+}
+
+AEEResult htp_iface_close(remote_handle64 handle) {
+ struct htp_context * ctx = (struct htp_context *) handle;
+
+ if (!ctx) {
+ return AEE_EBADPARM;
+ }
+
+ if (ctx->queue) {
+ FARF(ERROR, "Closing handle with queue still open");
+ return AEE_EITEMBUSY;
+ }
+
+ free(ctx);
+ return AEE_SUCCESS;
+}
+
+AEEResult htp_iface_enable_etm(remote_handle64 handle) {
+ int err = HAP_user_etm_enable();
+ if (err) {
+ if (err == AEE_EVERSIONNOTSUPPORT) {
+ FARF(ERROR, "API HAP_user_etm_enable is not supported\n");
+ } else {
+ FARF(ERROR, "Error executing HAP_user_etm_enable with error code : 0x%x\n", err);
+ }
+ }
+ return err;
+}
+
+AEEResult htp_iface_disable_etm(remote_handle64 handle) {
+ int err = HAP_user_etm_disable();
+ if (err) {
+ if (err == AEE_EVERSIONNOTSUPPORT) {
+ FARF(ERROR, "API HAP_user_etm_disable is not supported\n");
+ } else {
+ FARF(ERROR, "Error executing HAP_user_etm_disable with error code : 0x%x\n", err);
+ }
+ }
+ return err;
+}
+
+static int vtcm_acquire(struct htp_context * ctx) {
+ int err;
+ if (!ctx->vtcm_valid) {
+ // Temporarily bump thread priority to make sure it's higher than other sessions.
+ // This way the resource manager will notify the other thread to release VTCM.
+ // Note that we need to reaquire VTCM at normal priority for this to work next time.
+ qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10);
+ err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
+ if (err != 0) {
+ FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err);
+ abort();
+ }
+ HAP_compute_res_release_cached(ctx->vtcm_rctx);
+ qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio);
+
+ err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
+ if (err != 0) {
+ FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err);
+ abort();
+ }
+ ctx->vtcm_valid = true;
+ }
+
+ ctx->vtcm_inuse = true;
+ return 0;
+}
+
+static int vtcm_release(struct htp_context * ctx) {
+ ctx->vtcm_inuse = false;
+
+ if (ctx->vtcm_valid && ctx->vtcm_needs_release) {
+ ctx->vtcm_valid = false;
+ ctx->vtcm_needs_release = false;
+ HAP_compute_res_release_cached(ctx->vtcm_rctx);
+ }
+
+ return 0;
+}
+
+static int vtcm_release_callback(unsigned int rctx, void * state) {
+ struct htp_context * ctx = (struct htp_context *) state;
+
+ if (!ctx || ctx->vtcm_rctx != rctx) {
+ return AEE_EBADPARM;
+ }
+
+ // If VTCM is not inuse (not processing Ops) release it right here
+ // otherwise we'll release it once we're done with the current Op.
+
+ if (ctx->vtcm_inuse) {
+ ctx->vtcm_needs_release = false;
+ return 0;
+ }
+
+ ctx->vtcm_valid = false;
+ HAP_compute_res_release_cached(ctx->vtcm_rctx);
+
+ return 0;
+}
+
+static int vtcm_alloc(struct htp_context * ctx) {
+ unsigned int vtcm_size = 8 * 1024 * 1024; // 8MB default
+ HAP_compute_res_query_VTCM(0, &vtcm_size, NULL, NULL, NULL);
+
+ compute_res_attr_t attr;
+ HAP_compute_res_attr_init(&attr);
+ HAP_compute_res_attr_set_serialize(&attr, 0);
+ HAP_compute_res_attr_set_cache_mode(&attr, 1);
+ HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size);
+ HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx);
+ HAP_compute_res_attr_set_hmx_param(&attr, 1);
+
+ // Allocate VTCM for scratch pads
+ uint32_t rctx = HAP_compute_res_acquire(&attr, 1000000 /* timeout */);
+ if (!rctx) {
+ FARF(ERROR, "failed to allocate %zu bytes VTCM\n", ctx->vtcm_size);
+ return AEE_ENOMEMORY;
+ }
+
+ void * vtcm_ptr;
+ if (HAP_compute_res_attr_get_vtcm_ptr_v2(&attr, &vtcm_ptr, &vtcm_size) != 0) {
+ HAP_compute_res_release(rctx);
+ FARF(ERROR, "failed to allocate %zu bytes VTCM (new)\n", ctx->vtcm_size);
+ return AEE_ENOMEMORY;
+ }
+
+ ctx->vtcm_base = (uint8_t *) vtcm_ptr;
+ ctx->vtcm_size = vtcm_size;
+ ctx->vtcm_rctx = rctx;
+ ctx->vtcm_valid = false;
+ ctx->vtcm_inuse = false;
+ ctx->vtcm_needs_release = false;
+
+ return 0;
+}
+
+static void vtcm_free(struct htp_context * ctx) {
+ if (ctx->vtcm_rctx) {
+ HAP_compute_res_release(ctx->vtcm_rctx);
+ ctx->vtcm_base = 0;
+ ctx->vtcm_rctx = 0;
+ }
+}
+
+static void htp_packet_callback(dspqueue_t queue, int error, void * context);
+static void htp_error_callback(dspqueue_t queue, int error, void * context);
+
+AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) {
+ struct htp_context * ctx = (struct htp_context *) handle;
+
+ if (!ctx) {
+ return AEE_EBADPARM;
+ }
+
+ if (ctx->queue) {
+ FARF(ERROR, "Queue already open");
+ return AEE_EITEMBUSY;
+ }
+
+ // Import queue created on the CPU
+ int err = dspqueue_import(dsp_queue_id, // Queue ID from dspqueue_export
+ htp_packet_callback, // Packet callback
+ htp_error_callback, // Error callback; no errors expected on the DSP
+ (void *) ctx, // Callback context
+ &ctx->queue);
+
+ if (err) {
+ FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err);
+ return err;
+ }
+
+ ctx->thread_id = qurt_thread_get_id();
+ ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id);
+
+ // allocate VTCM
+ err = vtcm_alloc(ctx);
+ if (err != AEE_SUCCESS) {
+ FARF(ERROR, "Unable to allocate VTCM");
+ return AEE_ENOMEMORY;
+ }
+
+ qurt_sysenv_max_hthreads_t hw_threads;
+ qurt_sysenv_get_max_hw_threads(&hw_threads);
+ uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
+
+ if (n_hvx == 0) {
+ n_hvx = hw_nhvx;
+ }
+ if (n_hvx > hw_threads.max_hthreads) {
+ n_hvx = hw_threads.max_hthreads;
+ }
+ if (n_hvx > HTP_MAX_NTHREADS) {
+ n_hvx = HTP_MAX_NTHREADS;
+ }
+
+ ctx->n_threads = n_hvx;
+ for (int i = 0; i < ctx->n_threads; i++) {
+ // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541
+ ctx->dma[i] = dma_queue_create(64);
+ }
+
+ // init worker pool
+ err = worker_pool_init(&ctx->worker_pool, n_hvx);
+ if (err != AEE_SUCCESS) {
+ FARF(ERROR, "Unable to create worker pool");
+ return err;
+ }
+
+ FARF(HIGH, "session %u started: n-hvx %u vtcm-size %zu vtcm-rctx %u n-threads %u thread-id %d thread-prio %d \n",
+ sess_id, hw_nhvx, ctx->vtcm_size, ctx->vtcm_rctx, ctx->n_threads, ctx->thread_id, ctx->thread_prio);
+
+ return AEE_SUCCESS;
+}
+
+AEEResult htp_iface_stop(remote_handle64 handle) {
+ struct htp_context * ctx = (struct htp_context *) handle;
+ if (!ctx) {
+ return AEE_EBADPARM;
+ }
+
+ if (!ctx->queue) {
+ FARF(ERROR, "Queue not open");
+ return AEE_EBADSTATE;
+ }
+
+ // Close queue. dspqueue_close() will also wait for callbacks to finish.
+ int err = dspqueue_close(ctx->queue);
+ ctx->queue = NULL;
+ if (err != 0) {
+ FARF(ERROR, "Queue close failed with 0x%08x", (unsigned) err);
+ return err;
+ }
+
+ if (ctx->worker_pool) {
+ // Release worker pool
+ worker_pool_release(&ctx->worker_pool);
+ }
+
+ for (int i = 0; i < ctx->n_threads; i++) {
+ dma_queue_delete(ctx->dma[i]);
+ }
+
+ vtcm_free(ctx);
+
+ return AEE_SUCCESS;
+}
+
+static void htp_error_callback(dspqueue_t queue, int error, void * context) {
+ // No errors expected on the DSP.
+ FARF(ERROR, "Error callback: 0x%08x", (unsigned) error);
+}
+
+struct profile_data {
+ uint64_t usecs;
+ uint64_t cycles;
+ uint64_t pkts;
+};
+
+static inline void profile_start(struct profile_data * d) {
+ d->usecs = HAP_perf_get_qtimer_count();
+ d->cycles = hex_get_cycles();
+ d->pkts = hex_get_pktcnt();
+}
+
+static inline void profile_stop(struct profile_data * d) {
+ d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
+ d->cycles = hex_get_cycles() - d->cycles;
+ d->pkts = hex_get_pktcnt() - d->pkts;
+}
+
+static int send_htp_rsp(struct htp_context * c,
+ uint32_t op,
+ uint32_t status,
+ struct dspqueue_buffer * bufs,
+ size_t n_bufs,
+ struct profile_data * prof) {
+ // Prep response struct
+ struct htp_general_rsp rsp;
+ rsp.op = op;
+ rsp.status = status;
+ rsp.prof_usecs = prof->usecs;
+ rsp.prof_cycles = prof->cycles;
+ rsp.prof_pkts = prof->pkts;
+
+ int err = dspqueue_write(c->queue,
+ 0, // Flags
+ n_bufs,
+ bufs, // Buffer references
+ sizeof(rsp),
+ (const uint8_t *) &rsp, // Message
+ DSPQUEUE_TIMEOUT_NONE);
+
+ if (err != 0) {
+ FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err);
+ }
+
+ return err;
+}
+
+static void proc_matmul_req(struct htp_context * ctx,
+ struct htp_general_req * req,
+ struct dspqueue_buffer * bufs,
+ size_t n_bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[2].fd;
+ rsp_bufs[0].ptr = bufs[2].ptr;
+ rsp_bufs[0].size = bufs[2].size;
+ rsp_bufs[0].offset = bufs[2].offset;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.src1 = req->src1;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ octx.dst.data = (uint32_t) bufs[2].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_matmul(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[1].fd;
+ rsp_bufs[0].ptr = bufs[1].ptr;
+ rsp_bufs[0].offset = bufs[1].offset;
+ rsp_bufs[0].size = bufs[1].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.dst.data = (uint32_t) bufs[1].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_argsort(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[1].fd;
+ rsp_bufs[0].ptr = bufs[1].ptr;
+ rsp_bufs[0].offset = bufs[1].offset;
+ rsp_bufs[0].size = bufs[1].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.dst.data = (uint32_t) bufs[1].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_cpy(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[2].fd;
+ rsp_bufs[0].ptr = bufs[2].ptr;
+ rsp_bufs[0].offset = bufs[2].offset;
+ rsp_bufs[0].size = bufs[2].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.src1 = req->src1;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ octx.dst.data = (uint32_t) bufs[2].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_get_rows(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_matmul_id_req(struct htp_context * ctx,
+ struct htp_general_req * req,
+ struct dspqueue_buffer * bufs,
+ size_t n_bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[3].fd;
+ rsp_bufs[0].ptr = bufs[3].ptr;
+ rsp_bufs[0].size = bufs[3].size;
+ rsp_bufs[0].offset = bufs[3].offset;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.src1 = req->src1;
+ octx.src2 = req->src2;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ octx.src2.data = (uint32_t) bufs[2].ptr;
+ octx.dst.data = (uint32_t) bufs[3].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_matmul_id(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[2].fd;
+ rsp_bufs[0].ptr = bufs[2].ptr;
+ rsp_bufs[0].offset = bufs[2].offset;
+ rsp_bufs[0].size = bufs[2].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.src1 = req->src1;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ octx.dst.data = (uint32_t) bufs[2].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_binary(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[3].fd;
+ rsp_bufs[0].ptr = bufs[3].ptr;
+ rsp_bufs[0].offset = bufs[3].offset;
+ rsp_bufs[0].size = bufs[3].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.src1 = req->src1;
+ octx.src2 = req->src2;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ octx.src2.data = (uint32_t) bufs[2].ptr;
+ octx.dst.data = (uint32_t) bufs[3].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_binary(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[1].fd;
+ rsp_bufs[0].ptr = bufs[1].ptr;
+ rsp_bufs[0].offset = bufs[1].offset;
+ rsp_bufs[0].size = bufs[1].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.dst.data = (uint32_t) bufs[1].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_unary(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[1].fd;
+ rsp_bufs[0].ptr = bufs[1].ptr;
+ rsp_bufs[0].offset = bufs[1].offset;
+ rsp_bufs[0].size = bufs[1].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.dst.data = (uint32_t) bufs[1].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_sum_rows(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_activations_req(struct htp_context * ctx,
+ struct htp_general_req * req,
+ struct dspqueue_buffer * bufs,
+ uint32_t n_bufs) {
+ struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+
+ int write_idx = (n_bufs == 3) ? 2 : 1;
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[write_idx].fd;
+ rsp_bufs[0].ptr = bufs[write_idx].ptr;
+ rsp_bufs[0].offset = bufs[write_idx].offset;
+ rsp_bufs[0].size = bufs[write_idx].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ if (3 == n_bufs) {
+ octx.src1 = req->src1;
+ }
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ if (3 == n_bufs) {
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ octx.dst.data = (uint32_t) bufs[2].ptr;
+ } else {
+ octx.dst.data = (uint32_t) bufs[1].ptr;
+ }
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ if (octx.op == HTP_OP_SOFTMAX) {
+ rsp_status = op_softmax(&octx);
+ } else {
+ rsp_status = op_activations(&octx);
+ }
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_rope_req(struct htp_context * ctx,
+ struct htp_general_req * req,
+ struct dspqueue_buffer * bufs,
+ uint32_t n_bufs) {
+ struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+
+ int write_idx = n_bufs - 1;
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[write_idx].fd;
+ rsp_bufs[0].ptr = bufs[write_idx].ptr;
+ rsp_bufs[0].offset = bufs[write_idx].offset;
+ rsp_bufs[0].size = bufs[write_idx].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.src1 = req->src1;
+ if (4 == n_bufs) {
+ octx.src2 = req->src2;
+ }
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ if (4 == n_bufs) {
+ octx.src2.data = (uint32_t) bufs[2].ptr;
+ octx.dst.data = (uint32_t) bufs[3].ptr;
+ } else {
+ octx.dst.data = (uint32_t) bufs[2].ptr;
+ }
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_rope(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[2].fd;
+ rsp_bufs[0].ptr = bufs[2].ptr;
+ rsp_bufs[0].offset = bufs[2].offset;
+ rsp_bufs[0].size = bufs[2].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.src1 = req->src1;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ octx.dst.data = (uint32_t) bufs[2].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_set_rows(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_flash_attn_ext_req(struct htp_context * ctx,
+ struct htp_general_req * req,
+ struct dspqueue_buffer * bufs,
+ uint32_t n_bufs) {
+ // Setup Op context
+ struct htp_ops_context octx;
+ memset(&octx, 0, sizeof(octx));
+
+ octx.ctx = ctx;
+ octx.n_threads = ctx->n_threads;
+
+ octx.src0 = req->src0;
+ octx.src1 = req->src1;
+ octx.src2 = req->src2;
+ octx.src3 = req->src3;
+ octx.src4 = req->src4;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ octx.src2.data = (uint32_t) bufs[2].ptr;
+
+ int last_buf = 3;
+
+ if (octx.src3.ne[0]) {
+ octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid
+ }
+
+ if (octx.src4.ne[0]) {
+ octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid
+ }
+
+ octx.dst.data = (uint32_t) bufs[last_buf].ptr;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_flash_attn_ext(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+
+ struct dspqueue_buffer rsp_buf = bufs[last_buf];
+ rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
+}
+
+static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
+ struct htp_context * ctx = (struct htp_context *) context;
+
+ // Repeatedly read packets from the queue until it's empty. We don't
+ // necessarily get a separate callback for each packet, and new packets
+ // may arrive while we're processing the previous one. This ensures we
+ // keep the DSP busy as much as possible and avoid waiting for the CPU.
+
+ while (1) {
+ struct htp_general_req req;
+ uint32_t req_size;
+
+ struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
+ uint32_t n_bufs;
+ uint32_t flags;
+
+ // Read packet from queue
+ int err = dspqueue_read_noblock(queue, &flags,
+ HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
+ &n_bufs, // Number of buffer references
+ bufs, // Buffer references
+ sizeof(req), // Max message length
+ &req_size, // Message length
+ (uint8_t *) &req); // Message
+
+ if (err == AEE_EWOULDBLOCK) {
+ // Consumed all packets available for now
+ return;
+ }
+
+ if (err != 0) {
+ FARF(ERROR, "dspqueue_read_noblock failed: 0x%08x", (unsigned) err);
+ return;
+ }
+
+ if (req_size != sizeof(req)) {
+ FARF(ERROR, "Invalid request size");
+ continue;
+ }
+
+ if (req.flags & HTP_OPFLAGS_EARLY_WAKEUP) {
+ // Host wants early notification
+ dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0);
+ }
+
+ // Process packet based on its message type
+ switch (req.op) {
+ case HTP_OP_MUL_MAT:
+ if (n_bufs != 3) {
+ FARF(ERROR, "Bad matmul-req buffer list");
+ continue;
+ }
+ proc_matmul_req(ctx, &req, bufs, n_bufs);
+ break;
+
+ case HTP_OP_MUL_MAT_ID:
+ if (n_bufs != 4) {
+ FARF(ERROR, "Bad matmul-id-req buffer list");
+ continue;
+ }
+ proc_matmul_id_req(ctx, &req, bufs, n_bufs);
+ break;
+
+ case HTP_OP_MUL:
+ case HTP_OP_ADD:
+ case HTP_OP_SUB:
+ case HTP_OP_DIV:
+ if (n_bufs != 3) {
+ FARF(ERROR, "Bad binary-req buffer list");
+ continue;
+ }
+ proc_binary_req(ctx, &req, bufs);
+ break;
+
+ case HTP_OP_RMS_NORM:
+ case HTP_OP_SCALE:
+ if (n_bufs != 2) {
+ FARF(ERROR, "Bad unary-req buffer list");
+ continue;
+ }
+
+ proc_unary_req(ctx, &req, bufs);
+ break;
+
+ case HTP_OP_SQR:
+ case HTP_OP_SQRT:
+ if (n_bufs != 2) {
+ FARF(ERROR, "Bad unary-req buffer list");
+ continue;
+ }
+
+ proc_unary_req(ctx, &req, bufs);
+ break;
+
+ case HTP_OP_SUM_ROWS:
+ if (n_bufs != 2) {
+ FARF(ERROR, "Bad unary-req buffer list");
+ continue;
+ }
+
+ proc_sum_rows_req(ctx, &req, bufs);
+ break;
+
+ case HTP_OP_UNARY_SILU:
+ case HTP_OP_UNARY_GELU:
+ if (n_bufs != 2) {
+ FARF(ERROR, "Bad act-req buffer list");
+ continue;
+ }
+ proc_activations_req(ctx, &req, bufs, n_bufs);
+ break;
+
+ case HTP_OP_GLU_SWIGLU:
+ case HTP_OP_GLU_SWIGLU_OAI:
+ case HTP_OP_SOFTMAX:
+ case HTP_OP_GLU_GEGLU:
+ if ((n_bufs != 2) && (n_bufs != 3)) {
+ FARF(ERROR, "Bad act-req buffer list");
+ continue;
+ }
+ proc_activations_req(ctx, &req, bufs, n_bufs);
+ break;
+
+ case HTP_OP_ADD_ID:
+ if (n_bufs != 4) {
+ FARF(ERROR, "Bad add-id-req buffer list");
+ continue;
+ }
+ proc_add_id_req(ctx, &req, bufs);
+ break;
+
+ case HTP_OP_ROPE:
+ if ((n_bufs != 3) && (n_bufs != 4)) {
+ FARF(ERROR, "Bad rope-req buffer list");
+ continue;
+ }
+ proc_rope_req(ctx, &req, bufs, n_bufs);
+ break;
+
+ case HTP_OP_FLASH_ATTN_EXT:
+ if (!(n_bufs >= 4 && n_bufs <= 6)) {
+ FARF(ERROR, "Bad flash-attn-ext-req buffer list");
+ continue;
+ }
+ proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs);
+ break;
+
+ case HTP_OP_SET_ROWS:
+ if (n_bufs != 3) {
+ FARF(ERROR, "Bad set-rows-req buffer list");
+ continue;
+ }
+ proc_set_rows_req(ctx, &req, bufs);
+ break;
+
+ case HTP_OP_GET_ROWS:
+ if (n_bufs != 3) {
+ FARF(ERROR, "Bad get-rows-req buffer list");
+ continue;
+ }
+ proc_get_rows_req(ctx, &req, bufs);
+ break;
+
+ case HTP_OP_CPY:
+ if (n_bufs != 2) {
+ FARF(ERROR, "Bad cpy-req buffer list");
+ continue;
+ }
+ proc_cpy_req(ctx, &req, bufs);
+ break;
+
+ case HTP_OP_ARGSORT:
+ if (n_bufs != 2) {
+ FARF(ERROR, "Bad argsort-req buffer list");
+ continue;
+ }
+ proc_argsort_req(ctx, &req, bufs);
+ break;
+
+ default:
+ FARF(ERROR, "Unknown Op %u", req.op);
+ break;
+ }
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c
new file mode 100644
index 0000000..c360abe
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c
@@ -0,0 +1,2665 @@
+#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <math.h>
+#include <string.h>
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+#include "hvx-dump.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+#define MM_SPAD_SRC0_NROWS 16
+#define MM_SPAD_SRC1_NROWS 16
+#define MM_SPAD_DST_NROWS 2
+
+struct htp_matmul_context {
+ const char * type;
+ struct htp_ops_context * octx;
+
+ void (*vec_dot_1x1)(const int n, float * restrict s0,
+ const void * restrict vx0,
+ const void * restrict vy0);
+
+ void (*vec_dot_2x1)(const int n, float * restrict s0,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0);
+
+ void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0, const void * restrict vy1);
+
+ // Precomputed values
+ uint32_t src0_nrows_per_thread;
+ uint32_t src1_nrows_per_thread;
+
+ struct fastdiv_values mm_div_ne12_ne1;
+ struct fastdiv_values mm_div_ne1;
+ struct fastdiv_values mm_div_r2;
+ struct fastdiv_values mm_div_r3;
+};
+
+// vdelta control to replicate first 4x fp32 values across lanes
+static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = {
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
+ 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
+ 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
+ 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
+ 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
+ 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
+};
+
+// vdelta control to replicate and interleave first 8x fp32 values across lanes
+static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = {
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
+ 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
+ 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
+ 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
+ 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
+ 0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
+ 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
+};
+
+// vdelta control to replicate first fp32 value across all elements
+static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = {
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
+ 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
+ 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
+ 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
+ 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
+ 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
+ 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+};
+
+// vdelta control to replicate first fp16 value across all elements
+static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = {
+ 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
+ 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
+ 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
+ 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
+ 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
+ 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
+ 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+};
+
+// vdelta control to replicate first fp16 value across all elements
+static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = {
+ 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+};
+
+// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
+static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
+ 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
+ 0x00, 0x11, 0x10, 0x10, 0x10, 0x02, 0x00, 0x04, 0x00, 0x01, 0x02, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04,
+ 0x00, 0x00, 0x22, 0x20, 0x20, 0x20, 0x21, 0x22, 0x20, 0x24, 0x04, 0x00, 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x02,
+ 0x00, 0x04, 0x00, 0x11, 0x12, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08,
+ 0x01, 0x02, 0x00, 0x04, 0x44, 0x40, 0x40, 0x40, 0x41, 0x40, 0x40, 0x40, 0x42, 0x40, 0x44, 0x40, 0x41, 0x42, 0x48,
+ 0x48, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x12, 0x10, 0x10, 0x10, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00,
+ 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
+};
+
+static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
+ 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
+ 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+};
+
+// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
+
+static inline size_t q8x4x2_row_size(uint32_t ne) {
+ // ensures perfect alignment of quants and full row
+ const uint32_t qk = QK_Q8_0x4x2;
+ const uint32_t nb = (ne + qk - 1) / qk;
+ return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);
+}
+
+static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
+
+ HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
+ HVX_Vector v2_3 = vptr[1]; // ...
+ HVX_Vector v4_5 = vptr[2]; // ...
+ HVX_Vector v6_7 = vptr[3]; // ...
+
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+ const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
+
+ HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
+ HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
+ HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
+ HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
+ HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
+ HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
+ HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
+
+ // Convert uint4 to int4 (i.e. x - 8)
+ v0 = Q6_Vb_vsub_VbVb(v0, i8);
+ v1 = Q6_Vb_vsub_VbVb(v1, i8);
+ v2 = Q6_Vb_vsub_VbVb(v2, i8);
+ v3 = Q6_Vb_vsub_VbVb(v3, i8);
+ v4 = Q6_Vb_vsub_VbVb(v4, i8);
+ v5 = Q6_Vb_vsub_VbVb(v5, i8);
+ v6 = Q6_Vb_vsub_VbVb(v6, i8);
+ v7 = Q6_Vb_vsub_VbVb(v7, i8);
+
+ HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
+ return r;
+}
+
+static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) {
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
+
+ HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
+ HVX_Vector v2_3 = vptr[1]; // ...
+ HVX_Vector v4_5 = vptr[2]; // ...
+ HVX_Vector v6_7 = vptr[3]; // ...
+
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+ const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
+
+ HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
+ HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
+ HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
+ HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
+ HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
+ HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
+ HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
+
+ v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
+ v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
+ v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
+ v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
+ v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
+ v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
+ v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
+ v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
+
+ HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
+ return r;
+}
+
+static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
+
+ HVX_Vector v0 = vptr[0]; // first 128 vals
+ HVX_Vector v1 = vptr[1]; // ...
+ HVX_Vector v2 = vptr[2]; // ...
+ HVX_Vector v3 = vptr[3]; // ...
+ HVX_Vector v4 = vptr[4]; // ...
+ HVX_Vector v5 = vptr[5]; // ...
+ HVX_Vector v6 = vptr[6]; // ...
+ HVX_Vector v7 = vptr[7]; // ...
+
+ HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
+ return r;
+}
+
+// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
+// Accumulate each block into a single int32 value.
+// Return a single HVX vector with 32x int32 accumulators.
+// This version is parameterized to support less than 1024 elements.
+// if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
+
+static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
+ HVX_Vector r0 = Q6_V_vsplat_R(0);
+ HVX_Vector r1 = Q6_V_vsplat_R(0);
+ HVX_Vector r2 = Q6_V_vsplat_R(0);
+ HVX_Vector r3 = Q6_V_vsplat_R(0);
+ HVX_Vector r4 = Q6_V_vsplat_R(0);
+ HVX_Vector r5 = Q6_V_vsplat_R(0);
+ HVX_Vector r6 = Q6_V_vsplat_R(0);
+ HVX_Vector r7 = Q6_V_vsplat_R(0);
+
+ HVX_VectorPair p3;
+ HVX_VectorPair p2;
+ HVX_VectorPair p1;
+ HVX_VectorPair p0;
+
+ if (n >= 128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); }
+ if (n >= 256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); }
+ if (n >= 384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); }
+ if (n >= 512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); }
+ if (n >= 640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); }
+ if (n >= 768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); }
+ if (n >= 896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); }
+ if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); }
+
+ if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
+ if (n >= 384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
+ if (n >= 640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); }
+ if (n >= 896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); }
+
+ if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
+ if (n >= 384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
+ if (n >= 640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); }
+ if (n >= 896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); }
+
+ if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
+ if (n >= 640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
+
+ if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
+ if (n >= 640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
+
+ if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
+ if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
+
+ return r0;
+}
+
+static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
+ return hvx_vec_rmpy_x8_n(x, y, 1024);
+}
+
+// Handle most common cases of tensors not multiple of 1024.
+static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
+ if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); };
+ if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); };
+ if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); };
+ return hvx_vec_rmpy_x8_n(x, y, 1024);
+}
+
+static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
+ assert(n % 32 == 0); // min sub-block size
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t x_qblk_size = qk / 2; // int4
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
+
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
+
+ // Row sum (sf)
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+
+ // Multiply and accumulate into int32.
+ // Compute combined scale (fp32).
+ // Apply scale to acc and accumulate into the row sum (qf32).
+
+ const uint32_t nb = n / qk; // num full blocks
+ const uint32_t nloe = n % qk; // num leftover elemements
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ }
+
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+ if (nloe) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+
+ // Zero out unused scales
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ }
+
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
+
+ hvx_vec_store_u(s0, 4, r0_sum);
+}
+
+static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0) {
+ assert(n % 32 == 0); // min sub-block size
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t x_qblk_size = qk / 2; // int4
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
+
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
+
+ // Row sum (sf)
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+
+ // Multiply and accumulate into int32.
+ // Compute combined scale (fp32).
+ // Apply scale to acc and accumulate into the row sum (qf32).
+
+ const uint32_t nb = n / qk; // num full blocks
+ const uint32_t nloe = n % qk; // num leftover elemements
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
+
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
+ }
+
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+ if (nloe) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
+
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
+
+ // Zero out unused scales
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
+ }
+
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+ hvx_vec_store_u(s0, 8, rsum);
+}
+
+static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0, const void * restrict vy1) {
+ assert(n % 32 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+ assert((unsigned long) vy1 % 128 == 0);
+
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t x_qblk_size = qk / 2; // int4
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
+
+ const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
+ const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
+
+ // Row sums (sf) - 4 accumulators for 2×2 tile
+ HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+ const uint32_t nb = n / qk; // num full blocks
+ const uint32_t nloe = n % qk; // num leftover elements
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ // Load src1 columns (reused across both src0 rows)
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+
+ // Load src0 rows (reused across both src1 columns)
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+ // Load scales
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ // Compute combined scales
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+ // Apply scales and accumulate
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Process leftovers
+ if (nloe) {
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
+
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+ // Zero out unused scales
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Reduce and store results
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+ hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0
+ hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1
+}
+
+static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
+ assert(n % 32 == 0); // min sub-block size
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t x_qblk_size = qk; // int8
+ const uint32_t x_qrow_size = n; // int8 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
+
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
+
+ // Row sum (sf)
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+
+ // Multiply and accumulate into int32.
+ // Compute combined scale (fp32).
+ // Apply scale to acc and accumulate into the row sum (qf32).
+
+ const uint32_t nb = n / qk; // num full blocks
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ }
+
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+ if (nloe) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+
+ // Zero out unused scales
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ }
+
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
+
+ hvx_vec_store_u(s0, 4, r0_sum);
+}
+
+static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0) {
+ assert(n % 32 == 0); // min sub-block size
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t x_qblk_size = qk; // int8
+ const uint32_t x_qrow_size = n; // int8 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
+
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
+
+ // Row sum (qf32)
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+
+ // Multiply and accumulate into int32.
+ // Compute combined scale (fp32).
+ // Apply scale to acc and accumulate into the row sum (qf32).
+
+ const uint32_t nb = n / qk; // num full blocks
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
+
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
+ }
+
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+ if (nloe) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
+
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
+
+ // Zero out unused scales
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
+ }
+
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+ hvx_vec_store_u(s0, 8, rsum);
+}
+
+static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0, const void * restrict vy1) {
+ assert(n % 32 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+ assert((unsigned long) vy1 % 128 == 0);
+
+ const uint32_t qk = QK_Q8_0x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t x_qblk_size = qk; // int8
+ const uint32_t x_qrow_size = n; // int8 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
+
+ const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
+ const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
+
+ // Row sums (sf) - 4 accumulators for 2×2 tile
+ HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+ const uint32_t nb = n / qk; // num full blocks
+ const uint32_t nloe = n % qk; // num leftover elements
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ // Load src1 columns (reused across both src0 rows)
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+
+ // Load src0 rows (reused across both src1 columns)
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+ // Load scales
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ // Compute combined scales
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+ // Apply scales and accumulate
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Process leftovers
+ if (nloe) {
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
+
+ HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
+ HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+ // Zero out unused scales
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Reduce and store results
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+ hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
+ hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
+}
+
+static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
+ assert(n % 32 == 0); // min sub-block size
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+
+ const uint32_t qk = QK_MXFP4x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
+ const uint32_t x_qblk_size = qk / 2; // fp4
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
+
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
+
+ // Row sum (sf)
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+
+ // Multiply and accumulate into int32.
+ // Compute combined scale (fp32).
+ // Apply scale to acc and accumulate into the row sum (qf32).
+
+ const uint32_t nb = n / qk; // num full blocks
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
+
+ // Convert rX_d scales from e8m0 to fp32
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+ // Left shift with zero fill to create FP32
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ }
+
+ // Process leftovers
+ if (nloe) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
+
+ // Convert rX_d scales from e8m0 to fp32
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+ // Left shift with zero fill to create FP32
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
+
+ // Zero-out unused scales
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ }
+
+ r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
+
+ hvx_vec_store_u(s0, 4, r0_sum);
+}
+
+static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0) {
+ assert(n % 32 == 0); // min sub-block size
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+
+ const uint32_t qk = QK_MXFP4x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
+ const uint32_t x_qblk_size = qk / 2; // fp4
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
+
+ const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first
+ const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
+
+ // Row sum (sf)
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+
+ // Multiply and accumulate into int32.
+ // Compute combined scale (fp32).
+ // Apply scale to acc and accumulate into the row sum (f32).
+
+ const uint32_t nb = n / qk; // num full blocks
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
+
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
+
+ // Convert rX_d scales from e8m0 to fp32
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+ // Left shift with zero fill to create FP32
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
+ }
+
+ // Process leftovers
+ if (nloe) {
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
+
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
+
+ // Convert rX_d scales from e8m0 to fp32
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+ // Left shift with zero fill to create FP32
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
+
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
+
+ // Zero-out unused values
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
+ r0_ia = Q6_V_vand_QV(bmask, r0_ia);
+ r1_ia = Q6_V_vand_QV(bmask, r1_ia);
+
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
+
+ r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+ r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
+ }
+
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+ hvx_vec_store_u(s0, 8, rsum);
+}
+
+static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0, const void * restrict vy1) {
+ assert(n % 32 == 0);
+ assert((unsigned long) vx0 % 128 == 0);
+ assert((unsigned long) vx1 % 128 == 0);
+ assert((unsigned long) vy0 % 128 == 0);
+ assert((unsigned long) vy1 % 128 == 0);
+
+ const uint32_t qk = QK_MXFP4x4x2 * 4;
+
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
+ const uint32_t x_qblk_size = qk / 2; // fp4
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
+
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
+ const uint32_t y_qblk_size = qk; // int8
+ const uint32_t y_qrow_size = n; // int8 (not padded)
+
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
+
+ const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first
+ const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales
+ const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first
+ const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales
+
+ // Row sums (sf) - 4 accumulators for 2×2 tile
+ HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+ const uint32_t nb = n / qk; // num full blocks
+ const uint32_t nloe = n % qk; // num leftover elements
+
+ uint32_t i = 0;
+ for (; i < nb; i++) {
+ // Load src1 columns (reused across both src0 rows)
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+
+ // Load src0 rows (reused across both src1 columns)
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+ // Load scales
+ HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
+ HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
+ vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
+ vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
+ vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
+ vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
+
+ // Convert rX_d scales from e8m0 to fp32
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+ // Left shift with zero fill to create FP32
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
+
+ // Compute combined scales
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
+
+ // Apply scales and accumulate
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Process leftovers
+ if (nloe) {
+ HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
+ HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+
+ HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
+ HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
+ HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
+ HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
+
+ HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
+ HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
+ vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
+ vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
+ vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
+ vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
+
+ // Convert rX_d scales from e8m0 to fp32
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+ // Left shift with zero fill to create FP32
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
+
+ HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
+ HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
+ HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
+ HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
+
+ // Zero out unused scales
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+ r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+ r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+ r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+ r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+ r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+ r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+ r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+ r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+ HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+ HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+ HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+ HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+ }
+
+ // Reduce and store results
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+ hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
+ hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
+}
+
+static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const HVX_Vector * restrict x = (const HVX_Vector *) vx;
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy;
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ if (nloe) {
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
+ hvx_vec_store_u(&s[0], 4, rsum);
+}
+
+static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0) {
+ const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
+ const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
+
+ uint32_t nvec = n / VLEN_FP16;
+ uint32_t nloe = n % VLEN_FP16;
+
+ HVX_Vector rsum0 = Q6_V_vsplat_R(0);
+ HVX_Vector rsum1 = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(2)
+ for (i = 0; i < nvec; i++) {
+ HVX_Vector y_hf = y[i];
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf);
+
+ rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
+ rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
+ }
+
+ if (nloe) {
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
+ HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
+
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+ rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
+ rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
+ }
+
+ HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1));
+ hvx_vec_store_u(s0, 8, rsum);
+}
+
+static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1,
+ const void * restrict vx0, const void * restrict vx1,
+ const void * restrict vy0, const void * restrict vy1) {
+ const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
+ const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
+ const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
+ const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
+
+ uint32_t nvec = n / VLEN_FP16;
+ uint32_t nloe = n % VLEN_FP16;
+
+ // Row sums (sf) - 4 accumulators for 2×2 tile
+ HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
+ HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(2)
+ for (i = 0; i < nvec; i++) {
+ HVX_Vector r0_hf = x0[i];
+ HVX_Vector r1_hf = x1[i];
+ HVX_Vector c0_hf = y0[i];
+ HVX_Vector c1_hf = y1[i];
+
+ // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+ HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
+ HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
+ HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
+ HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
+
+ HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
+ HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
+ HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
+ HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
+ }
+
+ if (nloe) {
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+
+ HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
+ HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
+ HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
+ HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
+
+ HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf);
+ HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf);
+ HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf);
+ HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf);
+
+ HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p));
+ HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p));
+ HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p));
+ HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p));
+
+ r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum));
+ r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum));
+ r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum));
+ r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum));
+
+ }
+
+ // Reduce and store results
+ HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+ HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+ hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
+ hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
+}
+
+static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const HVX_UVector * restrict x = (const HVX_UVector *) vx;
+ const HVX_UVector * restrict y = (const HVX_UVector *) vy;
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ if (nloe) {
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
+ hvx_vec_store_u(&s[0], 4, rsum);
+}
+
+static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+ const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
+ const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(2)
+ for (i = 0; i < nvec; i++) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ if (nloe) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
+
+ // Zero-out unused elements
+ // Note that we need to clear both x and y because they may contain NANs
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ x_hf = Q6_V_vand_QV(bmask, x_hf);
+ y_hf = Q6_V_vand_QV(bmask, y_hf);
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ // Convert into fp32 and reduce
+ rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
+ hvx_vec_store_u(&s[0], 4, rsum);
+}
+
+#define htp_matmul_tensors_preamble \
+ struct htp_tensor * restrict src0 = &octx->src0; \
+ struct htp_tensor * restrict src1 = &octx->src1; \
+ struct htp_tensor * restrict src2 = &octx->src2; \
+ struct htp_tensor * restrict dst = &octx->dst; \
+ struct htp_spad * restrict src0_spad = &octx->src0_spad; \
+ struct htp_spad * restrict src1_spad = &octx->src1_spad; \
+ struct htp_spad * restrict dst_spad = &octx->dst_spad; \
+ \
+ const uint32_t ne00 = src0->ne[0]; \
+ const uint32_t ne01 = src0->ne[1]; \
+ const uint32_t ne02 = src0->ne[2]; \
+ const uint32_t ne03 = src0->ne[3]; \
+ \
+ const uint32_t ne10 = src1->ne[0]; \
+ const uint32_t ne11 = src1->ne[1]; \
+ const uint32_t ne12 = src1->ne[2]; \
+ const uint32_t ne13 = src1->ne[3]; \
+ \
+ const uint32_t ne20 = src2->ne[0]; \
+ const uint32_t ne21 = src2->ne[1]; \
+ const uint32_t ne22 = src2->ne[2]; \
+ const uint32_t ne23 = src2->ne[3]; \
+ \
+ const uint32_t ne0 = dst->ne[0]; \
+ const uint32_t ne1 = dst->ne[1]; \
+ const uint32_t ne2 = dst->ne[2]; \
+ const uint32_t ne3 = dst->ne[3]; \
+ \
+ const uint32_t nb00 = src0->nb[0]; \
+ const uint32_t nb01 = src0->nb[1]; \
+ const uint32_t nb02 = src0->nb[2]; \
+ const uint32_t nb03 = src0->nb[3]; \
+ \
+ const uint32_t nb10 = src1->nb[0]; \
+ const uint32_t nb11 = src1->nb[1]; \
+ const uint32_t nb12 = src1->nb[2]; \
+ const uint32_t nb13 = src1->nb[3]; \
+ \
+ const uint32_t nb0 = dst->nb[0]; \
+ const uint32_t nb1 = dst->nb[1]; \
+ const uint32_t nb2 = dst->nb[2]; \
+ const uint32_t nb3 = dst->nb[3];
+
+#define htp_matmul_preamble \
+ struct htp_matmul_context * mmctx = data; \
+ struct htp_ops_context * octx = mmctx->octx; \
+ htp_matmul_tensors_preamble; \
+ dma_queue *dma_queue = octx->ctx->dma[ith]; \
+ uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread;
+
+// *** matmul with support for 4d tensors and full broadcasting
+
+static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
+ htp_matmul_preamble;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ assert(ne12 % ne02 == 0);
+ assert(ne13 % ne03 == 0);
+
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
+ const uint32_t nr0 = ne0;
+
+ // This is the size of the rest of the dimensions of the result
+ const uint32_t nr1 = ne1 * ne2 * ne3;
+
+ // distribute the thread work across the inner or outer loop based on which one is larger
+ uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
+ uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+
+ // The number of elements in each chunk
+ const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
+ const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
+
+ uint32_t current_chunk = ith;
+
+ const uint32_t ith0 = current_chunk % nchunk0;
+ const uint32_t ith1 = current_chunk / nchunk0;
+
+ const uint32_t ir0_start = dr0 * ith0;
+ const uint32_t ir0_end = MIN(ir0_start + dr0, nr0);
+
+ const uint32_t ir1_start = dr1 * ith1;
+ const uint32_t ir1_end = MIN(ir1_start + dr1, nr1);
+
+ // no work for this thread
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
+ return;
+ }
+
+ // block-tiling attempt
+ const uint32_t blck_0 = 64;
+ const uint32_t blck_1 = 64;
+
+ for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
+ for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
+ for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
+ const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1);
+ const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1);
+ const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
+
+ // broadcast src0 into src1
+ const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3);
+ const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2);
+
+ const uint32_t i1 = i11;
+ const uint32_t i2 = i12;
+ const uint32_t i3 = i13;
+
+ const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
+ const uint8_t * restrict src1_col = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
+ float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
+
+ const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
+ for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
+ const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
+ mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
+ }
+ }
+ }
+ }
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
+ src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// src1 tensor is already in VTCM spad
+static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
+ htp_matmul_preamble;
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+ const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ const size_t dst_row_size = nb1;
+ const size_t src0_row_size = nb01;
+ const size_t src1_row_size = nb11;
+
+ const size_t src0_stride = src0_spad->stride;
+ const size_t src1_stride = src1_spad->stride;
+
+ // Per-thread VTCM scratchpads for all tensors
+ // Note that the entire src1 tensor is already in VTCM
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
+ uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
+ uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
+ uint8_t * restrict src1_data = src1_spad->data;
+
+ volatile uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
+
+ // Prefill spad with src0 rows
+ #pragma unroll(4)
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+ const int is0 = (ir0 - src0_start_row);
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
+ break;
+ }
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
+ }
+
+ // Process src0 rows
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
+
+ // Process src1 columns in pairs (2×2 tiling)
+ uint32_t ir1 = 0;
+ for (; ir1 + 1 < src1_nrows; ir1 += 2) {
+ const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
+ const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
+ float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
+ float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
+ mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1);
+ }
+
+ // Handle remaining src1 rows (fallback to 2×1)
+ for (; ir1 < src1_nrows; ++ir1) {
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
+ float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
+ mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
+ }
+
+ // Prefetch next (n + spad_nrows) row
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
+ if (pr0 < src0_end_row_x2) {
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
+ }
+ }
+
+ // Process the last row (if any)
+ if (src0_end_row != src0_end_row_x2) {
+ uint32_t ir0 = src0_end_row_x2;
+ const int is0 = (ir0 - src0_start_row);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 1);
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
+
+ #pragma unroll(2)
+ for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
+ float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
+ mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
+ }
+ }
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
+ src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// q8x4x2 src1 tensor is already in VTCM spad
+static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
+ htp_matmul_preamble;
+
+ const uint32_t src0_nrows = ne01;
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ const size_t dst_row_size = nb1;
+ const size_t src0_row_size = nb01;
+ const size_t src1_row_size = nb11;
+
+ const size_t src0_stride = src0_spad->stride;
+ const size_t src1_stride = src1_spad->stride;
+
+ // Per-thread VTCM scratchpads for all tensors
+ // Note that the entire src1 tensor is already in VTCM
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
+ uint8_t * spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
+ uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
+ uint8_t * src1_data = src1_spad->data;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ float * tmp = (float *) spad_dst;
+
+ const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
+ const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
+ float * restrict dst_col = (float *) dst->data;
+
+ // Prefill spad with 2x src0 rows
+ #pragma unroll(2)
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+ const uint32_t is0 = (ir0 - src0_start_row);
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
+ break;
+ }
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
+ }
+
+ // Process src0 rows
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
+ mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
+
+ // Prefetch next (n + spad_nrows) row
+ const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
+ if (pr0 < src0_end_row_x2) {
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
+ }
+ }
+
+ // Process the last row (if any)
+ if (src0_end_row != src0_end_row_x2) {
+ const uint32_t ir0 = src0_end_row_x2;
+ const uint32_t is0 = (ir0 - src0_start_row);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 1);
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
+ mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
+ }
+
+ hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
+ src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)]
+
+struct mmid_row_mapping {
+ uint32_t i1;
+ uint32_t i2;
+};
+
+// src1 tensor is already in VTCM spad
+static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
+ htp_matmul_preamble;
+
+ struct htp_tensor * restrict ids = &octx->src2;
+ struct htp_spad * restrict src2_spad = &octx->src2_spad;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ const uint32_t src0_nrows = ne01; // src0 rows per expert
+ const uint32_t src1_nrows = ne11;
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ const uint32_t n_ids = ids->ne[0]; // n_expert_used
+ const uint32_t n_as = ne02; // n_expert
+
+ const size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
+ const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
+
+ const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0;
+ const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size;
+
+ const size_t dst_row_size = nb1;
+ const size_t src0_row_size = nb01;
+ const size_t src1_row_size = q8x4x2_row_size(ne10);
+
+ const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
+
+ // Per-thread VTCM scratchpads for all tensors
+ // Note that the entire src1 tensor is already in VTCM
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
+ uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
+ uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
+ uint8_t * restrict src1_data = src1_spad->data;
+
+ for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) {
+ const int32_t cne1 = matrix_row_counts[cur_a];
+
+ if (cne1 == 0) {
+ continue;
+ }
+
+ const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0);
+
+ // Prefill spad with src0 rows
+ #pragma unroll(4)
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+ const int is0 = (ir0 - src0_start_row);
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
+ break;
+ }
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
+ src0_row_size_padded, src0_row_size, 2);
+ }
+
+ // Process src0 rows
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
+
+ for (uint32_t cid = 0; cid < cne1; ++cid) {
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
+ const int rm1 = row_mapping.i1; // expert idx
+ const int rm2 = row_mapping.i2; // token idx
+
+ const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
+ float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
+
+ mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
+ }
+
+ // Prefetch next (n + spad_nrows) row
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
+ if (pr0 < src0_end_row_x2) {
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
+ src0_row_size_padded, src0_row_size, 2);
+ }
+ }
+
+ // Process the last row (if any)
+ if (src0_end_row != src0_end_row_x2) {
+ uint32_t ir0 = src0_end_row_x2;
+ const uint32_t is0 = (ir0 - src0_start_row);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
+ src0_row_size_padded, src0_row_size, 1);
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
+
+ for (uint32_t cid = 0; cid < cne1; ++cid) {
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
+ const int rm1 = row_mapping.i1; // expert idx
+ const int rm2 = row_mapping.i2; // token idx
+
+ const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
+ float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
+
+ mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
+ }
+ }
+ }
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
+ ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
+ src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],
+ dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// src1 tensor is already in VTCM spad
+static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
+ htp_matmul_preamble;
+
+ struct htp_tensor * restrict ids = &octx->src2;
+ struct htp_spad * restrict src2_spad = &octx->src2_spad;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ const uint32_t src0_nrows = ne01; // src0 rows per expert
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ assert(ne13 % ne03 == 0);
+
+ const size_t dst_row_size = nb1;
+ const size_t src0_row_size = nb01;
+ const size_t src1_row_size = q8x4x2_row_size(ne10);
+
+ const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
+
+ const uint32_t n_aids = src2->ne[0]; // num activated experts
+ const uint32_t n_ids = ne02; // num experts
+
+ // Per-thread VTCM scratchpads for all tensors
+ // Note that the entire src1 tensor is already in VTCM
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
+ uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
+ uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
+ uint8_t * restrict src1_data = src1_spad->data;
+
+ for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) { // for each expert
+ const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]);
+ assert(eid < n_ids);
+
+ const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02;
+ const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
+ float * restrict dst_row = (float *) (dst->data + ie1 * nb1);
+
+ // Prefill spad with src0 rows
+ #pragma unroll(4)
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+ const int is0 = (ir0 - src0_start_row);
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
+ break;
+ }
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
+ src0_row_size_padded, src0_row_size, 2);
+ }
+
+ // Process src0 rows
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
+ mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
+
+ // Prefetch next (n + spad_nrows) row
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
+ if (pr0 < src0_end_row_x2) {
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
+ src0_row_size_padded, src0_row_size, 2);
+ }
+ }
+
+ // Process the last row (if any)
+ if (src0_end_row != src0_end_row_x2) {
+ uint32_t ir0 = src0_end_row_x2;
+ const uint32_t is0 = (ir0 - src0_start_row);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
+ src0_row_size_padded, src0_row_size, 1);
+ const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
+ mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
+ }
+ }
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
+ ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
+ src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],
+ dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// *** dynamic quant
+
+static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
+ assert((unsigned long) x % 128 == 0);
+ assert((unsigned long) y_q % 128 == 0);
+
+ HVX_Vector * vx = (HVX_Vector *) x;
+ HVX_Vector zero = Q6_V_vsplat_R(0);
+
+ // Use reduce max fp32 to find max(abs(e)) first
+ HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
+ HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
+ HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
+ HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
+ // Load and convert into QF32
+ HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
+ HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
+ HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
+ HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
+
+ // Convert to QF32
+ HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
+ HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
+ HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
+ HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
+
+ // Combine and convert to fp16
+ HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
+ HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
+
+ // Convert into fp16
+ HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
+ HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
+
+ // Replicate first fp16 scale across all lanes
+ HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16;
+ vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
+ vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
+
+ HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
+ HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
+ HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
+ HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
+
+ hvx_vec_store_u(y_d + 0, 2, vd01_hf);
+ HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64);
+ hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf);
+
+ hvx_vec_store_u(y_d + 4, 2, vd23_hf);
+ rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64);
+ hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
+
+ // Divide input by the scale
+ HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
+ HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
+ vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
+ vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
+
+ // Convert to int8
+ HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
+ HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
+ HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
+
+ *(HVX_Vector *) y_q = vx_i8;
+}
+
+static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
+ assert((unsigned long) x % 128 == 0);
+ assert((unsigned long) y_q % 128 == 0);
+
+ HVX_Vector * vx = (HVX_Vector *) x;
+
+ // Load and convert into QF32
+ HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
+ HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
+ HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
+ HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
+
+ // Convert into fp16
+ HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
+ HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
+
+ // Compute max and scale
+ HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
+ HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf));
+
+ // Replicate first fp16 scale across all lanes
+ HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
+ vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
+ vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
+
+ HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
+ HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
+ HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
+ HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
+
+ hvx_vec_store_u(y_d + 0, 4, vd01_hf);
+ hvx_vec_store_u(y_d + 4, 4, vd23_hf);
+
+ // Divide input by the scale
+ HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
+ HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
+ vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
+ vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
+
+ // Convert to int8
+ HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
+ HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
+ HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
+
+ *(HVX_Vector *) y_q = vx_i8;
+}
+
+static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
+ assert((unsigned long) x % 128 == 0);
+ assert((unsigned long) y_q % 128 == 0);
+
+ HVX_Vector * vx = (HVX_Vector *) x;
+
+ // Load and convert into QF32
+ HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
+ HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
+ HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
+ HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
+
+ // Convert into fp16
+ HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
+ HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
+
+ // Compute max and scale
+ HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
+ vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
+
+ // Replicate first fp16 scale across all lanes
+ HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
+ vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
+
+ HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
+ HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);
+
+ *(HVX_UVector *) y_d = vd_hf;
+
+ // Divide input by the scale
+ HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf);
+ vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));
+ vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));
+
+ // Convert to int8
+ HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
+ HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
+ HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
+
+ *(HVX_Vector *) y_q = vx_i8;
+}
+
+// Overrides input x
+static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
+ assert(k % 32 == 0);
+ const uint32_t qk = QK_Q8_0x4x2;
+ const uint32_t nb = (k + qk - 1) / qk;
+
+ const uint32_t qrow_size = k; // int8
+
+ const uint32_t dblk_size = 8 * 2; // 8x __fp16
+ const uint32_t qblk_size = QK_Q8_0x4x2; // int8
+
+ uint8_t * restrict y_q = (y + 0); // quants first
+ uint8_t * restrict y_d = (y + qrow_size); // then scales
+
+ // Temp scales override input since we're working off of the aligned temp buffer in VTCM
+ uint8_t * restrict t_d = (uint8_t *) x;
+
+ for (uint32_t i = 0; i < nb; i++) {
+#if FP32_QUANTIZE_GROUP_SIZE == 32
+ quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+ quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
+#elif FP32_QUANTIZE_GROUP_SIZE == 64
+ quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+ quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
+#elif FP32_QUANTIZE_GROUP_SIZE == 128
+ quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+ quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
+#else
+#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
+#endif
+ }
+
+ // now copy the scales into final location
+ hvx_copy_f16_ua(y_d, t_d, nb * 8);
+}
+
+static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_matmul_context * mmctx = data;
+ struct htp_ops_context * octx = mmctx->octx;
+
+ const struct htp_tensor * src = &octx->src1;
+ uint8_t * restrict dst = octx->src1_spad.data;
+ struct htp_spad * spad = &octx->src0_spad;
+ uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
+
+ uint64_t t1 = HAP_perf_get_qtimer_count();
+
+ const uint32_t ne0 = src->ne[0];
+ const uint32_t ne1 = src->ne[1];
+ const uint32_t ne2 = src->ne[2];
+ const uint32_t ne3 = src->ne[3];
+
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
+
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
+
+ const size_t src_row_size = src->nb[1];
+ const size_t dst_row_size = q8x4x2_row_size(ne0);
+
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first);
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
+ uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
+
+ const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
+ memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding
+
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
+ hex_l2fetch(src_data, src_row_size, src_row_size, 2);
+ hvx_copy_f32_aa(tmp_data, src_data, ne0);
+
+ // FARF(HIGH, "quantize-q8x4-row: %u\n", i);
+ quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0);
+ dst_data += dst_row_size;
+ src_data += src_row_size;
+ }
+
+ uint64_t t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
+ ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_matmul_context * mmctx = data;
+ struct htp_ops_context * octx = mmctx->octx;
+
+ const struct htp_tensor * src = &octx->src1;
+ uint8_t * restrict dst = octx->src1_spad.data;
+ uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
+ uint32_t dst_stride = octx->src1_spad.stride;
+
+ uint64_t t1 = HAP_perf_get_qtimer_count();
+
+ const uint32_t ne0 = src->ne[0];
+ const uint32_t ne1 = src->ne[1];
+ const uint32_t ne2 = src->ne[2];
+ const uint32_t ne3 = src->ne[3];
+
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
+
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
+
+ const size_t src_row_size = ne0 * sizeof(float);
+ const size_t src_stride = src->nb[1];
+
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
+
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
+ hex_l2fetch(src_data, src_row_size, src_stride, 2);
+ hvx_copy_f16_f32_au(dst_data, src_data, ne0);
+
+ dst_data += dst_stride;
+ src_data += src_stride;
+ }
+
+ uint64_t t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
+ ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// TODO just a plain copy that should be done via the DMA during the Op setup
+static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_matmul_context * mmctx = data;
+ struct htp_ops_context * octx = mmctx->octx;
+
+ const struct htp_tensor * src = &octx->src1;
+ uint8_t * restrict dst = octx->src1_spad.data;
+ uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
+ uint32_t dst_stride = octx->src1_spad.stride;
+
+ uint64_t t1 = HAP_perf_get_qtimer_count();
+
+ const uint32_t ne0 = src->ne[0];
+ const uint32_t ne1 = src->ne[1];
+ const uint32_t ne2 = src->ne[2];
+ const uint32_t ne3 = src->ne[3];
+
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
+
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
+
+ const size_t src_row_size = ne0 * sizeof(float);
+ const size_t src_stride = src->nb[1];
+
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
+
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
+ hex_l2fetch(src_data, src_row_size, src_stride, 2);
+ hvx_copy_f16_au(dst_data, src_data, ne0);
+
+ dst_data += dst_stride;
+ src_data += src_stride;
+ }
+
+ uint64_t t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
+ ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+
+static inline bool htp_is_permuted(const struct htp_tensor * t) {
+ return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
+}
+
+static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) {
+ switch (type) {
+ case HTP_TYPE_Q4_0:
+ mmctx->type = "q4x4x2-f32";
+ mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1;
+ mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1;
+ mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2;
+ return 0;
+ case HTP_TYPE_Q8_0:
+ mmctx->type = "q8x4x2-f32";
+ mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1;
+ mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
+ mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
+ return 0;
+ case HTP_TYPE_MXFP4:
+ mmctx->type = "mxfp4x4x2-f32";
+ mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
+ mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1;
+ mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2;
+ return 0;
+ default:
+ return -1;
+ }
+}
+
+static void htp_mminit_spad(struct htp_ops_context * octx,
+ size_t dst_row_size,
+ size_t src0_row_size_padded,
+ size_t src1_row_size,
+ uint32_t src1_nrows,
+ size_t src2_spad_size_per_thread) {
+ octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
+
+ if (src2_spad_size_per_thread > 0) {
+ octx->src2_spad.size_per_thread = src2_spad_size_per_thread;
+ octx->src2_spad.size = octx->src2_spad.size_per_thread;
+ }
+
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
+ size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
+ }
+
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
+}
+
+int op_matmul(struct htp_ops_context * octx) {
+ htp_matmul_tensors_preamble;
+
+ struct htp_matmul_context mmctx_struct = {0};
+ struct htp_matmul_context * mmctx = &mmctx_struct;
+ mmctx->octx = octx;
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03;
+ const uint32_t src1_nrows = ne11 * ne12 * ne13;
+
+ // Compute src0_nrows_per_thread
+ mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
+ mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
+
+ const size_t src0_row_size = nb01;
+ const size_t dst_row_size = nb1;
+ size_t src1_row_size = nb11;
+
+ const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
+ size_t src1_row_size_padded;
+
+ worker_callback_t quant_job_func;
+ worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
+
+ bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
+
+ if (src0->type == HTP_TYPE_F16) {
+ // Try optimized f16-f16 path first (src1 in VTCM)
+ const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128);
+ const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);
+ const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
+ const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
+
+ const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
+
+ // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
+ // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
+ const bool is_batched = (ne02 > 1) || (ne03 > 1);
+ const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
+
+ if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
+ // Optimized path
+ quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16;
+ mmctx->type = "f16-f16";
+ mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1;
+ mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1;
+ mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2;
+
+ src1_row_size = f16_src1_row_size; // row size post quantization
+
+ octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
+
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
+ } else {
+ // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
+ quant_job_func = NULL;
+ if (src1->type == HTP_TYPE_F32) {
+ mmctx->type = "f16-f32";
+ mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1;
+ matmul_job_func = matmul_4d;
+ } else {
+ mmctx->type = "f16-f16";
+ mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1;
+ matmul_job_func = matmul_4d;
+ }
+
+ src1_row_size = nb11; // original row size in DDR
+
+ octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
+ octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
+
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
+
+ // Init fastdiv for matmul_4d (supports broadcasting)
+ mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
+ mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
+ mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
+ mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
+
+ need_quant = false;
+ }
+ } else {
+ if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ quant_job_func = quantize_f32_q8x4x2;
+ src1_row_size = q8x4x2_row_size(ne10);
+ htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0);
+ }
+
+ // VTCM scratchpads for all tensors
+ size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
+
+ FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
+ octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);
+
+ FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0],
+ src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
+ dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);
+
+ // Make sure the reserved vtcm size is sufficient
+ if (octx->ctx->vtcm_size < spad_size) {
+ FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type,
+ octx->ctx->vtcm_size, spad_size);
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+ octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+
+ octx->src0_spad.stride = src0_row_size_padded;
+ octx->src1_spad.stride = src1_row_size;
+
+ if (need_quant) {
+ const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
+ mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
+ worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
+ }
+
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ const uint32_t n_matmul_jobs = octx->n_threads;
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
+ }
+
+ return HTP_STATUS_OK;
+}
+
+int op_matmul_id(struct htp_ops_context * octx) {
+ htp_matmul_tensors_preamble;
+
+ struct htp_matmul_context mmctx_struct = {0};
+ struct htp_matmul_context * mmctx = &mmctx_struct;
+ mmctx->octx = octx;
+
+ struct htp_tensor * restrict ids = &octx->src2;
+
+ const size_t src0_row_size = nb01;
+ const size_t dst_row_size = nb1;
+
+ const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
+
+ const uint32_t src0_nrows = ne01; // per expert
+ const uint32_t src1_nrows = ne11 * ne12 * ne13;
+
+ worker_callback_t quant_job_func;
+ worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id;
+
+ // Compute src0_nrows_per_thread
+ mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
+ mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
+
+ size_t src1_row_size;
+ size_t src1_row_size_padded;
+
+ // row groups
+ const int n_ids = ids->ne[0]; // n_expert_used
+ const int n_as = ne02; // n_expert
+
+ size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
+ size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
+
+ if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ quant_job_func = quantize_f32_q8x4x2;
+ src1_row_size = q8x4x2_row_size(ne10);
+
+ const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
+ htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);
+
+ size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
+
+ FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
+ octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);
+
+ FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+ ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,
+ src1->data, dst->data);
+
+ // Make sure the reserved vtcm size is sufficient
+ if (octx->ctx->vtcm_size < spad_size) {
+ FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size);
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+ octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+ octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size;
+
+ octx->src0_spad.stride = src0_row_size_padded;
+ octx->src1_spad.stride = src1_row_size;
+
+ if (src1_nrows > 1) {
+ // initialize matrix_row_counts and map
+ uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0;
+ struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size;
+
+ memset(matrix_row_counts, 0, n_as * sizeof(uint32_t));
+
+ // group rows by src0 matrix
+ for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx
+ for (uint32_t id = 0; id < n_ids; ++id) { // expert idx
+ const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
+
+ assert(i02 >= 0 && i02 < n_as);
+
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 };
+ matrix_row_counts[i02] += 1;
+ }
+ }
+ }
+
+ // Setup worker pool callbacks
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
+ const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
+ mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
+ worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
+ }
+
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ const uint32_t n_matmul_jobs = octx->n_threads;
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
+ }
+
+ return HTP_STATUS_OK;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/rope-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/rope-ops.c
new file mode 100644
index 0000000..943ca5c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/rope-ops.c
@@ -0,0 +1,480 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <math.h>
+#include <string.h>
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h
+#define HTP_ROPE_TYPE_NORMAL 0
+#define HTP_ROPE_TYPE_NEOX 2
+
+#define htp_rope_preamble \
+ const uint32_t ne00 = src0->ne[0]; \
+ const uint32_t ne01 = src0->ne[1]; \
+ const uint32_t ne02 = src0->ne[2]; \
+ const uint32_t ne03 = src0->ne[3]; \
+ \
+ const uint32_t ne0 = dst->ne[0]; \
+ const uint32_t ne1 = dst->ne[1]; \
+ const uint32_t ne2 = dst->ne[2]; \
+ const uint32_t ne3 = dst->ne[3]; \
+ \
+ const uint32_t nb00 = src0->nb[0]; \
+ const uint32_t nb01 = src0->nb[1]; \
+ const uint32_t nb02 = src0->nb[2]; \
+ const uint32_t nb03 = src0->nb[3]; \
+ \
+ const uint32_t nb0 = dst->nb[0]; \
+ const uint32_t nb1 = dst->nb[1]; \
+ const uint32_t nb2 = dst->nb[2]; \
+ const uint32_t nb3 = dst->nb[3];
+
+struct rope_th_ctx {
+ int32_t n_dims;
+ int32_t mode;
+ int32_t n_ctx_orig;
+ int32_t sections[4];
+
+ float freq_base;
+ float freq_scale;
+ float ext_factor;
+ float attn_factor;
+ float beta_fast;
+ float beta_slow;
+ float theta_scale;
+ float corr_dims[2];
+
+ struct htp_ops_context * octx;
+};
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+
+ return (1 - MIN(1, MAX(0, y)));
+}
+
+static void rope_cache_init(const float theta_base,
+ const float freq_scale,
+ const float * freq_factors,
+ float * corr_dims,
+ const uint32_t ne0,
+ const float ext_factor,
+ const float mscale,
+ float * cache,
+ const float theta_scale) {
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
+ float theta = theta_base;
+
+ for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
+ const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
+
+ float theta_extrap = theta / ff;
+
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta_final = theta_interp;
+ float mscale_final = mscale;
+
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+ theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+ }
+
+ cache[i0 + 0] = cosf(theta_final) * mscale_final;
+ cache[i0 + 1] = sinf(theta_final) * mscale_final;
+
+ theta *= theta_scale;
+ }
+}
+
+#define M_PI 3.1415926535897932384626433
+
+static void rope_corr_dims(int n_dims,
+ int n_ctx_orig,
+ float freq_base,
+ float beta_fast,
+ float beta_slow,
+ float * dims) {
+ float start = floorf(n_dims * logf(n_ctx_orig / (beta_fast * 2 * (float) M_PI)) / (2 * logf(freq_base)));
+ float end = ceilf(n_dims * logf(n_ctx_orig / (beta_slow * 2 * (float) M_PI)) / (2 * logf(freq_base)));
+ dims[0] = MAX(0, start);
+ dims[1] = MIN(n_dims - 1, end);
+}
+
+static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
+ memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
+
+ const int32_t * op_params = &octx->op_params[0];
+
+ rope_ctx->n_dims = ((const int32_t *) op_params)[1];
+ rope_ctx->mode = ((const int32_t *) op_params)[2];
+ rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
+
+ memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
+ memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
+ memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
+ memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
+ memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
+ memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
+ memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
+
+ rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
+
+ rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
+ rope_ctx->beta_slow, rope_ctx->corr_dims);
+
+ rope_ctx->octx = octx;
+ FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
+ rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
+}
+
+static void hvx_calc_rope_neox_f32(const float * restrict src0,
+ float * restrict dst,
+ const int num_elems,
+ const float * restrict theta_cache) {
+ // for (int i = 0; i < num_elems; i += 2) {
+ //const float cos_theta = theta_cache[i + 0];
+ //const float sin_theta = theta_cache[i + 1];
+
+ //const float x0 = src[0];
+ //const float x1 = src[num_elems/2];
+
+ //dst[0] = x0*cos_theta - x1*sin_theta;
+ //dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
+
+ //src += 1;
+ //dst += 1;
+ // }
+
+ const uint8_t * restrict src0_curr = (const uint8_t *) src0;
+ const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
+ uint8_t * restrict dst_curr = (uint8_t *) dst;
+
+ int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
+ int half_size = (sizeof(float) * (num_elems / 2));
+
+ for (int i = 0; i < step_of_1; i++) {
+ HVX_Vector v0 = *(HVX_Vector *) src0_curr;
+ HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
+
+ HVX_Vector v2 = *(HVX_Vector *) theta_curr;
+ HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
+
+ HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
+
+ HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin));
+ HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin));
+ HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin));
+ HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin));
+
+ HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
+ HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
+
+ *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
+ *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
+
+ src0_curr += VLEN;
+ theta_curr += 2 * VLEN;
+ dst_curr += VLEN;
+ }
+}
+
+static void hvx_calc_rope_f32(const float * restrict src0,
+ float * restrict dst,
+ const int num_elems,
+ const float * restrict theta_cache) {
+ // for (int i = 0; i < num_elems; i += 2) {
+ //const float cos_theta = theta_cache[i + 0];
+ //const float sin_theta = theta_cache[i + 1];
+
+ //const float x0 = src[0];
+ //const float x1 = src[1];
+
+ //dst[0] = x0*cos_theta - x1*sin_theta;
+ //dst[1] = x0*sin_theta + x1*cos_theta;
+
+ //src += 2;
+ //dst += 2;
+ // }
+
+ const uint8_t * restrict src0_curr = (const uint8_t *) src0;
+ const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
+ uint8_t * restrict dst_curr = (uint8_t *) dst;
+
+ int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
+
+ for (int i = 0; i < step_of_1; i++) {
+ HVX_Vector v0 = *(HVX_Vector *) src0_curr;
+ HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
+
+ HVX_Vector v2 = *(HVX_Vector *) theta_curr;
+ HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
+
+ HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1
+ HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
+
+ HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin));
+ HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin));
+ HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin));
+ HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin));
+
+ HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
+ HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
+
+ HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
+
+ *(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore);
+ *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
+
+ src0_curr += 2 * VLEN;
+ theta_curr += 2 * VLEN;
+ dst_curr += 2 * VLEN;
+ }
+}
+
+static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
+ const uint32_t ir0,
+ const uint32_t ir1,
+ int nth,
+ int ith,
+ const int opt_path) {
+ struct htp_ops_context * octx = rope_ctx->octx;
+
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * src1 = &octx->src1;
+ const struct htp_tensor * src2 = &octx->src2;
+ struct htp_tensor * dst = &octx->dst;
+
+ const int32_t mode = rope_ctx->mode;
+ const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
+
+ htp_rope_preamble;
+
+ const int32_t * pos = (const int32_t *) src1->data;
+
+ float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
+
+ const float * freq_factors = NULL;
+ if (src2 != NULL) {
+ freq_factors = (const float *) src2->data;
+ }
+
+ const uint32_t i1_end = MIN(ir1, ne1);
+ const int32_t half_dims = rope_ctx->n_dims / 2;
+ const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
+ for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
+ for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
+ const int32_t p = pos[i2];
+
+ rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
+ rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
+
+ for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
+ const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
+ float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
+
+ const float * src_loc = src;
+ float * dst_data_loc = dst_data;
+
+ if (1 == opt_path) {
+ if (is_neox) {
+ hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
+ } else {
+ hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
+ }
+
+ src_loc += rope_ctx->n_dims;
+ dst_data_loc += rope_ctx->n_dims;
+ } else {
+ for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
+ const float cos_theta = wp0[i0 + 0];
+ const float sin_theta = wp0[i0 + 1];
+
+ if (is_neox) {
+ const float x0 = src_loc[0];
+ const float x1 = src_loc[half_dims];
+
+ dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
+ dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
+
+ src_loc += 1;
+ dst_data_loc += 1;
+ } else {
+ const float x0 = src_loc[0];
+ const float x1 = src_loc[1];
+
+ dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
+ dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
+
+ src_loc += 2;
+ dst_data_loc += 2;
+ }
+ }
+
+ src_loc += (is_neox ? half_dims : 0);
+ dst_data_loc += (is_neox ? half_dims : 0);
+ }
+
+ // TODO: use simd to speed up the remaining elements copy
+ memcpy(dst_data_loc, src_loc, remain_bytes);
+ }
+ }
+ }
+}
+
+static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
+ struct htp_ops_context * octx = rope_ctx->octx;
+
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * src1 = &octx->src1;
+ struct htp_tensor * dst = &octx->dst;
+
+ htp_rope_preamble;
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+ const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ int is_aligned = 1;
+ int opt_path = 0;
+ if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
+ (0 == hex_is_aligned((void *) dst->data, VLEN))) {
+ FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
+ is_aligned = 0;
+ }
+ if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
+ opt_path = 1;
+ }
+
+ rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
+ struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
+
+ rope_job_f32_per_thread(rope_ctx, n, i);
+}
+
+static int execute_op_rope_f32(struct htp_ops_context * octx) {
+ int err = HTP_STATUS_OK;
+
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * src1 = &octx->src1;
+ const struct htp_tensor * src2 = &octx->src2;
+ struct htp_tensor * dst = &octx->dst;
+
+ worker_callback_t op_func;
+ const char * op_type = NULL;
+
+ struct rope_th_ctx rope_ctx;
+
+ switch (octx->op) {
+ case HTP_OP_ROPE:
+ op_func = rope_job_dispatcher_f32;
+ op_type = "rope-f32";
+
+ init_rope_ctx(&rope_ctx, octx);
+ break;
+
+ default:
+ FARF(ERROR, "Unsupported Op %u\n", octx->op);
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ const uint32_t n_threads = octx->n_threads;
+
+ const size_t src0_row_size = src0->nb[1];
+ const size_t src1_row_size = src0_row_size;
+ const size_t dst_row_size = dst->nb[1];
+
+ // VTCM scratchpads for all tensors
+ // N rows per thread, padded to HVX vector size
+ octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
+ octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
+ octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
+
+ size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+
+ if (src2->ne[0]) {
+ FARF(HIGH,
+ "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
+ "dst-spad-size %u\n",
+ op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
+ src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
+ dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+ } else {
+ FARF(HIGH,
+ "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
+ op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
+ src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
+ octx->dst_spad.size);
+ }
+
+ // Make sure the reserved vtcm size is sufficient
+ if (octx->ctx->vtcm_size < spad_size) {
+ FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
+ spad_size);
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+ octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+
+ uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ uint32_t n_jobs = MIN(n_threads, src0_nrows);
+ octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
+ }
+
+ return err;
+}
+
+int op_rope(struct htp_ops_context * octx) {
+ int err = HTP_STATUS_OK;
+
+ switch (octx->src0.type) {
+ case HTP_TYPE_F32:
+ err = execute_op_rope_f32(octx);
+ break;
+
+ default:
+ err = HTP_STATUS_NO_SUPPORT;
+ break;
+ }
+
+ return err;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c
new file mode 100644
index 0000000..904484d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c
@@ -0,0 +1,164 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <math.h>
+#include <string.h>
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+#define set_rows_preamble \
+ const uint32_t ne00 = octx->src0.ne[0]; \
+ const uint32_t ne01 = octx->src0.ne[1]; \
+ const uint32_t ne02 = octx->src0.ne[2]; \
+ const uint32_t ne03 = octx->src0.ne[3]; \
+ \
+ const uint32_t ne10 = octx->src1.ne[0]; \
+ const uint32_t ne11 = octx->src1.ne[1]; \
+ const uint32_t ne12 = octx->src1.ne[2]; \
+ \
+ const uint32_t nb01 = octx->src0.nb[1]; \
+ const uint32_t nb02 = octx->src0.nb[2]; \
+ const uint32_t nb03 = octx->src0.nb[3]; \
+ \
+ const uint32_t nb10 = octx->src1.nb[0]; \
+ const uint32_t nb11 = octx->src1.nb[1]; \
+ const uint32_t nb12 = octx->src1.nb[2]; \
+ \
+ const uint32_t nb1 = octx->dst.nb[1]; \
+ const uint32_t nb2 = octx->dst.nb[2]; \
+ const uint32_t nb3 = octx->dst.nb[3]; \
+ \
+ const uint32_t ne1 = octx->dst.ne[1]; \
+ \
+ const uint32_t nr = ne01;
+
+static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+ set_rows_preamble;
+
+ // parallelize by rows of src0
+ const uint32_t dr = octx->src0_nrows_per_thread;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
+
+ const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
+
+ for (uint32_t i03 = 0; i03 < ne03; ++i03) {
+ for (uint32_t i02 = 0; i02 < ne02; ++i02) {
+ for (uint32_t i = ir0; i < ir1; ++i) {
+ const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
+ const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+ const uint32_t i10 = i;
+
+ const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
+
+ uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
+ if (i1 >= ne1) {
+ // ignore invalid indices
+ continue;
+ }
+
+ const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
+ const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
+
+ // copy row
+ hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
+ }
+ }
+ }
+
+ return HTP_STATUS_OK;
+}
+
+static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+ set_rows_preamble;
+
+ // parallelize by rows of src0
+ const uint32_t dr = octx->src0_nrows_per_thread;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
+
+ const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
+
+ for (uint32_t i03 = 0; i03 < ne03; ++i03) {
+ for (uint32_t i02 = 0; i02 < ne02; ++i02) {
+ for (uint32_t i = ir0; i < ir1; ++i) {
+ const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
+ const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+ const uint32_t i10 = i;
+
+ const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
+
+ uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
+ if (i1 >= ne1) {
+ // ignore invalid indices
+ continue;
+ }
+
+ const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
+ uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
+
+ hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
+ }
+ }
+ }
+
+ return HTP_STATUS_OK;
+}
+
+static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
+ set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
+}
+
+static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
+ set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
+}
+
+int op_set_rows(struct htp_ops_context * octx) {
+ set_rows_preamble;
+
+ if (octx->src0.type != HTP_TYPE_F32) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+ return HTP_STATUS_OK;
+ }
+
+ octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
+ octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
+
+ const uint32_t n_jobs = MIN(nr, octx->n_threads);
+ octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+
+ switch(octx->dst.type) {
+ case HTP_TYPE_F32:
+ worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
+ break;
+ case HTP_TYPE_F16:
+ worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
+ break;
+ default:
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ return HTP_STATUS_OK;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c
new file mode 100644
index 0000000..e91a16d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c
@@ -0,0 +1,395 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <math.h>
+#include <string.h>
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+#define htp_softmax_preamble3 \
+ const uint32_t ne00 = src0->ne[0]; \
+ const uint32_t ne01 = src0->ne[1]; \
+ const uint32_t ne02 = src0->ne[2]; \
+ const uint32_t ne03 = src0->ne[3]; \
+ \
+ const uint32_t nb00 = src0->nb[0]; \
+ const uint32_t nb01 = src0->nb[1]; \
+ const uint32_t nb02 = src0->nb[2]; \
+ const uint32_t nb03 = src0->nb[3]; \
+ \
+ const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \
+ const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \
+ const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \
+ const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \
+ \
+ const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \
+ const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \
+ const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \
+ const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \
+ \
+ const uint32_t ne0 = dst->ne[0]; \
+ const uint32_t ne1 = dst->ne[1]; \
+ const uint32_t ne2 = dst->ne[2]; \
+ const uint32_t ne3 = dst->ne[3]; \
+ \
+ const uint32_t nb0 = dst->nb[0]; \
+ const uint32_t nb1 = dst->nb[1]; \
+ const uint32_t nb2 = dst->nb[2]; \
+ const uint32_t nb3 = dst->nb[3];
+
+struct softmax_th_ctx {
+ bool use_f16;
+ bool use_src1;
+ uint32_t n_head;
+ uint32_t n_head_log2;
+
+ float scale;
+ float max_bias;
+ float m0;
+ float m1;
+
+ struct htp_ops_context * octx;
+};
+
+static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * src1 = &octx->src1;
+
+ memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
+
+ memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
+ memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
+
+ softmax_ctx->n_head = src0->ne[2];
+ softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
+
+ softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
+ softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
+
+ softmax_ctx->use_src1 = (src1->ne[0] != 0);
+ softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
+
+ softmax_ctx->octx = octx;
+}
+
+static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
+ uint8_t * restrict dst,
+ const int num_elems,
+ float scale,
+ const uint8_t * restrict mask,
+ float slope) {
+ const uint8_t * restrict src_curr = src;
+ uint8_t * restrict dst_curr = dst;
+ const uint8_t * restrict mask_curr = mask;
+
+ HVX_Vector scale_vec = hvx_vec_splat_f32(scale);
+ HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
+
+ int step_of_1 = num_elems >> 5;
+
+ #pragma unroll(4)
+ for (int i = 0; i < step_of_1; i++) {
+ HVX_Vector v1 = *(HVX_Vector *) src_curr;
+
+ HVX_Vector v3 = *(HVX_Vector *) mask_curr;
+
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);
+
+ HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v3, slope_vec);
+
+ HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, v4);
+
+ *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v5);
+
+ src_curr += VLEN;
+ dst_curr += VLEN;
+ mask_curr += VLEN;
+ }
+}
+
+static void hvx_fast_softmax_f32(const uint8_t * restrict src,
+ uint8_t * restrict dst,
+ uint8_t * restrict pad,
+ const int num_elems) {
+ const HVX_Vector * restrict v_src = (HVX_Vector *) src;
+ HVX_Vector * restrict v_pad = (HVX_Vector *) pad;
+ HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
+
+ HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000);
+ HVX_Vector max_vec = hvx_vec_splat_f32(((const float *) src)[0]);
+ HVX_Vector zero_v = Q6_V_vzero();
+ HVX_Vector one_v = hvx_vec_splat_f32(1.0);
+
+ int step_of_1 = num_elems >> 5;
+
+ #pragma unroll(4)
+ for (int i = 0; i < step_of_1; i++) {
+ HVX_Vector v1 = v_src[i];
+ max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
+ }
+
+ HVX_Vector v = hvx_vec_reduce_max_f32(max_vec);
+ max_vec = hvx_vec_repl4(v);
+
+ #pragma unroll(4)
+ for (int i = 0; i < step_of_1; i++) {
+ HVX_Vector v1 = v_src[i];
+ HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec);
+
+ HVX_Vector v3 = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(v2));
+
+ sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3);
+
+ v_pad[i] = v3;
+ }
+
+ v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
+ sum_vec = hvx_vec_repl4(v);
+
+ HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
+ HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);
+ HVX_Vector scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v);
+
+ #pragma unroll(4)
+ for (int i = 0; i < step_of_1; i++) {
+ HVX_Vector v1 = v_pad[i];
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);
+ v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
+ }
+}
+
+static float hvx_softmax_f32(const uint8_t * restrict src,
+ uint8_t * restrict dst,
+ uint8_t * restrict spad,
+ const int num_elems,
+ const float max) {
+ hvx_sub_scalar_f32(spad, src, max, num_elems);
+
+ hvx_exp_f32(spad, dst, num_elems, false);
+
+ float sum = hvx_reduce_sum_f32(dst, num_elems);
+
+ return sum;
+}
+
+static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
+ struct htp_ops_context * octx = softmax_ctx->octx;
+
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * src1 = &octx->src1;
+ const struct htp_tensor * dst = &octx->dst;
+
+ htp_softmax_preamble3;
+
+ uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
+ uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
+ uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1);
+
+ float * wp0 = (float *) src0_spad_data;
+ float * wp1 = (float *) src1_spad_data;
+ float * wp2 = (float *) dst_spad_data;
+
+ for (uint32_t i03 = 0; i03 < ne03; i03++) {
+ for (uint32_t i02 = 0; i02 < ne02; i02++) {
+ for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
+ const uint32_t i11 = i01;
+ const uint32_t i12 = i02 % ne12;
+ const uint32_t i13 = i03 % ne13;
+
+ // ALiBi
+ const uint32_t h = i02; // head
+
+ const float slope = (softmax_ctx->max_bias > 0.0f) ?
+ h < softmax_ctx->n_head_log2 ?
+ powf(softmax_ctx->m0, h + 1) :
+ powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
+ 1.0f;
+
+ float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+ float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+
+ // broadcast the mask across rows
+ __fp16 * mp_f16 = (softmax_ctx->use_src1) ?
+ (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
+ NULL;
+ float * mp_f32 = (softmax_ctx->use_src1) ?
+ (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
+ NULL;
+
+ if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
+ hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
+ (const uint8_t *) mp_f32, slope);
+ } else {
+ hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
+ if (mp_f32) {
+ if (softmax_ctx->use_f16) {
+ for (int i = 0; i < ne00; ++i) {
+ wp0[i] += slope * (float) mp_f16[i];
+ }
+ } else {
+ for (int i = 0; i < ne00; ++i) {
+ wp0[i] += slope * mp_f32[i];
+ }
+ }
+ }
+ }
+
+ if (1 == opt_path) {
+ hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
+ } else {
+ float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
+ float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
+ sum = sum > 0.0 ? (1.0 / sum) : 1;
+ hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
+ }
+ }
+ }
+ }
+}
+
+static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
+ struct htp_ops_context * octx = softmax_ctx->octx;
+
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * src1 = &octx->src1;
+ struct htp_tensor * dst = &octx->dst;
+
+ htp_softmax_preamble3;
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+ const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ int is_aligned = 1;
+ int opt_path = 0;
+ if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) {
+ is_aligned = 0;
+ FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n");
+ }
+ if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
+ opt_path = 1;
+ }
+
+ softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+ softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
+ ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
+ struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
+ softmax_job_f32_per_thread(p_softmax_ctx, n, i);
+}
+
+static int execute_op_softmax_f32(struct htp_ops_context * octx) {
+ int err = HTP_STATUS_OK;
+
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * src1 = &octx->src1;
+ struct htp_tensor * dst = &octx->dst;
+
+ worker_callback_t op_func;
+ const char * op_type = NULL;
+
+ struct softmax_th_ctx softmax_ctx;
+
+ switch (octx->op) {
+ case HTP_OP_SOFTMAX:
+ op_func = softmax_job_dispatcher_f32;
+ op_type = "softmax-f32";
+
+ init_softmax_ctx(&softmax_ctx, octx);
+ break;
+
+ default:
+ FARF(ERROR, "Unsupported Op %u\n", octx->op);
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ const uint32_t n_threads = octx->n_threads;
+
+ const size_t src0_row_size = src0->nb[1];
+ const size_t src1_row_size = src0_row_size;
+ const size_t dst_row_size = dst->nb[1];
+
+ // VTCM scratchpads for all tensors
+ // N rows per thread, padded to HVX vector size
+ octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
+ octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
+ octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
+
+ size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+
+ if (src1->ne[0]) {
+ FARF(HIGH,
+ "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
+ op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
+ src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
+ octx->dst_spad.size);
+ } else {
+ FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+ }
+
+ // Make sure the reserved vtcm size is sufficient
+ if (octx->ctx->vtcm_size < spad_size) {
+ FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
+ spad_size);
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+ octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+
+ uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ uint32_t n_jobs = MIN(n_threads, src0_nrows);
+ octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
+ }
+
+ return err;
+}
+
+int op_softmax(struct htp_ops_context * octx) {
+ int err = HTP_STATUS_OK;
+
+ switch (octx->src0.type) {
+ case HTP_TYPE_F32:
+ err = execute_op_softmax_f32(octx);
+ break;
+
+ default:
+ err = HTP_STATUS_NO_SUPPORT;
+ break;
+ }
+
+ return err;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c
new file mode 100644
index 0000000..62e45da
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c
@@ -0,0 +1,115 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <string.h>
+#include <math.h>
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+
+#define sum_rows_preamble \
+ struct htp_tensor *src0 = &octx->src0;\
+ struct htp_tensor *dst = &octx->dst; \
+ \
+ const uint32_t ne00 = src0->ne[0]; \
+ const uint32_t ne01 = src0->ne[1]; \
+ const uint32_t ne02 = src0->ne[2]; \
+ const uint32_t ne03 = src0->ne[3]; \
+ \
+ const uint32_t nb00 = src0->nb[0]; \
+ const uint32_t nb01 = src0->nb[1]; \
+ const uint32_t nb02 = src0->nb[2]; \
+ const uint32_t nb03 = src0->nb[3]; \
+ \
+ const uint32_t ne0 = dst->ne[0]; \
+ const uint32_t ne1 = dst->ne[1]; \
+ const uint32_t ne2 = dst->ne[2]; \
+ const uint32_t ne3 = dst->ne[3]; \
+ \
+ const uint32_t nb0 = dst->nb[0]; \
+ const uint32_t nb1 = dst->nb[1]; \
+ const uint32_t nb2 = dst->nb[2]; \
+ const uint32_t nb3 = dst->nb[3]; \
+
+static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+ sum_rows_preamble;
+
+ const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+ const size_t src0_row_size = nb01;
+ const size_t dst_row_size = nb1;
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return HTP_STATUS_OK;
+ }
+
+ int opt_path = 0;
+ if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
+ opt_path = 1;
+ }
+
+ const uint8_t * restrict data_src = (const uint8_t *) src0->data;
+ uint8_t * restrict data_dst = (uint8_t *) dst->data;
+
+ const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
+ float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
+
+ for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
+ const float * restrict src_local = src_th + (ir * ne00);
+
+ if (ir + 1 < src0_nrows_per_thread) {
+ hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
+ }
+
+ if (1 == opt_path) {
+ dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
+ } else {
+ dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
+ }
+ }
+
+ return HTP_STATUS_OK;
+}
+
+static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
+ sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
+}
+
+int op_sum_rows(struct htp_ops_context * octx) {
+ sum_rows_preamble;
+
+ if (octx->src0.type != HTP_TYPE_F32) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+ return HTP_STATUS_OK;
+ }
+
+ const int n_threads = octx->n_threads;
+ const uint32_t src0_nrows = ne01 * ne02 * ne03;
+
+ uint32_t n_jobs = MIN(n_threads, src0_nrows);
+ octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+
+ worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
+
+ return HTP_STATUS_OK;
+}
+
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c b/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c
new file mode 100644
index 0000000..ce879bf
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c
@@ -0,0 +1,342 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <math.h>
+#include <string.h>
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+#define htp_unary_preamble \
+ const uint32_t ne00 = src->ne[0]; \
+ const uint32_t ne01 = src->ne[1]; \
+ const uint32_t ne02 = src->ne[2]; \
+ const uint32_t ne03 = src->ne[3]; \
+ \
+ const uint32_t ne0 = dst->ne[0]; \
+ const uint32_t ne1 = dst->ne[1]; \
+ const uint32_t ne2 = dst->ne[2]; \
+ const uint32_t ne3 = dst->ne[3]; \
+ \
+ const uint32_t nb00 = src->nb[0]; \
+ const uint32_t nb01 = src->nb[1]; \
+ const uint32_t nb02 = src->nb[2]; \
+ const uint32_t nb03 = src->nb[3]; \
+ \
+ const uint32_t nb0 = dst->nb[0]; \
+ const uint32_t nb1 = dst->nb[1]; \
+ const uint32_t nb2 = dst->nb[2]; \
+ const uint32_t nb3 = dst->nb[3];
+
+static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
+ uint8_t * restrict dst,
+ uint8_t * restrict pad,
+ const int num_elems,
+ float epsilon) {
+ const HVX_Vector * restrict v_src = (HVX_Vector *) src;
+ HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
+
+ HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
+ HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
+
+ int step_of_1 = num_elems >> 5;
+ #pragma unroll(4)
+ for (int i = 0; i < step_of_1; i++) {
+ HVX_Vector v1 = v_src[i];
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
+ sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
+ }
+
+ HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
+ sum_v = hvx_vec_repl4(reduced_sum);
+
+ HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
+ HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
+ HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
+ HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
+
+ HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
+
+ #pragma unroll(4)
+ for (int i = 0; i < step_of_1; i++) {
+ HVX_Vector v1 = v_src[i];
+ HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
+ v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
+ }
+}
+
+static void scale_htp_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params,
+ int opt_path) {
+ float scale = 0.f;
+ float bias = 0.f;
+ memcpy(&scale, &op_params[0], sizeof(float));
+ memcpy(&bias, &op_params[1], sizeof(float));
+
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
+ const float * restrict src_local = src + (ir * row_elems);
+ float * restrict dst_local = dst + (ir * row_elems);
+
+ if (ir + 1 < num_rows) {
+ hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
+ }
+
+ hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
+ }
+}
+
+static void rms_norm_htp_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params,
+ int opt_path) {
+ float epsilon = 0.f;
+ memcpy(&epsilon, op_params, sizeof(float));
+
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
+ const float * restrict src_local = src + (ir * row_elems);
+ float * restrict dst_local = dst + (ir * row_elems);
+
+ if (ir + 1 < num_rows) {
+ hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
+ }
+
+ if (1 == opt_path) {
+ hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
+ } else {
+ float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
+
+ const float mean = sum / row_elems;
+ const float scale = 1.0f / sqrtf(mean + epsilon);
+
+ hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
+ }
+ }
+}
+
+static void sqr_htp_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params,
+ int opt_path) {
+
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
+ const float * restrict src_local = src + (ir * row_elems);
+ float * restrict dst_local = dst + (ir * row_elems);
+
+ if (ir + 1 < num_rows) {
+ hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
+ }
+
+ if (1 == opt_path) {
+ hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+ } else {
+ hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+ }
+ }
+}
+
+static void sqrt_htp_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params,
+ int opt_path) {
+
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
+ const float * restrict src_local = src + (ir * row_elems);
+ float * restrict dst_local = dst + (ir * row_elems);
+
+ if (ir + 1 < num_rows) {
+ hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
+ }
+
+ if (1 == opt_path) {
+ hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+ } else {
+ hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+ }
+ }
+}
+
+static void unary_job_f32_per_thread(const struct htp_tensor * src,
+ struct htp_tensor * dst,
+ uint8_t * spad,
+ int htp_op,
+ int32_t * op_params,
+ uint32_t nth,
+ uint32_t ith,
+ uint32_t src0_nrows_per_thread) {
+ htp_unary_preamble;
+
+ const size_t src0_row_size = nb01;
+ const size_t dst_row_size = nb1;
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ int is_aligned = 1;
+ int opt_path = 0;
+ if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
+ is_aligned = 0;
+ }
+ if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
+ opt_path = 1;
+ }
+
+ const uint8_t * restrict data_src = (const uint8_t *) src->data;
+ uint8_t * restrict data_dst = (uint8_t *) dst->data;
+
+ const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
+ float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
+ uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01);
+
+ switch (htp_op) {
+ case HTP_OP_RMS_NORM:
+ rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
+ break;
+ case HTP_OP_SCALE:
+ scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
+ break;
+ case HTP_OP_SQR:
+ sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
+ break;
+ case HTP_OP_SQRT:
+ sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
+ break;
+
+ default:
+ break;
+ }
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
+ src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
+ dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = (struct htp_ops_context *) data;
+
+ unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
+ octx->src0_nrows_per_thread);
+}
+
+static int execute_op_unary_f32(struct htp_ops_context * octx) {
+ int err = HTP_STATUS_OK;
+
+ const struct htp_tensor * src0 = &octx->src0;
+ struct htp_tensor * dst = &octx->dst;
+
+ worker_callback_t unary_op_func;
+ const char * op_type = NULL;
+
+ switch (octx->op) {
+ case HTP_OP_RMS_NORM:
+ unary_op_func = unary_job_dispatcher_f32;
+ op_type = "rmsnorm-f32";
+ break;
+ case HTP_OP_SCALE:
+ unary_op_func = unary_job_dispatcher_f32;
+ op_type = "scale-f32";
+ break;
+ case HTP_OP_SQR:
+ unary_op_func = unary_job_dispatcher_f32;
+ op_type = "sqr-f32";
+ break;
+ case HTP_OP_SQRT:
+ unary_op_func = unary_job_dispatcher_f32;
+ op_type = "sqrt-f32";
+ break;
+
+ default:
+ FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ const int n_threads = octx->n_threads;
+ const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+
+ const size_t src0_row_size = src0->nb[1];
+ const size_t dst_row_size = dst->nb[1];
+
+ // VTCM scratchpads for all tensors
+ octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
+ octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
+
+ size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
+
+ FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+
+ // Make sure the reserved vtcm size is sufficient
+ if (octx->ctx->vtcm_size < spad_size) {
+ FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
+ spad_size);
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ uint32_t n_jobs = MIN(n_threads, src0_nrows);
+
+ octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+
+ worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
+ }
+
+ return err;
+}
+
+int op_unary(struct htp_ops_context * octx) {
+ int err = HTP_STATUS_OK;
+
+ switch (octx->src0.type) {
+ case HTP_TYPE_F32:
+ err = execute_op_unary_f32(octx);
+ break;
+
+ default:
+ err = HTP_STATUS_NO_SUPPORT;
+ break;
+ }
+
+ return err;
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.c b/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.c
new file mode 100644
index 0000000..894815f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.c
@@ -0,0 +1,293 @@
+#include "worker-pool.h"
+
+#include <qurt.h>
+#include <stdatomic.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "HAP_farf.h"
+
+#define WORKER_THREAD_STACK_SZ (2 * 16384)
+#define LOWEST_USABLE_QURT_PRIO (254)
+
+struct worker_pool_s;
+
+// internal structure kept in thread-local storage per instance of worker pool
+typedef struct {
+ struct worker_pool_s * pool;
+ unsigned int id;
+} worker_context_t;
+
+// internal structure kept in thread-local storage per instance of worker pool
+typedef struct worker_pool_s {
+ worker_pool_job_t job[MAX_NUM_WORKERS]; // list of job descriptors
+ qurt_thread_t thread[MAX_NUM_WORKERS]; // thread ID's of the workers
+ worker_context_t context[MAX_NUM_WORKERS]; // worker contexts
+ void * stack[MAX_NUM_WORKERS]; // thread stack pointers
+ unsigned int n_threads; // number of workers in this pool
+
+ atomic_uint seqn; // seqno used to detect new jobs
+ atomic_uint next_job; // next job index
+ atomic_uint n_pending; // number of pending jobs
+ atomic_uint n_jobs; // number of current jobs
+ atomic_bool killed; // threads need to exit
+} worker_pool_t;
+
+static void worker_pool_main(void * context) {
+ worker_context_t * me = (worker_context_t *) context;
+ worker_pool_t * pool = me->pool;
+
+ FARF(HIGH, "worker-pool: thread %u started", me->id);
+
+ unsigned int prev_seqn = 0;
+ while (!atomic_load(&pool->killed)) {
+ unsigned int seqn = atomic_load(&pool->seqn);
+ if (seqn == prev_seqn) {
+ // Nothing to do
+ qurt_futex_wait(&pool->seqn, prev_seqn);
+ continue;
+ }
+
+ // New job
+ prev_seqn = seqn;
+
+ unsigned int n = atomic_load(&pool->n_jobs);
+ unsigned int i = atomic_fetch_add(&pool->next_job, 1);
+ if (i >= n) {
+ // Spurios wakeup
+ continue;
+ }
+
+ pool->job[i].func(n, i, pool->job[i].data);
+
+ atomic_fetch_sub(&pool->n_pending, 1);
+ }
+
+ FARF(HIGH, "worker-pool: thread %u stopped", me->id);
+}
+
+AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context, uint32_t n_threads, uint32_t stack_size) {
+ int err = 0;
+
+ if (NULL == context) {
+ FARF(ERROR, "NULL context passed to worker_pool_init().");
+ return AEE_EBADPARM;
+ }
+
+ // Allocations
+ int size = (stack_size * n_threads) + (sizeof(worker_pool_t));
+
+ unsigned char * mem_blob = (unsigned char *) malloc(size);
+ if (!mem_blob) {
+ FARF(ERROR, "Could not allocate memory for worker pool!!");
+ return AEE_ENOMEMORY;
+ }
+
+ worker_pool_t * me = (worker_pool_t *) (mem_blob + stack_size * n_threads);
+
+ // name for the first worker, useful in debugging threads
+ char name[19];
+ snprintf(name, 12, "0x%8x:", (int) me);
+ strcat(name, "worker0");
+ me->n_threads = n_threads;
+
+ // initializations
+ for (unsigned int i = 0; i < me->n_threads; i++) {
+ me->stack[i] = NULL;
+ me->thread[i] = 0;
+
+ me->context[i].id = i;
+ me->context[i].pool = me;
+ }
+
+ // initialize job queue
+ me->n_pending = 0;
+ me->n_jobs = 0;
+ me->next_job = 0;
+ me->seqn = 0;
+ me->killed = 0;
+
+ // launch the workers
+ qurt_thread_attr_t attr;
+ qurt_thread_attr_init(&attr);
+
+ for (unsigned int i = 0; i < me->n_threads; i++) {
+ // set up stack
+ me->stack[i] = mem_blob;
+ mem_blob += stack_size;
+ qurt_thread_attr_set_stack_addr(&attr, me->stack[i]);
+ qurt_thread_attr_set_stack_size(&attr, stack_size);
+
+ // set up name
+ qurt_thread_attr_set_name(&attr, name);
+ name[17] = (name[17] + 1);
+ // name threads context:worker0, context:worker1, .. (recycle at 9, but num threads should be less than that anyway)
+ if (name[17] > '9') {
+ name[17] = '0';
+ }
+
+ // set up priority - by default, match the creating thread's prio
+ int prio = qurt_thread_get_priority(qurt_thread_get_id());
+
+ if (prio < 1) {
+ prio = 1;
+ }
+ if (prio > LOWEST_USABLE_QURT_PRIO) {
+ prio = LOWEST_USABLE_QURT_PRIO;
+ }
+
+ qurt_thread_attr_set_priority(&attr, prio);
+
+ // launch
+ err = qurt_thread_create(&me->thread[i], &attr, worker_pool_main, (void *) &me->context[i]);
+ if (err) {
+ FARF(ERROR, "Could not launch worker threads!");
+ worker_pool_release((worker_pool_context_t *) &me);
+ return AEE_EQURTTHREADCREATE;
+ }
+ }
+ *context = (worker_pool_context_t *) me;
+ return AEE_SUCCESS;
+}
+
+AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads) {
+ return worker_pool_init_with_stack_size(context, n_threads, WORKER_THREAD_STACK_SZ);
+}
+
+// clean up worker pool
+void worker_pool_release(worker_pool_context_t * context) {
+ worker_pool_t * me = (worker_pool_t *) *context;
+
+ // if no worker pool exists, return error.
+ if (NULL == me) {
+ return;
+ }
+
+ atomic_store(&me->killed, 1);
+ atomic_fetch_add(&me->seqn, 1);
+ qurt_futex_wake(&me->seqn, me->n_threads);
+
+ // de-initializations
+ for (unsigned int i = 0; i < me->n_threads; i++) {
+ if (me->thread[i]) {
+ int status;
+ (void) qurt_thread_join(me->thread[i], &status);
+ }
+ }
+
+ // free allocated memory (were allocated as a single buffer starting at stack[0])
+ if (me->stack[0]) {
+ free(me->stack[0]);
+ }
+
+ *context = NULL;
+}
+
+// run jobs
+AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n) {
+ worker_pool_t * me = (worker_pool_t *) context;
+ if (NULL == me) {
+ FARF(ERROR, "worker-pool: invalid context");
+ return AEE_EBADPARM;
+ }
+
+ if (n > me->n_threads) {
+ FARF(ERROR, "worker-pool: invalid number of jobs %u for n-threads %u", n, me->n_threads);
+ return AEE_EBADPARM;
+ }
+
+ memcpy(me->job, job, sizeof(worker_pool_job_t) * n);
+
+ if (n > 1) {
+ atomic_store(&me->next_job, 1);
+ atomic_store(&me->n_jobs, n);
+ atomic_store(&me->n_pending, n - 1);
+
+ // wake up workers
+ atomic_fetch_add(&me->seqn, 1);
+ qurt_futex_wake(&me->seqn, n - 1);
+ }
+
+ // main thread runs job #0
+ me->job[0].func(n, 0, me->job[0].data);
+
+ if (n > 1) {
+ while (atomic_load(&me->n_pending))
+ ;
+ }
+
+ return 0;
+}
+
+// run func
+AEEResult worker_pool_run_func(worker_pool_context_t context, worker_callback_t func, void * data, unsigned int n) {
+ worker_pool_job_t job[n];
+
+ for (unsigned int i = 0; i < n; i++) {
+ job[i].func = func;
+ job[i].data = data;
+ }
+
+ return worker_pool_run_jobs(context, job, n);
+}
+
+AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio) {
+ worker_pool_t * me = (worker_pool_t *) context;
+
+ // if no worker pool exists, return error.
+ if (!me) {
+ return AEE_ENOMORE;
+ }
+
+ int result = AEE_SUCCESS;
+ if (prio < 1) {
+ prio = 1;
+ }
+ if (prio > LOWEST_USABLE_QURT_PRIO) {
+ prio = LOWEST_USABLE_QURT_PRIO;
+ }
+
+ for (unsigned int i = 0; i < me->n_threads; i++) {
+ int res = qurt_thread_set_priority(me->thread[i], (unsigned short) prio);
+ if (0 != res) {
+ result = AEE_EBADPARM;
+ FARF(ERROR, "QURT failed to set priority of thread %d, ERROR = %d", me->thread[i], res);
+ }
+ }
+
+ return result;
+}
+
+AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids) {
+ worker_pool_t * me = (worker_pool_t *) context;
+ if (!me) {
+ FARF(ERROR, "worker-pool: invalid context");
+ return AEE_EBADPARM;
+ ;
+ }
+
+ for (int i = 0; i < me->n_threads; i++) {
+ tids[i] = me->thread[i];
+ }
+
+ return AEE_SUCCESS;
+}
+
+AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio) {
+ worker_pool_t * me = (worker_pool_t *) context;
+ if (!me) {
+ FARF(ERROR, "worker-pool: invalid context");
+ return AEE_EBADPARM;
+ }
+
+ int priority = qurt_thread_get_priority(me->thread[0]);
+ if (priority > 0) {
+ *prio = priority;
+ return 0;
+ } else {
+ *prio = 0;
+ return AEE_EBADSTATE;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.h b/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.h
new file mode 100644
index 0000000..6f8c905
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.h
@@ -0,0 +1,57 @@
+#ifndef HTP_WORKER_POOL_H
+#define HTP_WORKER_POOL_H
+
+// MACRO enables function to be visible in shared-library case.
+#define WORKERPOOL_API __attribute__((visibility("default")))
+
+#include <AEEStdDef.h>
+#include <AEEStdErr.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/// signature of callbacks to be invoked by worker threads
+typedef void (*worker_callback_t)(unsigned int n, unsigned int i, void *);
+
+/// Typedef of worker_pool context
+typedef void * worker_pool_context_t;
+
+/// descriptor for requested callback
+typedef struct {
+ worker_callback_t func;
+ void * data;
+} worker_pool_job_t;
+
+/// Maximum supported number of worker threads.
+#define MAX_NUM_WORKERS 10
+
+// Initialize worker pool.
+WORKERPOOL_API AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads);
+
+// Initialize worker pool with custom stack size
+WORKERPOOL_API AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context,
+ uint32_t n_threads,
+ uint32_t stack_size);
+
+// Kill worker threads and release worker pool resources
+WORKERPOOL_API void worker_pool_release(worker_pool_context_t * context);
+
+// Run jobs with the worker pool.
+WORKERPOOL_API AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n);
+
+WORKERPOOL_API AEEResult worker_pool_run_func(worker_pool_context_t context,
+ worker_callback_t func,
+ void * data,
+ unsigned int n);
+
+WORKERPOOL_API AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio);
+WORKERPOOL_API AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio);
+WORKERPOOL_API AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // #ifndef HTP_WORKER_POOL_H
diff --git a/llama.cpp/ggml/src/ggml-hexagon/libdl.h b/llama.cpp/ggml/src/ggml-hexagon/libdl.h
new file mode 100644
index 0000000..8ca5016
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/libdl.h
@@ -0,0 +1,79 @@
+#pragma once
+
+#ifdef _WIN32
+# define WIN32_LEAN_AND_MEAN
+# ifndef NOMINMAX
+# define NOMINMAX
+# endif
+# include <windows.h>
+# include <winevt.h>
+#else
+# include <dlfcn.h>
+# include <unistd.h>
+#endif
+#include <filesystem>
+
+namespace fs = std::filesystem;
+
+#ifdef _WIN32
+
+using dl_handle = std::remove_pointer_t<HMODULE>;
+
+struct dl_handle_deleter {
+ void operator()(HMODULE handle) {
+ FreeLibrary(handle);
+ }
+};
+
+static inline dl_handle * dl_load_library(const fs::path & path) {
+ // suppress error dialogs for missing DLLs
+ DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+ SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+ HMODULE handle = LoadLibraryW(path.wstring().c_str());
+
+ SetErrorMode(old_mode);
+
+ return handle;
+}
+
+static inline void * dl_get_sym(dl_handle * handle, const char * name) {
+ DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+ SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+ void * p = (void *) GetProcAddress(handle, name);
+
+ SetErrorMode(old_mode);
+
+ return p;
+}
+
+static inline const char * dl_error() {
+ return "";
+}
+
+#else
+
+using dl_handle = void;
+
+struct dl_handle_deleter {
+ void operator()(void * handle) {
+ dlclose(handle);
+ }
+};
+
+static inline dl_handle * dl_load_library(const fs::path & path) {
+ dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
+ return handle;
+}
+
+static inline void * dl_get_sym(dl_handle * handle, const char * name) {
+ return dlsym(handle, name);
+}
+
+static inline const char * dl_error() {
+ const char *rslt = dlerror();
+ return rslt != nullptr ? rslt : "";
+}
+
+#endif
diff --git a/llama.cpp/ggml/src/ggml-hexagon/libggml-htp.inf b/llama.cpp/ggml/src/ggml-hexagon/libggml-htp.inf
new file mode 100644
index 0000000..656d2d9
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/libggml-htp.inf
@@ -0,0 +1,38 @@
+[Version]
+Signature = "$WINDOWS NT$"
+Class = ComputeAccelerator
+ClassGuid = {F01A9D53-3FF6-48D2-9F97-C8A7004BE10C}
+Provider = %GGML%
+DriverVer = 01/01/2026,1.0.0.0
+CatalogFile = libggml-htp.cat
+PnpLockDown = 1
+
+[DestinationDirs]
+Drivers_Dir = 6
+
+[SourceDisksNames]
+1 = %DiskId%
+
+[SourceDisksFiles]
+libggml-htp-v68.so = 1
+libggml-htp-v69.so = 1
+libggml-htp-v73.so = 1
+libggml-htp-v75.so = 1
+libggml-htp-v81.so = 1
+
+[ControlFlags]
+ExcludeFromSelect = *
+
+[DefaultInstall.NTarm64]
+CopyFiles=Drivers_Dir
+
+[Drivers_Dir]
+libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+
+[Strings]
+GGML = 'GGML'
+DiskId = 'GGML HTP library'
diff --git a/llama.cpp/ggml/src/ggml-hexagon/op-desc.h b/llama.cpp/ggml/src/ggml-hexagon/op-desc.h
new file mode 100644
index 0000000..a1e8ddd
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/op-desc.h
@@ -0,0 +1,153 @@
+#ifndef OP_DESC_H
+#define OP_DESC_H
+
+#define GGML_COMMON_IMPL_CPP
+#include "ggml-backend-impl.h"
+#include "ggml-common.h"
+
+#include <string>
+#include <stdio.h>
+
+struct op_desc {
+ char strides[64 * GGML_MAX_SRC];
+ char dims[64 * GGML_MAX_SRC];
+ char types[16 * GGML_MAX_SRC];
+ char buffs[64 * GGML_MAX_SRC];
+ char names[64 * GGML_MAX_SRC];
+
+ int format_tensor_dims(char * str, const struct ggml_tensor * t) {
+ if (t->ne[2] == 1 && t->ne[3] == 1) {
+ return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
+ } else {
+ return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
+ }
+ }
+
+ void format_op_dims(char * str, const struct ggml_tensor * t) {
+ char * p = str;
+
+ // append src0 and src1 (if any)
+ if (t->src[0]) {
+ p += format_tensor_dims(p, t->src[0]);
+
+ for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
+ p += sprintf(p, " x ");
+ p += format_tensor_dims(p, t->src[i]);
+ }
+
+ p += sprintf(p, " -> ");
+ }
+
+ // format self dims separately for better visual alignment
+ char self[64];
+ format_tensor_dims(self, t);
+
+ p += sprintf(p, "%s", self);
+ }
+
+ int format_tensor_strides(char * str, const struct ggml_tensor * t) {
+ const char * c = ggml_is_contiguous(t) ? "" : "!";
+
+ if (t->ne[2] == 1 && t->ne[3] == 1) {
+ return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
+ } else {
+ return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
+ }
+ }
+
+ void format_op_strides(char * str, const struct ggml_tensor * t) {
+ char * p = str;
+
+ // append src0 and src1 (if any)
+ if (t->src[0]) {
+ p += format_tensor_strides(p, t->src[0]);
+
+ for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
+ p += sprintf(p, " x ");
+ p += format_tensor_strides(p, t->src[i]);
+ }
+
+ p += sprintf(p, " -> ");
+ }
+
+ // format self dims separately for better visual alignment
+ char self[64];
+ format_tensor_strides(self, t);
+
+ p += sprintf(p, "%s", self);
+ }
+
+ void format_op_types(char * str, const struct ggml_tensor * t) {
+ char * p = str;
+
+ // append src0 and src1 (if any)
+ if (t->src[0]) {
+ p += sprintf(p, "%s", ggml_type_name(t->src[0]->type));
+
+ for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
+ p += sprintf(p, " x ");
+ p += sprintf(p, "%s", ggml_type_name(t->src[i]->type));
+ }
+
+ p += sprintf(p, " -> ");
+ }
+
+ p += sprintf(p, "%s", ggml_type_name(t->type));
+ }
+
+ const char * tensor_buff_name(const struct ggml_tensor * t) {
+ if (t->buffer) {
+ return ggml_backend_buffer_name(t->buffer);
+ }
+ return "NONE";
+ }
+
+ void format_op_buffs(char * str, const struct ggml_tensor * t) {
+ char * p = str;
+
+ // append src0 and src1 (if any)
+ if (t->src[0]) {
+ p += sprintf(p, "%s", tensor_buff_name(t->src[0]));
+
+ for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
+ p += sprintf(p, " x ");
+ p += sprintf(p, "%s", tensor_buff_name(t->src[i]));
+ }
+
+ p += sprintf(p, " -> ");
+ }
+
+ p += sprintf(p, "%s", tensor_buff_name(t));
+ }
+
+ void format_op_names(char * str, const struct ggml_tensor * t) {
+ char * p = str;
+
+ // append src0 and src1 (if any)
+ if (t->src[0]) {
+ p += sprintf(p, "%s", t->src[0]->name);
+
+ for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) {
+ p += sprintf(p, " x ");
+ p += sprintf(p, "%s", t->src[i]->name);
+ }
+
+ p += sprintf(p, " -> ");
+ }
+
+ p += sprintf(p, "%s", t->name);
+ }
+
+ void format(const ggml_tensor * op) {
+ format_op_dims(dims, op);
+ format_op_strides(strides, op);
+ format_op_types(types, op);
+ format_op_buffs(buffs, op);
+ format_op_names(names, op);
+ }
+
+ op_desc() {}
+ op_desc(const ggml_tensor * op) { format(op); }
+};
+
+#endif // OP_DESC_H