1// sample drv interface
  2
  3#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
  4#pragma clang diagnostic ignored "-Wmissing-prototypes"
  5#pragma clang diagnostic ignored "-Wsign-compare"
  6
  7#include <filesystem>
  8#include <set>
  9#include <sstream>
 10#include <string>
 11#ifdef _WIN32
 12#   define WIN32_LEAN_AND_MEAN
 13#   ifndef NOMINMAX
 14#       define NOMINMAX
 15#   endif
 16#   include <windows.h>
 17#   include <winevt.h>
 18#else
 19#    include <dlfcn.h>
 20#    include <unistd.h>
 21#endif
 22#include "ggml-impl.h"
 23#include "htp-drv.h"
 24#include "libdl.h"
 25
 26#include <domain.h>
 27
 28//
 29// Driver API types
 30//
 31
 32typedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size);
 33typedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size);
 34typedef void   (*rpcmem_free_pfn_t)(void * po);
 35typedef int    (*rpcmem_to_fd_pfn_t)(void * po);
 36
 37typedef AEEResult (*dspqueue_create_pfn_t)(int                 domain,
 38                                           uint32_t            flags,
 39                                           uint32_t            req_queue_size,
 40                                           uint32_t            resp_queue_size,
 41                                           dspqueue_callback_t packet_callback,
 42                                           dspqueue_callback_t error_callback,
 43                                           void *              callback_context,
 44                                           dspqueue_t *        queue);
 45typedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue);
 46typedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id);
 47typedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags,
 48                                          uint32_t num_buffers,
 49                                          struct dspqueue_buffer *buffers,
 50                                          uint32_t message_length,
 51                                          const uint8_t *message,
 52                                          uint32_t timeout_us);
 53typedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags,
 54                                         uint32_t max_buffers, uint32_t *num_buffers,
 55                                         struct dspqueue_buffer *buffers,
 56                                         uint32_t max_message_length,
 57                                         uint32_t *message_length, uint8_t *message,
 58                                         uint32_t timeout_us);
 59
 60typedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags);
 61typedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length);
 62
 63typedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph);
 64typedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra);
 65typedef int (*remote_handle64_close_pfn_t)(remote_handle h);
 66typedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen);
 67typedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen);
 68typedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen);
 69
 70//
 71// Driver API pfns
 72//
 73
 74rpcmem_alloc_pfn_t  rpcmem_alloc_pfn  = nullptr;
 75rpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr;
 76rpcmem_free_pfn_t   rpcmem_free_pfn   = nullptr;
 77rpcmem_to_fd_pfn_t  rpcmem_to_fd_pfn  = nullptr;
 78
 79fastrpc_mmap_pfn_t   fastrpc_mmap_pfn   = nullptr;
 80fastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr;
 81
 82dspqueue_create_pfn_t dspqueue_create_pfn = nullptr;
 83dspqueue_close_pfn_t  dspqueue_close_pfn  = nullptr;
 84dspqueue_export_pfn_t dspqueue_export_pfn = nullptr;
 85dspqueue_write_pfn_t  dspqueue_write_pfn  = nullptr;
 86dspqueue_read_pfn_t   dspqueue_read_pfn   = nullptr;
 87
 88remote_handle64_open_pfn_t    remote_handle64_open_pfn    = nullptr;
 89remote_handle64_invoke_pfn_t  remote_handle64_invoke_pfn  = nullptr;
 90remote_handle64_close_pfn_t   remote_handle64_close_pfn   = nullptr;
 91remote_handle_control_pfn_t   remote_handle_control_pfn   = nullptr;
 92remote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr;
 93remote_session_control_pfn_t  remote_session_control_pfn  = nullptr;
 94
 95//
 96// Driver API
 97//
 98
 99void * rpcmem_alloc(int heapid, uint32_t flags, int size) {
100    return rpcmem_alloc_pfn(heapid, flags, size);
101}
102
103void * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) {
104    if (rpcmem_alloc2_pfn) {
105        return rpcmem_alloc2_pfn(heapid, flags, size);
106    } else {
107        GGML_LOG_INFO("ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\n");
108        return rpcmem_alloc_pfn(heapid, flags, size);
109    }
110}
111
112void rpcmem_free(void * po) {
113    return rpcmem_free_pfn(po);
114}
115
116int rpcmem_to_fd(void * po) {
117    return rpcmem_to_fd_pfn(po);
118}
119
120HTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) {
121    return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags);
122}
123
124HTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) {
125    return fastrpc_munmap_pfn(domain, fd, addr, length);
126}
127
128AEEResult dspqueue_create(int                 domain,
129                          uint32_t            flags,
130                          uint32_t            req_queue_size,
131                          uint32_t            resp_queue_size,
132                          dspqueue_callback_t packet_callback,
133                          dspqueue_callback_t error_callback,
134                          void *              callback_context,
135                          dspqueue_t *        queue) {
136    return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback,
137                               callback_context, queue);
138}
139
140AEEResult dspqueue_close(dspqueue_t queue) {
141    return dspqueue_close_pfn(queue);
142}
143
144AEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) {
145    return dspqueue_export_pfn(queue, queue_id);
146}
147
148AEEResult dspqueue_write(dspqueue_t               queue,
149                         uint32_t                 flags,
150                         uint32_t                 num_buffers,
151                         struct dspqueue_buffer * buffers,
152                         uint32_t                 message_length,
153                         const uint8_t *          message,
154                         uint32_t                 timeout_us) {
155    return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us);
156}
157
158AEEResult dspqueue_read(dspqueue_t               queue,
159                        uint32_t *               flags,
160                        uint32_t                 max_buffers,
161                        uint32_t *               num_buffers,
162                        struct dspqueue_buffer * buffers,
163                        uint32_t                 max_message_length,
164                        uint32_t *               message_length,
165                        uint8_t *                message,
166                        uint32_t                 timeout_us) {
167    return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length,
168                             message, timeout_us);
169}
170
171HTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) {
172    return remote_handle64_open_pfn(name, ph);
173}
174
175HTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) {
176    return remote_handle64_invoke_pfn(h, dwScalars, pra);
177}
178
179HTPDRV_API int remote_handle64_close(remote_handle64 h) {
180    return remote_handle64_close_pfn(h);
181}
182
183HTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) {
184    return remote_handle_control_pfn(req, data, datalen);
185}
186
187HTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) {
188    return remote_handle64_control_pfn(h, req, data, datalen);
189}
190
191HTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) {
192    return remote_session_control_pfn(req, data, datalen);
193}
194
195#ifdef _WIN32
196
197static std::string wstr_to_str(std::wstring_view wstr) {
198    std::string result;
199    if (wstr.empty()) {
200        return result;
201    }
202    auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
203                                            wstr.data(), (int) wstr.size(),
204                                            nullptr, 0, nullptr, nullptr);
205    if (bytes_needed == 0) {
206        GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
207        throw std::runtime_error("Invalid wstring input");
208    }
209
210    result.resize(bytes_needed, '\0');
211    int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
212                                            wstr.data(), (int) wstr.size(),
213                                            result.data(), bytes_needed,
214                                            nullptr, nullptr);
215    if (bytes_written == 0) {
216        GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
217        throw std::runtime_error("Wstring conversion failed");
218    }
219    return result;
220}
221
222static std::string get_driver_path() {
223    std::wstring serviceName = L"qcnspmcdm";
224    std::string result;
225
226    // Get a handle to the SCM database.
227    SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ);
228    if (nullptr == schSCManager) {
229        GGML_LOG_ERROR("ggml-hex: Failed to open SCManager. Error: %lu\n", GetLastError());
230        return result;
231    }
232
233    // Get a handle to the service.
234    SC_HANDLE schService = OpenServiceW(schSCManager,           // SCM database
235                                        serviceName.c_str(),    // name of service
236                                        SERVICE_QUERY_CONFIG);  // need query config access
237
238    if (nullptr == schService) {
239        GGML_LOG_ERROR("ggml-hex: Failed to open qcnspmcdm service. Error: %lu\n", GetLastError());
240        CloseServiceHandle(schSCManager);
241        return result;
242    }
243
244    // Store the size of buffer used as an output.
245    DWORD bufferSize;
246    if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) &&
247        (GetLastError() != ERROR_INSUFFICIENT_BUFFER)) {
248        GGML_LOG_ERROR("ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
249        CloseServiceHandle(schService);
250        CloseServiceHandle(schSCManager);
251        return result;
252    }
253    // Get the configuration of the service.
254    LPQUERY_SERVICE_CONFIGW serviceConfig =
255        static_cast<LPQUERY_SERVICE_CONFIGW>(LocalAlloc(LMEM_FIXED, bufferSize));
256    if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) {
257        fprintf(stderr, "ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
258        LocalFree(serviceConfig);
259        CloseServiceHandle(schService);
260        CloseServiceHandle(schSCManager);
261        return result;
262    }
263
264    // Read the driver file path get its parent directory
265    std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName);
266    driverPath = driverPath.substr(0, driverPath.find_last_of(L"\\"));
267
268    // Clean up resources
269    LocalFree(serviceConfig);
270    CloseServiceHandle(schService);
271    CloseServiceHandle(schSCManager);
272
273    // Driver path would contain invalid path string, like:
274    // \SystemRoot\System32\DriverStore\FileRepository\qcadsprpc8280.inf_arm64_c2b9460c9a072f37
275    // "\SystemRoot" should be replace with a correct one (e.g. C:\Windows)
276    const std::wstring systemRootPlaceholder = L"\\SystemRoot";
277    if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) {
278        GGML_LOG_ERROR("ggml-hex: String pattern not found in driver path.\n");
279        return result;
280    }
281
282    // Replace \SystemRoot with an absolute path from system ENV windir
283    const std::wstring systemRootEnv = L"windir";
284
285    // Query the number of wide charactors this variable requires
286    DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0);
287    if (numWords == 0) {
288        GGML_LOG_ERROR("ggml-hex: Failed get systemRoot environment variable\n");
289        return result;
290    }
291
292    // Query the actual system root name from environment variable
293    std::vector<wchar_t> systemRoot(numWords + 1);
294    numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1);
295    if (numWords == 0) {
296        GGML_LOG_ERROR("ggml-hex: Failed to read windir environment variable\n");
297        return result;
298    }
299    driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data()));
300
301    return wstr_to_str(driverPath);
302}
303
304#endif
305
306using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
307
308int htpdrv_init() {
309    static dl_handle_ptr lib_cdsp_rpc_handle = nullptr;
310    static bool initialized = false;
311#ifdef _WIN32
312    std::string drv_path = get_driver_path() + "\\" + "libcdsprpc.dll";
313#else
314    std::string drv_path = "libcdsprpc.so";
315#endif
316    if (initialized) {
317        GGML_LOG_INFO("ggml-hex: Driver already loaded\n");
318        return AEE_SUCCESS;
319    }
320    GGML_LOG_INFO("ggml-hex: Loading driver %s\n", drv_path.c_str());
321
322    fs::path path{ drv_path.c_str() };
323    dl_handle_ptr handle { dl_load_library(path) };
324    if (!handle) {
325        GGML_LOG_ERROR("ggml-hex: failed to load %s: %s\n", path.u8string().c_str(), dl_error());
326        return AEE_EUNABLETOLOAD;
327    }
328
329#define dlsym(drv, type, pfn, symbol, ignore)                               \
330    do {                                                                    \
331        pfn = (type) dl_get_sym(drv, #symbol);                              \
332        if (!ignore && nullptr == pfn) {                                    \
333            GGML_LOG_ERROR("ggml-hex: failed to dlsym %s\n", #symbol);      \
334            return AEE_EUNABLETOLOAD;                                       \
335        }                                                                   \
336    } while (0)
337
338    dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false);
339    dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true);
340    dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false);
341    dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false);
342    dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false);
343    dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false);
344    dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false);
345    dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false);
346    dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false);
347    dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false);
348    dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false);
349    dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false);
350    dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false);
351    dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false);
352    dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false);
353    dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false);
354    dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false);
355
356    lib_cdsp_rpc_handle = std::move(handle);
357    initialized         = true;
358
359    return AEE_SUCCESS;
360}
361
362domain * get_domain(int domain_id) {
363    int i    = 0;
364    int size = sizeof(supported_domains) / sizeof(domain);
365
366    for (i = 0; i < size; i++) {
367        if (supported_domains[i].id == domain_id) {
368            return &supported_domains[i];
369        }
370    }
371
372    return NULL;
373}
374
375int get_hex_arch_ver(int domain, int * arch) {
376    if (!remote_handle_control_pfn) {
377        GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
378        return AEE_EUNSUPPORTEDAPI;
379    }
380
381    struct remote_dsp_capability arch_ver;
382    arch_ver.domain       = (uint32_t) domain;
383    arch_ver.attribute_ID = ARCH_VER;
384    arch_ver.capability   = (uint32_t) 0;
385
386    int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
387    if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
388        GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
389        return AEE_EUNSUPPORTEDAPI;
390    }
391
392    if (err != AEE_SUCCESS) {
393        GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
394        return err;
395    }
396
397    switch (arch_ver.capability & 0xff) {
398        case 0x68:
399            *arch = 68;
400            return 0;
401        case 0x69:
402            *arch = 69;
403            return 0;
404        case 0x73:
405            *arch = 73;
406            return 0;
407        case 0x75:
408            *arch = 75;
409            return 0;
410        case 0x79:
411            *arch = 79;
412            return 0;
413        case 0x81:
414            *arch = 81;
415            return 0;
416    }
417    return -1;
418}