summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-hexagon/htp-drv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-hexagon/htp-drv.cpp')
-rw-r--r--llama.cpp/ggml/src/ggml-hexagon/htp-drv.cpp418
1 files changed, 418 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-hexagon/htp-drv.cpp b/llama.cpp/ggml/src/ggml-hexagon/htp-drv.cpp
new file mode 100644
index 0000000..2530bb0
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-hexagon/htp-drv.cpp
@@ -0,0 +1,418 @@
+// sample drv interface
+
+#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+#pragma clang diagnostic ignored "-Wsign-compare"
+
+#include <filesystem>
+#include <set>
+#include <sstream>
+#include <string>
+#ifdef _WIN32
+# define WIN32_LEAN_AND_MEAN
+# ifndef NOMINMAX
+# define NOMINMAX
+# endif
+# include <windows.h>
+# include <winevt.h>
+#else
+# include <dlfcn.h>
+# include <unistd.h>
+#endif
+#include "ggml-impl.h"
+#include "htp-drv.h"
+#include "libdl.h"
+
+#include <domain.h>
+
+//
+// Driver API types
+//
+
+typedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size);
+typedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size);
+typedef void (*rpcmem_free_pfn_t)(void * po);
+typedef int (*rpcmem_to_fd_pfn_t)(void * po);
+
+typedef AEEResult (*dspqueue_create_pfn_t)(int domain,
+ uint32_t flags,
+ uint32_t req_queue_size,
+ uint32_t resp_queue_size,
+ dspqueue_callback_t packet_callback,
+ dspqueue_callback_t error_callback,
+ void * callback_context,
+ dspqueue_t * queue);
+typedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue);
+typedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id);
+typedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags,
+ uint32_t num_buffers,
+ struct dspqueue_buffer *buffers,
+ uint32_t message_length,
+ const uint8_t *message,
+ uint32_t timeout_us);
+typedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags,
+ uint32_t max_buffers, uint32_t *num_buffers,
+ struct dspqueue_buffer *buffers,
+ uint32_t max_message_length,
+ uint32_t *message_length, uint8_t *message,
+ uint32_t timeout_us);
+
+typedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags);
+typedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length);
+
+typedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph);
+typedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra);
+typedef int (*remote_handle64_close_pfn_t)(remote_handle h);
+typedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen);
+typedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen);
+typedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen);
+
+//
+// Driver API pfns
+//
+
+rpcmem_alloc_pfn_t rpcmem_alloc_pfn = nullptr;
+rpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr;
+rpcmem_free_pfn_t rpcmem_free_pfn = nullptr;
+rpcmem_to_fd_pfn_t rpcmem_to_fd_pfn = nullptr;
+
+fastrpc_mmap_pfn_t fastrpc_mmap_pfn = nullptr;
+fastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr;
+
+dspqueue_create_pfn_t dspqueue_create_pfn = nullptr;
+dspqueue_close_pfn_t dspqueue_close_pfn = nullptr;
+dspqueue_export_pfn_t dspqueue_export_pfn = nullptr;
+dspqueue_write_pfn_t dspqueue_write_pfn = nullptr;
+dspqueue_read_pfn_t dspqueue_read_pfn = nullptr;
+
+remote_handle64_open_pfn_t remote_handle64_open_pfn = nullptr;
+remote_handle64_invoke_pfn_t remote_handle64_invoke_pfn = nullptr;
+remote_handle64_close_pfn_t remote_handle64_close_pfn = nullptr;
+remote_handle_control_pfn_t remote_handle_control_pfn = nullptr;
+remote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr;
+remote_session_control_pfn_t remote_session_control_pfn = nullptr;
+
+//
+// Driver API
+//
+
+void * rpcmem_alloc(int heapid, uint32_t flags, int size) {
+ return rpcmem_alloc_pfn(heapid, flags, size);
+}
+
+void * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) {
+ if (rpcmem_alloc2_pfn) {
+ return rpcmem_alloc2_pfn(heapid, flags, size);
+ } else {
+ GGML_LOG_INFO("ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\n");
+ return rpcmem_alloc_pfn(heapid, flags, size);
+ }
+}
+
+void rpcmem_free(void * po) {
+ return rpcmem_free_pfn(po);
+}
+
+int rpcmem_to_fd(void * po) {
+ return rpcmem_to_fd_pfn(po);
+}
+
+HTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) {
+ return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags);
+}
+
+HTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) {
+ return fastrpc_munmap_pfn(domain, fd, addr, length);
+}
+
+AEEResult dspqueue_create(int domain,
+ uint32_t flags,
+ uint32_t req_queue_size,
+ uint32_t resp_queue_size,
+ dspqueue_callback_t packet_callback,
+ dspqueue_callback_t error_callback,
+ void * callback_context,
+ dspqueue_t * queue) {
+ return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback,
+ callback_context, queue);
+}
+
+AEEResult dspqueue_close(dspqueue_t queue) {
+ return dspqueue_close_pfn(queue);
+}
+
+AEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) {
+ return dspqueue_export_pfn(queue, queue_id);
+}
+
+AEEResult dspqueue_write(dspqueue_t queue,
+ uint32_t flags,
+ uint32_t num_buffers,
+ struct dspqueue_buffer * buffers,
+ uint32_t message_length,
+ const uint8_t * message,
+ uint32_t timeout_us) {
+ return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us);
+}
+
+AEEResult dspqueue_read(dspqueue_t queue,
+ uint32_t * flags,
+ uint32_t max_buffers,
+ uint32_t * num_buffers,
+ struct dspqueue_buffer * buffers,
+ uint32_t max_message_length,
+ uint32_t * message_length,
+ uint8_t * message,
+ uint32_t timeout_us) {
+ return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length,
+ message, timeout_us);
+}
+
+HTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) {
+ return remote_handle64_open_pfn(name, ph);
+}
+
+HTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) {
+ return remote_handle64_invoke_pfn(h, dwScalars, pra);
+}
+
+HTPDRV_API int remote_handle64_close(remote_handle64 h) {
+ return remote_handle64_close_pfn(h);
+}
+
+HTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) {
+ return remote_handle_control_pfn(req, data, datalen);
+}
+
+HTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) {
+ return remote_handle64_control_pfn(h, req, data, datalen);
+}
+
+HTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) {
+ return remote_session_control_pfn(req, data, datalen);
+}
+
+#ifdef _WIN32
+
+static std::string wstr_to_str(std::wstring_view wstr) {
+ std::string result;
+ if (wstr.empty()) {
+ return result;
+ }
+ auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
+ wstr.data(), (int) wstr.size(),
+ nullptr, 0, nullptr, nullptr);
+ if (bytes_needed == 0) {
+ GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
+ throw std::runtime_error("Invalid wstring input");
+ }
+
+ result.resize(bytes_needed, '\0');
+ int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
+ wstr.data(), (int) wstr.size(),
+ result.data(), bytes_needed,
+ nullptr, nullptr);
+ if (bytes_written == 0) {
+ GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
+ throw std::runtime_error("Wstring conversion failed");
+ }
+ return result;
+}
+
+static std::string get_driver_path() {
+ std::wstring serviceName = L"qcnspmcdm";
+ std::string result;
+
+ // Get a handle to the SCM database.
+ SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ);
+ if (nullptr == schSCManager) {
+ GGML_LOG_ERROR("ggml-hex: Failed to open SCManager. Error: %lu\n", GetLastError());
+ return result;
+ }
+
+ // Get a handle to the service.
+ SC_HANDLE schService = OpenServiceW(schSCManager, // SCM database
+ serviceName.c_str(), // name of service
+ SERVICE_QUERY_CONFIG); // need query config access
+
+ if (nullptr == schService) {
+ GGML_LOG_ERROR("ggml-hex: Failed to open qcnspmcdm service. Error: %lu\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return result;
+ }
+
+ // Store the size of buffer used as an output.
+ DWORD bufferSize;
+ if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) &&
+ (GetLastError() != ERROR_INSUFFICIENT_BUFFER)) {
+ GGML_LOG_ERROR("ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+ return result;
+ }
+ // Get the configuration of the service.
+ LPQUERY_SERVICE_CONFIGW serviceConfig =
+ static_cast<LPQUERY_SERVICE_CONFIGW>(LocalAlloc(LMEM_FIXED, bufferSize));
+ if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) {
+ fprintf(stderr, "ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
+ LocalFree(serviceConfig);
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+ return result;
+ }
+
+ // Read the driver file path get its parent directory
+ std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName);
+ driverPath = driverPath.substr(0, driverPath.find_last_of(L"\\"));
+
+ // Clean up resources
+ LocalFree(serviceConfig);
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+
+ // Driver path would contain invalid path string, like:
+ // \SystemRoot\System32\DriverStore\FileRepository\qcadsprpc8280.inf_arm64_c2b9460c9a072f37
+ // "\SystemRoot" should be replace with a correct one (e.g. C:\Windows)
+ const std::wstring systemRootPlaceholder = L"\\SystemRoot";
+ if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) {
+ GGML_LOG_ERROR("ggml-hex: String pattern not found in driver path.\n");
+ return result;
+ }
+
+ // Replace \SystemRoot with an absolute path from system ENV windir
+ const std::wstring systemRootEnv = L"windir";
+
+ // Query the number of wide charactors this variable requires
+ DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0);
+ if (numWords == 0) {
+ GGML_LOG_ERROR("ggml-hex: Failed get systemRoot environment variable\n");
+ return result;
+ }
+
+ // Query the actual system root name from environment variable
+ std::vector<wchar_t> systemRoot(numWords + 1);
+ numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1);
+ if (numWords == 0) {
+ GGML_LOG_ERROR("ggml-hex: Failed to read windir environment variable\n");
+ return result;
+ }
+ driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data()));
+
+ return wstr_to_str(driverPath);
+}
+
+#endif
+
+using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
+
+int htpdrv_init() {
+ static dl_handle_ptr lib_cdsp_rpc_handle = nullptr;
+ static bool initialized = false;
+#ifdef _WIN32
+ std::string drv_path = get_driver_path() + "\\" + "libcdsprpc.dll";
+#else
+ std::string drv_path = "libcdsprpc.so";
+#endif
+ if (initialized) {
+ GGML_LOG_INFO("ggml-hex: Driver already loaded\n");
+ return AEE_SUCCESS;
+ }
+ GGML_LOG_INFO("ggml-hex: Loading driver %s\n", drv_path.c_str());
+
+ fs::path path{ drv_path.c_str() };
+ dl_handle_ptr handle { dl_load_library(path) };
+ if (!handle) {
+ GGML_LOG_ERROR("ggml-hex: failed to load %s: %s\n", path.u8string().c_str(), dl_error());
+ return AEE_EUNABLETOLOAD;
+ }
+
+#define dlsym(drv, type, pfn, symbol, ignore) \
+ do { \
+ pfn = (type) dl_get_sym(drv, #symbol); \
+ if (!ignore && nullptr == pfn) { \
+ GGML_LOG_ERROR("ggml-hex: failed to dlsym %s\n", #symbol); \
+ return AEE_EUNABLETOLOAD; \
+ } \
+ } while (0)
+
+ dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false);
+ dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true);
+ dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false);
+ dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false);
+ dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false);
+ dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false);
+ dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false);
+ dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false);
+ dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false);
+ dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false);
+ dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false);
+ dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false);
+ dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false);
+ dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false);
+ dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false);
+ dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false);
+ dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false);
+
+ lib_cdsp_rpc_handle = std::move(handle);
+ initialized = true;
+
+ return AEE_SUCCESS;
+}
+
+domain * get_domain(int domain_id) {
+ int i = 0;
+ int size = sizeof(supported_domains) / sizeof(domain);
+
+ for (i = 0; i < size; i++) {
+ if (supported_domains[i].id == domain_id) {
+ return &supported_domains[i];
+ }
+ }
+
+ return NULL;
+}
+
+int get_hex_arch_ver(int domain, int * arch) {
+ if (!remote_handle_control_pfn) {
+ GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
+ return AEE_EUNSUPPORTEDAPI;
+ }
+
+ struct remote_dsp_capability arch_ver;
+ arch_ver.domain = (uint32_t) domain;
+ arch_ver.attribute_ID = ARCH_VER;
+ arch_ver.capability = (uint32_t) 0;
+
+ int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
+ if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
+ GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
+ return AEE_EUNSUPPORTEDAPI;
+ }
+
+ if (err != AEE_SUCCESS) {
+ GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
+ return err;
+ }
+
+ switch (arch_ver.capability & 0xff) {
+ case 0x68:
+ *arch = 68;
+ return 0;
+ case 0x69:
+ *arch = 69;
+ return 0;
+ case 0x73:
+ *arch = 73;
+ return 0;
+ case 0x75:
+ *arch = 75;
+ return 0;
+ case 0x79:
+ *arch = 79;
+ return 0;
+ case 0x81:
+ *arch = 81;
+ return 0;
+ }
+ return -1;
+}