1#include <assert.h>
2#include <inttypes.h>
3#include <stdio.h>
4#include <stdlib.h>
5#include <string.h>
6#include <time.h>
7
8#include <atomic>
9#include <chrono>
10#include <cstddef>
11#include <mutex>
12#include <stdexcept>
13#include <string>
14
15#ifdef _WIN32
16# include <sal.h>
17#else
18# include <semaphore.h>
19# include <unistd.h>
20#endif
21
22#pragma clang diagnostic ignored "-Wnested-anon-types"
23#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
24
25#include <AEEStdErr.h>
26#include <dspqueue.h>
27#include <rpcmem.h>
28
29#define GGML_COMMON_IMPL_CPP
30#include "ggml-backend-impl.h"
31#include "ggml-common.h"
32#include "ggml-hexagon.h"
33#include "ggml-impl.h"
34#include "ggml-quants.h"
35#include "op-desc.h"
36#include "htp-msg.h"
37#include "htp_iface.h"
38#include "htp-drv.h"
39
40static size_t opt_ndev = 1;
41static size_t opt_nhvx = 0; // use all
42static int opt_arch = 0; // autodetect
43static int opt_etm = 0;
44static int opt_verbose = 0;
45static int opt_profile = 0;
46static int opt_hostbuf = 1; // hostbuf ON by default
47static int opt_experimental = 0;
48
49// Enable all stages by default
50static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;
51static int opt_opsync = 0; // synchronous ops
52
53#define HEX_VERBOSE(...) \
54 if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__)
55
56static inline uint64_t hex_is_aligned(void * addr, uint32_t align) {
57 return ((size_t) addr & (align - 1)) == 0;
58}
59
60static inline size_t hex_round_up(size_t n, size_t m) {
61 return m * ((n + m - 1) / m);
62}
63
64static const char * status_to_str(uint32_t status) {
65 switch (status) {
66 case HTP_STATUS_OK:
67 return "OK";
68 case HTP_STATUS_NO_SUPPORT:
69 return "NO-SUPPORT";
70 case HTP_STATUS_INVAL_PARAMS:
71 return "INVAL-PARAMS";
72 case HTP_STATUS_VTCM_TOO_SMALL:
73 return "VTCM-TOO-SMALL";
74 case HTP_STATUS_INTERNAL_ERR:
75 return "INTERNAL-ERROR";
76 default:
77 return "UNKNOWN";
78 }
79}
80
81// ** debug helpers
82
83static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_tensor * op, const uint32_t req_flags) {
84 if (!opt_verbose) return;
85
86 op_desc desc(op);
87 GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(),
88 ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags);
89}
90
91static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) {
92 if (!opt_verbose) return;
93
94 op_desc desc(op);
95 GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(),
96 ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no");
97}
98
99static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op,
100 uint32_t op_usec, uint32_t op_cycles, uint32_t op_pkts, uint64_t call_usec) {
101 if (!opt_profile) return;
102
103 op_desc desc(op);
104 GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(),
105 ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs,
106 op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec);
107}
108
109// ** backend sessions
110
111struct ggml_hexagon_session {
112 ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false);
113 ~ggml_hexagon_session() noexcept(true);
114
115 void allocate(int dev_id) noexcept(false);
116 void release() noexcept(true);
117
118 void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false);
119 void flush();
120
121 ggml_backend_buffer_type buffer_type = {};
122 ggml_backend_buffer_type repack_buffer_type = {};
123
124 std::string name;
125 remote_handle64 handle;
126 dspqueue_t queue;
127 uint32_t session_id;
128 uint32_t domain_id;
129 uint64_t queue_id;
130 int dev_id;
131 bool valid_session;
132 bool valid_handle;
133 bool valid_queue;
134 bool valid_iface;
135 std::atomic<int> op_pending;
136 uint32_t prof_usecs;
137 uint32_t prof_cycles;
138 uint32_t prof_pkts;
139};
140
141void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) {
142 // Bump pending flag (cleared in the session::flush once we get the responce)
143 this->op_pending++; // atomic inc
144
145 int err = dspqueue_write(this->queue,
146 0, // flags - the framework will autoset this
147 n_bufs, // number of buffers
148 bufs, // buffer references
149 sizeof(req), // Message length
150 (const uint8_t *) &req, // Message
151 DSPQUEUE_TIMEOUT // Timeout
152 );
153
154 if (err != 0) {
155 GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
156 }
157
158 if (sync) {
159 flush();
160 }
161}
162
163// Flush HTP response queue i.e wait for all outstanding requests to complete
164void ggml_hexagon_session::flush() {
165 dspqueue_t q = this->queue;
166
167 // Repeatedly read packets from the queue until it's empty. We don't
168 // necessarily get a separate callback for each packet, and new packets
169 // may arrive while we're processing the previous one.
170
171 while (this->op_pending) {
172 struct htp_general_rsp rsp;
173 uint32_t rsp_size;
174 uint32_t flags;
175
176 struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
177 uint32_t n_bufs;
178
179 // Read response packet from queue
180 int err = dspqueue_read(q, &flags,
181 HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
182 &n_bufs, // Number of buffer references
183 bufs, // Buffer references
184 sizeof(rsp), // Max message length
185 &rsp_size, // Message length
186 (uint8_t *) &rsp, // Message
187 DSPQUEUE_TIMEOUT); // Timeout
188
189 if (err == AEE_EEXPIRED) {
190 // TODO: might need to bail out if the HTP is stuck on something
191 continue;
192 }
193
194 if (err != 0) {
195 GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err);
196 }
197
198 // Basic sanity checks
199 if (rsp_size != sizeof(rsp)) {
200 GGML_ABORT("ggml-hex: dspcall : bad response (size)\n");
201 }
202
203 if (rsp.status != HTP_STATUS_OK) {
204 GGML_LOG_ERROR("ggml-hex: dspcall : dsp-rsp: %s\n", status_to_str(rsp.status));
205 // TODO: handle errors
206 }
207
208 // TODO: update profiling implementation, currently only works for opt_opsync mode
209 this->prof_usecs = rsp.prof_usecs;
210 this->prof_cycles = rsp.prof_cycles;
211 this->prof_pkts = rsp.prof_pkts;
212
213 this->op_pending--; // atomic dec
214 }
215}
216
217// ** backend buffers
218
219struct ggml_backend_hexagon_buffer_type_context {
220 ggml_backend_hexagon_buffer_type_context(const std::string & name, ggml_hexagon_session * sess) {
221 this->sess = sess;
222 this->name = name;
223 }
224
225 ggml_hexagon_session * sess;
226 std::string name;
227};
228
229struct ggml_backend_hexagon_buffer_context {
230 bool mmap_to(ggml_hexagon_session * s) {
231 HEX_VERBOSE("ggml-hex: %s mmaping buffer: base %p domain-id %d session-id %d size %zu fd %d repack %d\n",
232 s->name.c_str(), (void *) this->base, s->domain_id, s->session_id, this->size, this->fd,
233 (int) this->repack);
234
235 int err = fastrpc_mmap(s->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD);
236 if (err != 0) {
237 GGML_LOG_ERROR("ggml-hex: buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n",
238 s->domain_id, this->size, this->fd, (unsigned) err);
239 return false;
240 }
241
242 return true;
243 }
244
245 bool mmap() {
246 if (this->mapped) {
247 return true;
248 }
249 if (!mmap_to(this->sess)) {
250 return false;
251 }
252 this->mapped = true;
253 return true;
254 }
255
256 void munmap() {
257 if (!this->mapped) {
258 return;
259 }
260
261 fastrpc_munmap(this->sess->domain_id, this->fd, this->base, this->size);
262 this->mapped = false;
263 }
264
265 ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {
266 size += 4 * 1024; // extra page for padding
267
268 this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
269 if (!this->base) {
270 GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size);
271 throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)");
272 }
273
274 this->fd = rpcmem_to_fd(this->base);
275 if (this->fd < 0) {
276 GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->name.c_str(), (void *) this->base);
277 rpcmem_free(this->base);
278 this->base = NULL;
279 throw std::runtime_error("ggml-hex: rpcmem_to_fd failed (see log for details)");
280 }
281
282 HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d repack %d\n", sess->name.c_str(),
283 (void *) this->base, size, this->fd, (int) repack);
284
285 this->sess = sess;
286 this->size = size;
287 this->mapped = false;
288 this->repack = repack;
289 }
290
291 ~ggml_backend_hexagon_buffer_context() {
292 munmap();
293 if (this->base) {
294 rpcmem_free(this->base);
295 this->base = NULL;
296 }
297 }
298
299 ggml_hexagon_session * sess; // primary session
300 uint8_t * base;
301 size_t size;
302 int fd;
303 bool mapped; // mmap is done
304 bool repack; // repacked buffer
305};
306
307static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_buffer_t buffer) {
308 return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer->buft->context)->sess;
309}
310
311static void ggml_backend_hexagon_buffer_free_buffer(ggml_backend_buffer_t buffer) {
312 auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
313 delete ctx;
314}
315
316static void * ggml_backend_hexagon_buffer_get_base(ggml_backend_buffer_t buffer) {
317 auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
318 return ctx->base;
319}
320
321static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
322 auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
323 auto sess = ctx->sess;
324
325 HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d repack %d\n", sess->name.c_str(),
326 tensor->name, (void *) ctx->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage,
327 (int) ctx->repack);
328
329 if (tensor->view_src != NULL && tensor->view_offs == 0) {
330 ; // nothing to do for the view
331 } else {
332 if (!ctx->mapped) {
333 ctx->mmap();
334 }
335 }
336 return GGML_STATUS_SUCCESS;
337}
338
339// ======== Q4x4x2 ====================
340struct x2_q4 {
341 int v[2];
342};
343
344static x2_q4 unpack_q4(uint8_t v) {
345 x2_q4 x = { (int) (v & 0x0f) - 8, (int) (v >> 4) - 8 };
346 return x;
347}
348
349static void dump_block_q4_0(const block_q4_0 * b, int i) {
350 HEX_VERBOSE("ggml-hex: repack q4_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_q4(b->qs[0]).v[0],
351 unpack_q4(b->qs[1]).v[0], unpack_q4(b->qs[2]).v[0], unpack_q4(b->qs[3]).v[0], unpack_q4(b->qs[12]).v[1],
352 unpack_q4(b->qs[13]).v[1], unpack_q4(b->qs[14]).v[1], unpack_q4(b->qs[15]).v[1],
353 GGML_FP16_TO_FP32(b->d));
354}
355
356static void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k) {
357 static const int qk = QK_Q4_0x4x2;
358 const int dblk_size = 8 * 2; // 8x __fp16
359 const int qblk_size = qk / 2; // int4
360 const int qrow_size = k / 2; // int4 (not padded)
361
362 const uint8_t * v_q = v + 0; // quants first
363 const uint8_t * v_d = v + qrow_size; // then scales
364
365 const uint8_t * q = v_q + i * qblk_size;
366 const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);
367
368 HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
369 unpack_q4(q[0]).v[0], unpack_q4(q[1]).v[0], unpack_q4(q[2]).v[0], unpack_q4(q[3]).v[0],
370 unpack_q4(q[60]).v[0], unpack_q4(q[61]).v[0], unpack_q4(q[62]).v[0], unpack_q4(q[63]).v[0],
371 unpack_q4(q[124]).v[0], unpack_q4(q[125]).v[0], unpack_q4(q[126]).v[0], unpack_q4(q[127]).v[0],
372 GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));
373
374 HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
375 i + 1, unpack_q4(q[0]).v[1], unpack_q4(q[1]).v[1], unpack_q4(q[2]).v[1], unpack_q4(q[3]).v[1],
376 unpack_q4(q[60]).v[1], unpack_q4(q[61]).v[1], unpack_q4(q[62]).v[1], unpack_q4(q[63]).v[1],
377 unpack_q4(q[124]).v[1], unpack_q4(q[125]).v[1], unpack_q4(q[126]).v[1], unpack_q4(q[127]).v[1],
378 GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));
379}
380
381static void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) {
382 static const int qk = QK4_0;
383
384 for (unsigned int i = 0; i < qk / 2; ++i) {
385 const int x0 = (x->qs[i] & 0x0F);
386 const int x1 = (x->qs[i] >> 4);
387 qs[bi * qk + i + 0] = x0;
388 qs[bi * qk + i + qk / 2] = x1;
389 }
390}
391
392static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi) {
393 static const int qk = QK4_0;
394
395 for (unsigned int i = 0; i < qk / 2; ++i) {
396 const uint8_t x0 = qs[bi * qk + i + 0];
397 const uint8_t x1 = qs[bi * qk + i + qk / 2];
398 x->qs[i] = x0 | (x1 << 4);
399 }
400}
401
402static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
403 static const int qk = QK_Q4_0x4x2;
404 const int nb = (k + qk - 1) / qk; // number of blocks (padded)
405
406 const int dblk_size = 8 * 2; // 8x __fp16
407 const int qblk_size = qk / 2; // int4
408 const int qrow_size = k / 2; // int4 (not padded to blocks)
409
410 uint8_t * y_q = y + 0; // quants first
411 uint8_t * y_d = y + qrow_size; // then scales
412
413 if (opt_verbose > 2) {
414 for (int i = 0; i < nb; i++) {
415 dump_block_q4_0(&x[i * 8 + 0], 0);
416 dump_block_q4_0(&x[i * 8 + 1], 1);
417 dump_block_q4_0(&x[i * 8 + 2], 2);
418 dump_block_q4_0(&x[i * 8 + 3], 3);
419 dump_block_q4_0(&x[i * 8 + 4], 4);
420 dump_block_q4_0(&x[i * 8 + 5], 5);
421 dump_block_q4_0(&x[i * 8 + 6], 6);
422 dump_block_q4_0(&x[i * 8 + 7], 7);
423 }
424 }
425
426 // Repack the quants
427 for (int i = 0; i < nb; i++) {
428 uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
429 unpack_q4_0_quants(qs, &x[i * 8 + 0], 0);
430 unpack_q4_0_quants(qs, &x[i * 8 + 1], 1);
431 unpack_q4_0_quants(qs, &x[i * 8 + 2], 2);
432 unpack_q4_0_quants(qs, &x[i * 8 + 3], 3);
433 unpack_q4_0_quants(qs, &x[i * 8 + 4], 4);
434 unpack_q4_0_quants(qs, &x[i * 8 + 5], 5);
435 unpack_q4_0_quants(qs, &x[i * 8 + 6], 6);
436 unpack_q4_0_quants(qs, &x[i * 8 + 7], 7);
437
438 uint8_t * q = y_q + (i * qblk_size);
439 for (int j = 0; j < qk / 2; j++) {
440 q[j] = (qs[j + 128] << 4) | qs[j];
441 }
442 }
443
444 // Repack the scales
445 // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
446 // the last block is truncated and overriden by the scales.
447 for (int i = 0; i < nb; i++) {
448 // Repack the scales
449 ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
450 d[0] = x[i * 8 + 0].d;
451 d[1] = x[i * 8 + 1].d;
452 d[2] = x[i * 8 + 2].d;
453 d[3] = x[i * 8 + 3].d;
454 d[4] = x[i * 8 + 4].d;
455 d[5] = x[i * 8 + 5].d;
456 d[6] = x[i * 8 + 6].d;
457 d[7] = x[i * 8 + 7].d;
458 }
459
460 if (opt_verbose > 1) {
461 for (int i = 0; i < nb; i++) {
462 dump_packed_block_q4x4x2(y, i, k);
463 }
464 }
465}
466
467static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
468 static const int qk = QK_Q4_0x4x2;
469 const int nb = (k + qk - 1) / qk; // number of blocks (padded)
470
471 const int dblk_size = 8 * 2; // 8x __fp16
472 const int qblk_size = qk / 2; // int4
473 const int qrow_size = k / 2; // int4 (not padded to blocks)
474
475 const uint8_t * y_q = y + 0; // quants first
476 const uint8_t * y_d = y + qrow_size; // then scales
477
478 if (opt_verbose > 1) {
479 for (int i = 0; i < nb; i++) {
480 dump_packed_block_q4x4x2(y, i, k);
481 }
482 }
483
484 // Unpack the quants
485 for (int i = 0; i < nb; i++) {
486 uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
487
488 const uint8_t * q = y_q + (i * qblk_size);
489 for (int j = 0; j < qk / 2; j++) {
490 qs[j] = q[j] & 0xf;
491 qs[j + 128] = q[j] >> 4;
492 }
493
494 pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
495 pack_q4_0_quants(&x[i * 8 + 1], qs, 1);
496 pack_q4_0_quants(&x[i * 8 + 2], qs, 2);
497 pack_q4_0_quants(&x[i * 8 + 3], qs, 3);
498 pack_q4_0_quants(&x[i * 8 + 4], qs, 4);
499 pack_q4_0_quants(&x[i * 8 + 5], qs, 5);
500 pack_q4_0_quants(&x[i * 8 + 6], qs, 6);
501 pack_q4_0_quants(&x[i * 8 + 7], qs, 7);
502 }
503
504 // Repack the scales
505 // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
506 // the last block is truncated and overriden by the scales.
507 for (int i = 0; i < nb; i++) {
508 // Unpack the scales
509 const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
510 x[i * 8 + 0].d = d[0];
511 x[i * 8 + 1].d = d[1];
512 x[i * 8 + 2].d = d[2];
513 x[i * 8 + 3].d = d[3];
514 x[i * 8 + 4].d = d[4];
515 x[i * 8 + 5].d = d[5];
516 x[i * 8 + 6].d = d[6];
517 x[i * 8 + 7].d = d[7];
518 }
519
520 if (opt_verbose > 2) {
521 for (int i = 0; i < nb; i++) {
522 dump_block_q4_0(&x[i * 8 + 0], 0);
523 dump_block_q4_0(&x[i * 8 + 1], 1);
524 dump_block_q4_0(&x[i * 8 + 2], 2);
525 dump_block_q4_0(&x[i * 8 + 3], 3);
526 dump_block_q4_0(&x[i * 8 + 4], 4);
527 dump_block_q4_0(&x[i * 8 + 5], 5);
528 dump_block_q4_0(&x[i * 8 + 6], 6);
529 dump_block_q4_0(&x[i * 8 + 7], 7);
530 }
531 }
532}
533
534static void init_row_q4x4x2(block_q4_0 * x, int64_t k) {
535 static const int qk = QK_Q4_0x4x2;
536 const int nb = (k + qk - 1) / qk; // number of blocks (padded)
537
538 // Init the quants such that they unpack into zeros
539 uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
540 memset(qs, 8, sizeof(qs));
541
542 for (int i = 0; i < nb; i++) {
543 pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
544 pack_q4_0_quants(&x[i * 8 + 1], qs, 1);
545 pack_q4_0_quants(&x[i * 8 + 2], qs, 2);
546 pack_q4_0_quants(&x[i * 8 + 3], qs, 3);
547 pack_q4_0_quants(&x[i * 8 + 4], qs, 4);
548 pack_q4_0_quants(&x[i * 8 + 5], qs, 5);
549 pack_q4_0_quants(&x[i * 8 + 6], qs, 6);
550 pack_q4_0_quants(&x[i * 8 + 7], qs, 7);
551 }
552
553 // Init the scales
554 // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
555 // the last block is truncated and overriden by the scales.
556 for (int i = 0; i < nb; i++) {
557 // Unpack the scales
558 x[i * 8 + 0].d = 0;
559 x[i * 8 + 1].d = 0;
560 x[i * 8 + 2].d = 0;
561 x[i * 8 + 3].d = 0;
562 x[i * 8 + 4].d = 0;
563 x[i * 8 + 5].d = 0;
564 x[i * 8 + 6].d = 0;
565 x[i * 8 + 7].d = 0;
566 }
567}
568
569// repack q4_0 data into q4x4x2 tensor
570static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) {
571 int64_t nrows = ggml_nrows(t);
572
573 size_t row_size = ggml_row_size(t->type, t->ne[0]);
574 size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
575 size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
576
577 // Ensure we don't try to read more data than is available in the source buffer 'data'
578 // or write more than the tensor can hold.
579 const size_t total_tensor_size = (size_t)nrows * row_size;
580 const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
581
582 // Calculate how many full rows and how many remaining bytes we need to process.
583 const int64_t n_full_rows = n_bytes_to_copy / row_size;
584 const size_t n_rem_bytes = n_bytes_to_copy % row_size;
585
586 void * buf_pd = ggml_aligned_malloc(row_size_pd);
587 GGML_ASSERT(buf_pd != NULL);
588
589 void * buf_rp = ggml_aligned_malloc(row_size_rp);
590 GGML_ASSERT(buf_rp != NULL);
591
592 HEX_VERBOSE("ggml-hex: repack-q4_0-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
593 t->ne[0], nrows, row_size);
594
595 init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
596
597 // 1. Process all the full rows
598 for (int64_t i = 0; i < n_full_rows; i++) {
599 const uint8_t * src = (const uint8_t *) data + (i * row_size);
600 uint8_t * dst = (uint8_t *) t->data + (i * row_size);
601
602 memcpy(buf_pd, src, row_size);
603 repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
604 memcpy(dst, buf_rp, row_size);
605 }
606
607 // 2. Process the final, potentially partial, row
608 if (n_rem_bytes > 0) {
609 const int64_t i = n_full_rows;
610 const uint8_t * src = (const uint8_t *) data + (i * row_size);
611 uint8_t * dst = (uint8_t *) t->data + (i * row_size);
612
613 // re-init the row because we are potentially copying a partial row
614 init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);
615
616 // Copy only the remaining bytes from the source.
617 memcpy(buf_pd, src, n_rem_bytes);
618
619 // Repack the entire buffer
620 repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
621
622 // Write only the corresponding remaining bytes to the destination tensor.
623 memcpy(dst, buf_rp, n_rem_bytes);
624 }
625
626 ggml_aligned_free(buf_pd, row_size_pd);
627 ggml_aligned_free(buf_rp, row_size_rp);
628}
629
630// repack q4x4x2 tensor into q4_0 data
631static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) {
632 int64_t nrows = ggml_nrows(t);
633
634 size_t row_size = ggml_row_size(t->type, t->ne[0]);
635 size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
636 size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
637
638 // Ensure we don't try to copy more data than the tensor actually contains.
639 const size_t total_tensor_size = (size_t)nrows * row_size;
640 const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
641
642 // Calculate how many full rows and how many remaining bytes we need to process.
643 const int64_t n_full_rows = n_bytes_to_copy / row_size;
644 const size_t n_rem_bytes = n_bytes_to_copy % row_size;
645
646 void * buf_pd = ggml_aligned_malloc(row_size_pd);
647 GGML_ASSERT(buf_pd != NULL);
648
649 void * buf_rp = ggml_aligned_malloc(row_size_rp);
650 GGML_ASSERT(buf_rp != NULL);
651
652 HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
653 t->ne[0], nrows, row_size);
654
655 memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
656
657 // 1. Process all the full rows
658 for (int64_t i = 0; i < n_full_rows; i++) {
659 const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
660 uint8_t * dst = (uint8_t *) data + (i * row_size);
661
662 memcpy(buf_pd, src, row_size);
663 unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
664 memcpy(dst, buf_rp, row_size);
665 }
666
667 // 2. Process the final, potentially partial, row
668 if (n_rem_bytes > 0) {
669 const int64_t i = n_full_rows;
670 const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
671 uint8_t * dst = (uint8_t *) data + (i * row_size);
672
673 // We still need to read and unpack the entire source row because quantization is block-based.
674 memcpy(buf_pd, src, row_size);
675 unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
676
677 // But we only copy the remaining number of bytes to the destination.
678 memcpy(dst, buf_rp, n_rem_bytes);
679 }
680
681 ggml_aligned_free(buf_pd, row_size_pd);
682 ggml_aligned_free(buf_rp, row_size_rp);
683}
684
685// ======== Q8x4x2 ====================
686static void dump_block_q8_0(const block_q8_0 * b, int i) {
687 HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2],
688 b->qs[3], b->qs[28], b->qs[29], b->qs[30], b->qs[31], GGML_FP16_TO_FP32(b->d));
689}
690
691static void dump_packed_block_q8x4x2(const uint8_t * v, unsigned int i, size_t k) {
692 static const int qk = QK_Q8_0x4x2;
693 const int dblk_size = 8 * 2; // 8x __fp16
694 const int qblk_size = qk; // int8
695 const int qrow_size = k; // int8 (not padded)
696
697 const uint8_t * v_q = v + 0; // quants first
698 const uint8_t * v_d = v + qrow_size; // then scales
699
700 const uint8_t * q = v_q + i * qblk_size;
701 const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);
702
703 HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
704 q[0], q[1], q[2], q[3], q[60], q[61], q[62], q[63], q[124], q[125], q[126], q[127],
705 GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));
706
707 HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
708 i + 1, q[128], q[129], q[130], q[131], q[192], q[193], q[194], q[195], q[252], q[253], q[254], q[255],
709 GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));
710}
711
712static void unpack_q8_0_quants(uint8_t * qs, const block_q8_0 * x, unsigned int bi) {
713 static const int qk = QK8_0;
714
715 for (unsigned int i = 0; i < qk; ++i) {
716 qs[bi * qk + i] = x->qs[i];
717 }
718}
719
720static void pack_q8_0_quants(block_q8_0 * x, const uint8_t * qs, unsigned int bi) {
721 static const int qk = QK8_0;
722
723 for (unsigned int i = 0; i < qk; ++i) {
724 x->qs[i] = qs[bi * qk + i];
725 }
726}
727
728static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) {
729 static const int qk = QK_Q8_0x4x2;
730 const int nb = (k + qk - 1) / qk; // number of blocks (padded)
731
732 const int dblk_size = 8 * 2; // 8x __fp16
733 const int qblk_size = qk; // int8
734 const int qrow_size = k; // int8 (not padded to blocks)
735
736 uint8_t * y_q = y + 0; // quants first
737 uint8_t * y_d = y + qrow_size; // then scales
738
739 if (opt_verbose > 2) {
740 for (int i = 0; i < nb; i++) {
741 dump_block_q8_0(&x[i * 8 + 0], 0);
742 dump_block_q8_0(&x[i * 8 + 1], 1);
743 dump_block_q8_0(&x[i * 8 + 2], 2);
744 dump_block_q8_0(&x[i * 8 + 3], 3);
745 dump_block_q8_0(&x[i * 8 + 4], 4);
746 dump_block_q8_0(&x[i * 8 + 5], 5);
747 dump_block_q8_0(&x[i * 8 + 6], 6);
748 dump_block_q8_0(&x[i * 8 + 7], 7);
749 }
750 }
751
752 // Repack the quants
753 for (int i = 0; i < nb; i++) {
754 uint8_t qs[QK_Q8_0x4x2]; // unpacked quants
755
756 unpack_q8_0_quants(qs, &x[i * 8 + 0], 0);
757 unpack_q8_0_quants(qs, &x[i * 8 + 1], 1);
758 unpack_q8_0_quants(qs, &x[i * 8 + 2], 2);
759 unpack_q8_0_quants(qs, &x[i * 8 + 3], 3);
760 unpack_q8_0_quants(qs, &x[i * 8 + 4], 4);
761 unpack_q8_0_quants(qs, &x[i * 8 + 5], 5);
762 unpack_q8_0_quants(qs, &x[i * 8 + 6], 6);
763 unpack_q8_0_quants(qs, &x[i * 8 + 7], 7);
764
765 uint8_t * q = y_q + (i * qblk_size);
766 for (int j = 0; j < qk; j++) {
767 q[j] = qs[j];
768 }
769 }
770
771 // Repack the scales
772 // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
773 // the last block is truncated and overriden by the scales.
774 for (int i = 0; i < nb; i++) {
775 // Repack the scales
776 ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
777 d[0] = x[i * 8 + 0].d;
778 d[1] = x[i * 8 + 1].d;
779 d[2] = x[i * 8 + 2].d;
780 d[3] = x[i * 8 + 3].d;
781 d[4] = x[i * 8 + 4].d;
782 d[5] = x[i * 8 + 5].d;
783 d[6] = x[i * 8 + 6].d;
784 d[7] = x[i * 8 + 7].d;
785 }
786
787 if (opt_verbose > 1) {
788 for (int i = 0; i < nb; i++) {
789 dump_packed_block_q8x4x2(y, i, k);
790 }
791 }
792}
793
794static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) {
795 static const int qk = QK_Q8_0x4x2;
796 const int nb = (k + qk - 1) / qk; // number of blocks (padded)
797
798 const int dblk_size = 8 * 2; // 8x __fp16
799 const int qblk_size = qk; // int8
800 const int qrow_size = k; // int8 (not padded to blocks)
801
802 const uint8_t * y_q = y + 0; // quants first
803 const uint8_t * y_d = y + qrow_size; // then scales
804
805 if (opt_verbose > 1) {
806 for (int i = 0; i < nb; i++) {
807 dump_packed_block_q8x4x2(y, i, k);
808 }
809 }
810
811 // Unpack the quants
812 for (int i = 0; i < nb; i++) {
813 uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
814
815 const uint8_t * q = y_q + (i * qblk_size);
816 for (int j = 0; j < qk; j++) {
817 qs[j] = q[j];
818 }
819
820 pack_q8_0_quants(&x[i * 8 + 0], qs, 0);
821 pack_q8_0_quants(&x[i * 8 + 1], qs, 1);
822 pack_q8_0_quants(&x[i * 8 + 2], qs, 2);
823 pack_q8_0_quants(&x[i * 8 + 3], qs, 3);
824 pack_q8_0_quants(&x[i * 8 + 4], qs, 4);
825 pack_q8_0_quants(&x[i * 8 + 5], qs, 5);
826 pack_q8_0_quants(&x[i * 8 + 6], qs, 6);
827 pack_q8_0_quants(&x[i * 8 + 7], qs, 7);
828 }
829
830 // Repack the scales
831 // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
832 // the last block is truncated and overriden by the scales.
833 for (int i = 0; i < nb; i++) {
834 // Unpack the scales
835 const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
836 x[i * 8 + 0].d = d[0];
837 x[i * 8 + 1].d = d[1];
838 x[i * 8 + 2].d = d[2];
839 x[i * 8 + 3].d = d[3];
840 x[i * 8 + 4].d = d[4];
841 x[i * 8 + 5].d = d[5];
842 x[i * 8 + 6].d = d[6];
843 x[i * 8 + 7].d = d[7];
844 }
845
846 if (opt_verbose > 2) {
847 for (int i = 0; i < nb; i++) {
848 dump_block_q8_0(&x[i * 8 + 0], 0);
849 dump_block_q8_0(&x[i * 8 + 1], 1);
850 dump_block_q8_0(&x[i * 8 + 2], 2);
851 dump_block_q8_0(&x[i * 8 + 3], 3);
852 dump_block_q8_0(&x[i * 8 + 4], 4);
853 dump_block_q8_0(&x[i * 8 + 5], 5);
854 dump_block_q8_0(&x[i * 8 + 6], 6);
855 dump_block_q8_0(&x[i * 8 + 7], 7);
856 }
857 }
858}
859
860static void init_row_q8x4x2(block_q8_0 * x, int64_t k) {
861 static const int qk = QK_Q8_0x4x2;
862 const int nb = (k + qk - 1) / qk; // number of blocks (padded)
863
864 // Init the quants such that they unpack into zeros
865 uint8_t qs[QK_Q8_0x4x2]; // unpacked quants
866 memset(qs, 0, sizeof(qs));
867
868 for (int i = 0; i < nb; i++) {
869 pack_q8_0_quants(&x[i * 8 + 0], qs, 0);
870 pack_q8_0_quants(&x[i * 8 + 1], qs, 1);
871 pack_q8_0_quants(&x[i * 8 + 2], qs, 2);
872 pack_q8_0_quants(&x[i * 8 + 3], qs, 3);
873 pack_q8_0_quants(&x[i * 8 + 4], qs, 4);
874 pack_q8_0_quants(&x[i * 8 + 5], qs, 5);
875 pack_q8_0_quants(&x[i * 8 + 6], qs, 6);
876 pack_q8_0_quants(&x[i * 8 + 7], qs, 7);
877 }
878
879 // Init the scales
880 // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2)
881 // the last block is truncated and overriden by the scales.
882 for (int i = 0; i < nb; i++) {
883 // Unpack the scales
884 x[i * 8 + 0].d = 0;
885 x[i * 8 + 1].d = 0;
886 x[i * 8 + 2].d = 0;
887 x[i * 8 + 3].d = 0;
888 x[i * 8 + 4].d = 0;
889 x[i * 8 + 5].d = 0;
890 x[i * 8 + 6].d = 0;
891 x[i * 8 + 7].d = 0;
892 }
893}
894
895// repack q8_0 data into q8x4x2 tensor
896static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) {
897 int64_t nrows = ggml_nrows(t);
898
899 size_t row_size = ggml_row_size(t->type, t->ne[0]);
900 size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
901 size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
902
903 // Ensure we don't try to read more data than is available in the source buffer 'data'
904 // or write more than the tensor can hold.
905 const size_t total_tensor_size = (size_t)nrows * row_size;
906 const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
907
908 // Calculate how many full rows and how many remaining bytes we need to process.
909 const int64_t n_full_rows = n_bytes_to_copy / row_size;
910 const size_t n_rem_bytes = n_bytes_to_copy % row_size;
911
912 void * buf_pd = ggml_aligned_malloc(row_size_pd);
913 GGML_ASSERT(buf_pd != NULL);
914
915 void * buf_rp = ggml_aligned_malloc(row_size_rp);
916 GGML_ASSERT(buf_rp != NULL);
917
918 HEX_VERBOSE("ggml-hex: repack-q8_0-q8x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
919 t->ne[0], nrows, row_size);
920
921 init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
922
923 // 1. Process all the full rows
924 for (int64_t i = 0; i < n_full_rows; i++) {
925 const uint8_t * src = (const uint8_t *) data + (i * row_size);
926 uint8_t * dst = (uint8_t *) t->data + (i * row_size);
927
928 memcpy(buf_pd, src, row_size);
929 repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
930 memcpy(dst, buf_rp, row_size);
931 }
932
933 // 2. Process the final, potentially partial, row
934 if (n_rem_bytes > 0) {
935 const int64_t i = n_full_rows;
936 const uint8_t * src = (const uint8_t *) data + (i * row_size);
937 uint8_t * dst = (uint8_t *) t->data + (i * row_size);
938
939 // re-init the row because we are potentially copying a partial row
940 init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);
941
942 // Copy only the remaining bytes from the source.
943 memcpy(buf_pd, src, n_rem_bytes);
944
945 // Repack the entire buffer
946 repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
947
948 // Write only the corresponding remaining bytes to the destination tensor.
949 memcpy(dst, buf_rp, n_rem_bytes);
950 }
951
952 ggml_aligned_free(buf_pd, row_size_pd);
953 ggml_aligned_free(buf_rp, row_size_rp);
954}
955
956// repack q8x4x2 tensor into q8_0 data
957static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) {
958 int64_t nrows = ggml_nrows(t);
959
960 size_t row_size = ggml_row_size(t->type, t->ne[0]);
961 size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
962 size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
963
964 // Ensure we don't try to copy more data than the tensor actually contains.
965 const size_t total_tensor_size = (size_t)nrows * row_size;
966 const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
967
968 // Calculate how many full rows and how many remaining bytes we need to process.
969 const int64_t n_full_rows = n_bytes_to_copy / row_size;
970 const size_t n_rem_bytes = n_bytes_to_copy % row_size;
971
972 void * buf_pd = ggml_aligned_malloc(row_size_pd);
973 GGML_ASSERT(buf_pd != NULL);
974
975 void * buf_rp = ggml_aligned_malloc(row_size_rp);
976 GGML_ASSERT(buf_rp != NULL);
977
978 HEX_VERBOSE("ggml-hex: repack-q8x4x2-q8_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
979 t->ne[0], nrows, row_size);
980
981 memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
982
983 // 1. Process all the full rows
984 for (int64_t i = 0; i < n_full_rows; i++) {
985 const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
986 uint8_t * dst = (uint8_t *) data + (i * row_size);
987
988 memcpy(buf_pd, src, row_size);
989 unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
990 memcpy(dst, buf_rp, row_size);
991 }
992
993 // 2. Process the final, potentially partial, row
994 if (n_rem_bytes > 0) {
995 const int64_t i = n_full_rows;
996 const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
997 uint8_t * dst = (uint8_t *) data + (i * row_size);
998
999 // We still need to read and unpack the entire source row because quantization is block-based.
1000 memcpy(buf_pd, src, row_size);
1001 unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
1002
1003 // But we only copy the remaining number of bytes to the destination.
1004 memcpy(dst, buf_rp, n_rem_bytes);
1005 }
1006
1007 ggml_aligned_free(buf_pd, row_size_pd);
1008 ggml_aligned_free(buf_rp, row_size_rp);
1009}
1010
1011// ======== MXFP4x4x2 ====================
1012struct x2_mxfp4 {
1013 int v[2];
1014};
1015
1016static x2_mxfp4 unpack_mxfp4(uint8_t v) {
1017 x2_mxfp4 x;
1018 x.v[0] = kvalues_mxfp4[(v & 0x0f)];
1019 x.v[1] = kvalues_mxfp4[(v >> 4)];
1020 return x;
1021}
1022
1023static void dump_block_mxfp4(const block_mxfp4 * b, int i) {
1024 HEX_VERBOSE("ggml-hex: repack mxfp4 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_mxfp4(b->qs[0]).v[0],
1025 unpack_mxfp4(b->qs[1]).v[0], unpack_mxfp4(b->qs[2]).v[0], unpack_mxfp4(b->qs[3]).v[0],
1026 unpack_mxfp4(b->qs[12]).v[1], unpack_mxfp4(b->qs[13]).v[1], unpack_mxfp4(b->qs[14]).v[1],
1027 unpack_mxfp4(b->qs[15]).v[1], GGML_E8M0_TO_FP32_HALF(b->e));
1028}
1029
1030static void dump_packed_block_mxfp4x4x2(const uint8_t * v, unsigned int i, size_t k) {
1031 static const int qk = QK_MXFP4x4x2;
1032 const int eblk_size = 8 * 1; // 8x E8M0
1033 const int qblk_size = qk / 2; // int4
1034 const int qrow_size = k / 2; // int4 (not padded)
1035
1036 const uint8_t * v_q = v + 0; // quants first
1037 const uint8_t * v_e = v + qrow_size; // then scales
1038
1039 const uint8_t * q = v_q + i * qblk_size;
1040 const uint8_t * e = (const uint8_t *) (v_e + i * eblk_size);
1041
1042 HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
1043 unpack_mxfp4(q[0]).v[0], unpack_mxfp4(q[1]).v[0], unpack_mxfp4(q[2]).v[0], unpack_mxfp4(q[3]).v[0],
1044 unpack_mxfp4(q[60]).v[0], unpack_mxfp4(q[61]).v[0], unpack_mxfp4(q[62]).v[0], unpack_mxfp4(q[63]).v[0],
1045 unpack_mxfp4(q[124]).v[0], unpack_mxfp4(q[125]).v[0], unpack_mxfp4(q[126]).v[0],
1046 unpack_mxfp4(q[127]).v[0], GGML_E8M0_TO_FP32_HALF(e[0]), GGML_E8M0_TO_FP32_HALF(e[1]),
1047 GGML_E8M0_TO_FP32_HALF(e[2]), GGML_E8M0_TO_FP32_HALF(e[3]));
1048
1049 HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
1050 i + 1, unpack_mxfp4(q[0]).v[1], unpack_mxfp4(q[1]).v[1], unpack_mxfp4(q[2]).v[1],
1051 unpack_mxfp4(q[3]).v[1], unpack_mxfp4(q[60]).v[1], unpack_mxfp4(q[61]).v[1], unpack_mxfp4(q[62]).v[1],
1052 unpack_mxfp4(q[63]).v[1], unpack_mxfp4(q[124]).v[1], unpack_mxfp4(q[125]).v[1],
1053 unpack_mxfp4(q[126]).v[1], unpack_mxfp4(q[127]).v[1], GGML_E8M0_TO_FP32_HALF(e[4]),
1054 GGML_E8M0_TO_FP32_HALF(e[5]), GGML_E8M0_TO_FP32_HALF(e[6]), GGML_E8M0_TO_FP32_HALF(e[7]));
1055}
1056
1057static void unpack_mxfp4_quants(uint8_t * qs, const block_mxfp4 * x, unsigned int bi) {
1058 static const int qk = QK_MXFP4;
1059
1060 for (unsigned int i = 0; i < qk / 2; ++i) {
1061 const uint8_t x0 = (x->qs[i] & 0x0F);
1062 const uint8_t x1 = (x->qs[i] >> 4);
1063 qs[bi * qk + i + 0] = x0;
1064 qs[bi * qk + i + qk / 2] = x1;
1065 }
1066}
1067
1068static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int bi) {
1069 static const int qk = QK4_0;
1070
1071 for (unsigned int i = 0; i < qk / 2; ++i) {
1072 const uint8_t x0 = qs[bi * qk + i + 0];
1073 const uint8_t x1 = qs[bi * qk + i + qk / 2];
1074 x->qs[i] = x0 | (x1 << 4);
1075 }
1076}
1077
1078static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) {
1079 static const int qk = QK_MXFP4x4x2;
1080 const int nb = (k + qk - 1) / qk; // number of blocks (padded)
1081
1082 const int eblk_size = 8 * 1; // 8x E8M0
1083 const int qblk_size = qk / 2; // int4
1084 const int qrow_size = k / 2; // int4 (not padded to blocks)
1085
1086 uint8_t * y_q = y + 0; // quants first
1087 uint8_t * y_e = y + qrow_size; // then scales
1088
1089 if (opt_verbose > 2) {
1090 for (int i = 0; i < nb; i++) {
1091 dump_block_mxfp4(&x[i * 8 + 0], 0);
1092 dump_block_mxfp4(&x[i * 8 + 1], 1);
1093 dump_block_mxfp4(&x[i * 8 + 2], 2);
1094 dump_block_mxfp4(&x[i * 8 + 3], 3);
1095 dump_block_mxfp4(&x[i * 8 + 4], 4);
1096 dump_block_mxfp4(&x[i * 8 + 5], 5);
1097 dump_block_mxfp4(&x[i * 8 + 6], 6);
1098 dump_block_mxfp4(&x[i * 8 + 7], 7);
1099 }
1100 }
1101
1102 // Repack the quants
1103 for (int i = 0; i < nb; i++) {
1104 uint8_t qs[QK_MXFP4x4x2]; // unpacked quants
1105
1106 unpack_mxfp4_quants(qs, &x[i * 8 + 0], 0);
1107 unpack_mxfp4_quants(qs, &x[i * 8 + 1], 1);
1108 unpack_mxfp4_quants(qs, &x[i * 8 + 2], 2);
1109 unpack_mxfp4_quants(qs, &x[i * 8 + 3], 3);
1110 unpack_mxfp4_quants(qs, &x[i * 8 + 4], 4);
1111 unpack_mxfp4_quants(qs, &x[i * 8 + 5], 5);
1112 unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6);
1113 unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7);
1114
1115 uint8_t * q = y_q + (i * qblk_size);
1116 for (int j = 0; j < qk / 2; j++) {
1117 q[j] = (qs[j + 128] << 4) | qs[j];
1118 }
1119 }
1120
1121 // Repack the scales
1122 // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
1123 // the last block is truncated and overriden by the scales.
1124 for (int i = 0; i < nb; i++) {
1125 // Repack the scales
1126 uint8_t * e = (uint8_t *) (y_e + i * eblk_size);
1127 e[0] = x[i * 8 + 0].e;
1128 e[1] = x[i * 8 + 1].e;
1129 e[2] = x[i * 8 + 2].e;
1130 e[3] = x[i * 8 + 3].e;
1131 e[4] = x[i * 8 + 4].e;
1132 e[5] = x[i * 8 + 5].e;
1133 e[6] = x[i * 8 + 6].e;
1134 e[7] = x[i * 8 + 7].e;
1135 }
1136
1137 if (opt_verbose > 1) {
1138 for (int i = 0; i < nb; i++) {
1139 dump_packed_block_mxfp4x4x2(y, i, k);
1140 }
1141 }
1142}
1143
1144static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) {
1145 static const int qk = QK_MXFP4x4x2;
1146 const int nb = (k + qk - 1) / qk; // number of blocks (padded)
1147
1148 const int eblk_size = 8 * 1; // 8x E8M0
1149 const int qblk_size = qk / 2; // int4
1150 const int qrow_size = k / 2; // int4 (not padded to blocks)
1151
1152 const uint8_t * y_q = y + 0; // quants first
1153 const uint8_t * y_e = y + qrow_size; // then scales
1154
1155 if (opt_verbose > 1) {
1156 for (int i = 0; i < nb; i++) {
1157 dump_packed_block_mxfp4x4x2(y, i, k);
1158 }
1159 }
1160
1161 // Unpack the quants
1162 for (int i = 0; i < nb; i++) {
1163 uint8_t qs[QK_MXFP4x4x2]; // unpacked quants
1164
1165 const uint8_t * q = y_q + (i * qblk_size);
1166 for (int j = 0; j < qk / 2; j++) {
1167 qs[j] = q[j] & 0xf;
1168 qs[j + 128] = q[j] >> 4;
1169 }
1170
1171 pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
1172 pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);
1173 pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);
1174 pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);
1175 pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);
1176 pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);
1177 pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);
1178 pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);
1179 }
1180
1181 // Repack the scales
1182 // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2)
1183 // the last block is truncated and overriden by the scales.
1184 for (int i = 0; i < nb; i++) {
1185 // Unpack the scales
1186 const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size);
1187 x[i * 8 + 0].e = e[0];
1188 x[i * 8 + 1].e = e[1];
1189 x[i * 8 + 2].e = e[2];
1190 x[i * 8 + 3].e = e[3];
1191 x[i * 8 + 4].e = e[4];
1192 x[i * 8 + 5].e = e[5];
1193 x[i * 8 + 6].e = e[6];
1194 x[i * 8 + 7].e = e[7];
1195 }
1196
1197 if (opt_verbose > 2) {
1198 for (int i = 0; i < nb; i++) {
1199 dump_block_mxfp4(&x[i * 8 + 0], 0);
1200 dump_block_mxfp4(&x[i * 8 + 1], 1);
1201 dump_block_mxfp4(&x[i * 8 + 2], 2);
1202 dump_block_mxfp4(&x[i * 8 + 3], 3);
1203 dump_block_mxfp4(&x[i * 8 + 4], 4);
1204 dump_block_mxfp4(&x[i * 8 + 5], 5);
1205 dump_block_mxfp4(&x[i * 8 + 6], 6);
1206 dump_block_mxfp4(&x[i * 8 + 7], 7);
1207 }
1208 }
1209}
1210
1211static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) {
1212 static const int qk = QK_MXFP4x4x2;
1213 const int nb = (k + qk - 1) / qk; // number of blocks (padded)
1214
1215 // Init the quants such that they unpack into zeros
1216 uint8_t qs[QK_MXFP4x4x2]; // unpacked quants
1217 memset(qs, 0, sizeof(qs));
1218
1219 for (int i = 0; i < nb; i++) {
1220 pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
1221 pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);
1222 pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);
1223 pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);
1224 pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);
1225 pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);
1226 pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);
1227 pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);
1228 }
1229
1230 // Init the scales
1231 // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
1232 // the last block is truncated and overriden by the scales.
1233 for (int i = 0; i < nb; i++) {
1234 // Unpack the scales
1235 x[i * 8 + 0].e = 0;
1236 x[i * 8 + 1].e = 0;
1237 x[i * 8 + 2].e = 0;
1238 x[i * 8 + 3].e = 0;
1239 x[i * 8 + 4].e = 0;
1240 x[i * 8 + 5].e = 0;
1241 x[i * 8 + 6].e = 0;
1242 x[i * 8 + 7].e = 0;
1243 }
1244}
1245
1246// repack mxfp4 data into mxfp4x4x2 tensor
1247static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t size) {
1248 int64_t nrows = ggml_nrows(t);
1249
1250 size_t row_size = ggml_row_size(t->type, t->ne[0]);
1251 size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
1252 size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
1253
1254 // Ensure we don't try to read more data than is available in the source buffer 'data'
1255 // or write more than the tensor can hold.
1256 const size_t total_tensor_size = (size_t)nrows * row_size;
1257 const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
1258
1259 // Calculate how many full rows and how many remaining bytes we need to process.
1260 const int64_t n_full_rows = n_bytes_to_copy / row_size;
1261 const size_t n_rem_bytes = n_bytes_to_copy % row_size;
1262
1263 void * buf_pd = ggml_aligned_malloc(row_size_pd);
1264 GGML_ASSERT(buf_pd != NULL);
1265
1266 void * buf_rp = ggml_aligned_malloc(row_size_rp);
1267 GGML_ASSERT(buf_rp != NULL);
1268
1269 HEX_VERBOSE("ggml-hex: repack-mxfp4-mxfp4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data,
1270 size, t->ne[0], nrows, row_size);
1271
1272 init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
1273
1274 // 1. Process all the full rows
1275 for (int64_t i = 0; i < n_full_rows; i++) {
1276 const uint8_t * src = (const uint8_t *) data + (i * row_size);
1277 uint8_t * dst = (uint8_t *) t->data + (i * row_size);
1278
1279 memcpy(buf_pd, src, row_size);
1280 repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
1281 memcpy(dst, buf_rp, row_size);
1282 }
1283
1284 // 2. Process the final, potentially partial, row
1285 if (n_rem_bytes > 0) {
1286 const int64_t i = n_full_rows;
1287 const uint8_t * src = (const uint8_t *) data + (i * row_size);
1288 uint8_t * dst = (uint8_t *) t->data + (i * row_size);
1289
1290 // re-init the row because we are potentially copying a partial row
1291 init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);
1292
1293 // Copy only the remaining bytes from the source.
1294 memcpy(buf_pd, src, n_rem_bytes);
1295
1296 // Repack the entire buffer (partial data + zero padding).
1297 repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
1298
1299 // Write only the corresponding remaining bytes to the destination tensor.
1300 memcpy(dst, buf_rp, n_rem_bytes);
1301 }
1302
1303 ggml_aligned_free(buf_pd, row_size_pd);
1304 ggml_aligned_free(buf_rp, row_size_rp);
1305}
1306
1307// repack mxfp4x4x2 tensor into mxfp4 data
1308static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t size) {
1309 int64_t nrows = ggml_nrows(t);
1310
1311 size_t row_size = ggml_row_size(t->type, t->ne[0]);
1312 size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
1313 size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
1314
1315 // Ensure we don't try to copy more data than the tensor actually contains.
1316 const size_t total_tensor_size = (size_t)nrows * row_size;
1317 const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
1318
1319 // Calculate how many full rows and how many remaining bytes we need to process.
1320 const int64_t n_full_rows = n_bytes_to_copy / row_size;
1321 const size_t n_rem_bytes = n_bytes_to_copy % row_size;
1322
1323 void * buf_pd = ggml_aligned_malloc(row_size_pd);
1324 GGML_ASSERT(buf_pd != NULL);
1325
1326 void * buf_rp = ggml_aligned_malloc(row_size_rp);
1327 GGML_ASSERT(buf_rp != NULL);
1328
1329 HEX_VERBOSE("ggml-hex: repack-mxfp4x4x2-mxfp4 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data,
1330 size, t->ne[0], nrows, row_size);
1331
1332 memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
1333
1334 // 1. Process all the full rows
1335 for (int64_t i = 0; i < n_full_rows; i++) {
1336 const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
1337 uint8_t * dst = (uint8_t *) data + (i * row_size);
1338
1339 memcpy(buf_pd, src, row_size);
1340 unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
1341 memcpy(dst, buf_rp, row_size);
1342 }
1343
1344 // 2. Process the final, potentially partial, row
1345 if (n_rem_bytes > 0) {
1346 const int64_t i = n_full_rows;
1347 const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
1348 uint8_t * dst = (uint8_t *) data + (i * row_size);
1349
1350 // We still need to read and unpack the entire source row because the format is block-based.
1351 memcpy(buf_pd, src, row_size);
1352 unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
1353
1354 // But we only copy the remaining number of bytes to the destination to respect the size limit.
1355 memcpy(dst, buf_rp, n_rem_bytes);
1356 }
1357
1358 ggml_aligned_free(buf_pd, row_size_pd);
1359 ggml_aligned_free(buf_rp, row_size_rp);
1360}
1361
1362static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
1363 ggml_tensor * tensor,
1364 const void * data,
1365 size_t offset,
1366 size_t size) {
1367 auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context;
1368 auto sess = ctx->sess;
1369
1370 HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data,
1371 offset, size);
1372
1373 switch (tensor->type) {
1374 case GGML_TYPE_Q4_0:
1375 GGML_ASSERT(offset == 0);
1376 GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1377 repack_q4_0_q4x4x2(tensor, data, size);
1378 break;
1379
1380 case GGML_TYPE_Q8_0:
1381 GGML_ASSERT(offset == 0);
1382 GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1383 repack_q8_0_q8x4x2(tensor, data, size);
1384 break;
1385
1386 case GGML_TYPE_MXFP4:
1387 GGML_ASSERT(offset == 0);
1388 GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1389 repack_mxfp4_mxfp4x4x2(tensor, data, size);
1390 break;
1391
1392 default:
1393 memcpy((char *) tensor->data + offset, data, size);
1394 break;
1395 }
1396}
1397
1398static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
1399 const ggml_tensor * tensor,
1400 void * data,
1401 size_t offset,
1402 size_t size) {
1403 auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context;
1404 auto sess = ctx->sess;
1405
1406 HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data,
1407 offset, size);
1408
1409 switch (tensor->type) {
1410 case GGML_TYPE_Q4_0:
1411 GGML_ASSERT(offset == 0);
1412 GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1413 repack_q4x4x2_q4_0(data, tensor, size);
1414 break;
1415
1416 case GGML_TYPE_Q8_0:
1417 GGML_ASSERT(offset == 0);
1418 GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1419 repack_q8x4x2_q8_0(data, tensor, size);
1420 break;
1421
1422 case GGML_TYPE_MXFP4:
1423 GGML_ASSERT(offset == 0);
1424 GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1425 repack_mxfp4x4x2_mxfp4(data, tensor, size);
1426 break;
1427
1428 default:
1429 memcpy(data, (const char *) tensor->data + offset, size);
1430 break;
1431 }
1432}
1433
1434static bool ggml_backend_hexagon_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
1435 const struct ggml_tensor * src,
1436 struct ggml_tensor * dst) {
1437 GGML_UNUSED(buffer);
1438 GGML_UNUSED(src);
1439 GGML_UNUSED(dst);
1440 // we might optimize this later, for now take the slow path (ie get/set_tensor)
1441 return false;
1442}
1443
1444static void ggml_backend_hexagon_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1445 auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context;
1446 auto sess = ctx->sess;
1447 HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->name.c_str(), (void *) ctx->base, ctx->size);
1448 memset(ctx->base, value, ctx->size);
1449}
1450
1451static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = {
1452 /* .free_buffer = */ ggml_backend_hexagon_buffer_free_buffer,
1453 /* .get_base = */ ggml_backend_hexagon_buffer_get_base,
1454 /* .init_tensor = */ ggml_backend_hexagon_buffer_init_tensor,
1455 /* .memset_tensor = */ NULL,
1456 /* .set_tensor = */ ggml_backend_hexagon_buffer_set_tensor,
1457 /* .get_tensor = */ ggml_backend_hexagon_buffer_get_tensor,
1458 /* .cpy_tensor = */ ggml_backend_hexagon_buffer_cpy_tensor,
1459 /* .clear = */ ggml_backend_hexagon_buffer_clear,
1460 /* .reset = */ NULL,
1461};
1462
1463// ** backend buffer type
1464
1465static const char * ggml_backend_hexagon_buffer_type_name(ggml_backend_buffer_type_t buffer_type) {
1466 return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->name.c_str();
1467}
1468
1469static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer(
1470 ggml_backend_buffer_type_t buffer_type, size_t size) {
1471 auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
1472 try {
1473 ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/);
1474 return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);
1475 } catch (const std::exception & exc) {
1476 GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what());
1477 return nullptr;
1478 }
1479}
1480
1481static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffer(
1482 ggml_backend_buffer_type_t buffer_type, size_t size) {
1483 auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
1484 try {
1485 ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/);
1486 return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);
1487 } catch (const std::exception & exc) {
1488 GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what());
1489 return nullptr;
1490 }
1491}
1492
1493static size_t ggml_backend_hexagon_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
1494 return 128; // HVX alignment
1495 GGML_UNUSED(buffer_type);
1496}
1497
1498static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * t) {
1499 return ggml_nbytes(t);
1500}
1501
1502static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
1503 return 1 * 1024 * 1024 * 1024; // 1GB per buffer
1504 GGML_UNUSED(buffer_type);
1505}
1506
1507static bool ggml_backend_hexagon_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1508 return opt_hostbuf;
1509 GGML_UNUSED(buft);
1510}
1511
1512static bool ggml_backend_hexagon_repack_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1513 return false;
1514 GGML_UNUSED(buft);
1515}
1516
1517static ggml_backend_buffer_type_i ggml_backend_hexagon_buffer_type_interface = {
1518 /* .get_name = */ ggml_backend_hexagon_buffer_type_name,
1519 /* .alloc_buffer = */ ggml_backend_hexagon_buffer_type_alloc_buffer,
1520 /* .get_alignment = */ ggml_backend_hexagon_buffer_type_get_alignment,
1521 /* .get_max_size = */ ggml_backend_hexagon_buffer_type_get_max_size,
1522 /* .get_alloc_size = */ ggml_backend_hexagon_buffer_type_get_alloc_size,
1523 /* .is_host = */ ggml_backend_hexagon_buffer_type_is_host,
1524};
1525
1526static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interface = {
1527 /* .get_name = */ ggml_backend_hexagon_buffer_type_name,
1528 /* .alloc_buffer = */ ggml_backend_hexagon_repack_buffer_type_alloc_buffer,
1529 /* .get_alignment = */ ggml_backend_hexagon_buffer_type_get_alignment,
1530 /* .get_max_size = */ ggml_backend_hexagon_buffer_type_get_max_size,
1531 /* .get_alloc_size = */ ggml_backend_hexagon_buffer_type_get_alloc_size,
1532 /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host,
1533};
1534
1535void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
1536 this->valid_session = false;
1537 this->valid_handle = false;
1538 this->valid_queue = false;
1539 this->valid_iface = false;
1540
1541 this->domain_id = 3; // Default for CDSP, updated after the session is created
1542 this->session_id = 0; // Default for CDSP, updated after the session is created
1543 this->dev_id = dev_id;
1544 this->name = std::string("HTP") + std::to_string(dev_id);
1545
1546 this->op_pending = 0;
1547 this->prof_usecs = 0;
1548 this->prof_cycles = 0;
1549 this->prof_pkts = 0;
1550
1551 GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str());
1552
1553 domain * my_domain = get_domain(this->domain_id);
1554 if (my_domain == NULL) {
1555 GGML_LOG_ERROR("ggml-hex: unable to get domain struct for CDSP\n");
1556 throw std::runtime_error("ggml-hex: failed to get CDSP domain (see log for details)");
1557 }
1558
1559 // Create new session
1560 if (dev_id != 0) {
1561 struct remote_rpc_reserve_new_session n;
1562 n.domain_name_len = strlen(CDSP_DOMAIN_NAME);
1563 n.domain_name = const_cast<char *>(CDSP_DOMAIN_NAME);
1564 n.session_name = const_cast<char *>(this->name.c_str());
1565 n.session_name_len = this->name.size();
1566
1567 int err = remote_session_control(FASTRPC_RESERVE_NEW_SESSION, (void *) &n, sizeof(n));
1568 if (err != AEE_SUCCESS) {
1569 GGML_LOG_ERROR("ggml-hex: failed to reserve new session %d : error 0x%x\n", dev_id, err);
1570 throw std::runtime_error("ggml-hex: remote_session_control(new-sess) failed (see log for details)");
1571 }
1572
1573 // Save the IDs
1574 this->session_id = n.session_id;
1575 this->domain_id = n.effective_domain_id;
1576 this->valid_session = true;
1577 }
1578
1579 // Get session URI
1580
1581 char session_uri[256];
1582 {
1583 char htp_uri[256];
1584 snprintf(htp_uri, sizeof(htp_uri), "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch);
1585
1586 struct remote_rpc_get_uri u = {};
1587 u.session_id = this->session_id;
1588 u.domain_name = const_cast<char *>(CDSP_DOMAIN_NAME);
1589 u.domain_name_len = strlen(CDSP_DOMAIN_NAME);
1590 u.module_uri = const_cast<char *>(htp_uri);
1591 u.module_uri_len = strlen(htp_uri);
1592 u.uri = session_uri;
1593 u.uri_len = sizeof(session_uri);
1594
1595 int err = remote_session_control(FASTRPC_GET_URI, (void *) &u, sizeof(u));
1596 if (err != AEE_SUCCESS) {
1597 // fallback to single session uris
1598 int htp_URI_domain_len = strlen(htp_uri) + MAX_DOMAIN_NAMELEN;
1599
1600 snprintf(session_uri, htp_URI_domain_len, "%s%s", htp_uri, my_domain->uri);
1601
1602 GGML_LOG_WARN("ggml-hex: failed to get URI for session %d : error 0x%x. Falling back to single session URI: %s\n", dev_id, err, session_uri);
1603 }
1604 }
1605
1606 // Enable Unsigned PD
1607 {
1608 struct remote_rpc_control_unsigned_module u;
1609 u.domain = this->domain_id;
1610 u.enable = 1;
1611 int err = remote_session_control(DSPRPC_CONTROL_UNSIGNED_MODULE, (void *) &u, sizeof(u));
1612 if (err != AEE_SUCCESS) {
1613 GGML_LOG_ERROR("ggml-hex: failed to enable unsigned PD for session %d : error 0x%x\n", dev_id, err);
1614 throw std::runtime_error("ggml-hex: remote_session_control(unsign) failed (see log for details)");
1615 }
1616 }
1617
1618 // Open session
1619 int err = htp_iface_open(session_uri, &this->handle);
1620 if (err != AEE_SUCCESS) {
1621 GGML_LOG_ERROR("ggml-hex: failed to open session %d : error 0x%x\n", dev_id, err);
1622 throw std::runtime_error("ggml-hex: failed to open session (see log for details)");
1623 }
1624
1625 this->valid_handle = true;
1626
1627 GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(),
1628 this->session_id, this->domain_id, session_uri, (unsigned long) this->handle);
1629
1630 // Enable FastRPC QoS mode
1631 {
1632 struct remote_rpc_control_latency l;
1633 l.enable = 1;
1634
1635 int err = remote_handle64_control(this->handle, DSPRPC_CONTROL_LATENCY, (void *) &l, sizeof(l));
1636 if (err != 0) {
1637 GGML_LOG_WARN("ggml-hex: failed to enable fastrpc QOS mode: 0x%08x\n", (unsigned) err);
1638 }
1639 }
1640
1641 // Now let's setup the DSP queue
1642 err = dspqueue_create(this->domain_id,
1643 0, // Flags
1644 128 * 1024, // Request queue size (in bytes)
1645 64 * 1024, // Response queue size (in bytes)
1646 nullptr, // Read packet callback (we handle reads explicitly)
1647 nullptr, // Error callback (we handle errors during reads)
1648 (void *) this, // Callback context
1649 &queue);
1650 if (err != 0) {
1651 GGML_LOG_ERROR("ggml-hex: %s dspqueue_create failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
1652 throw std::runtime_error("ggml-hex: failed to create dspqueue (see log for details)");
1653 }
1654
1655 this->valid_queue = true;
1656
1657 // Export queue for use on the DSP
1658 err = dspqueue_export(queue, &this->queue_id);
1659 if (err != 0) {
1660 GGML_LOG_ERROR("ggml-hex: dspqueue_export failed: 0x%08x\n", (unsigned) err);
1661 throw std::runtime_error("ggml-hex: dspqueue export failed (see log for details)");
1662 }
1663
1664 if (opt_etm) {
1665 err = htp_iface_enable_etm(this->handle);
1666 if (err != 0) {
1667 GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err);
1668 }
1669 }
1670
1671 // Start the DSP-side service. We need to pass the queue ID to the
1672 // DSP in a FastRPC call; the DSP side will import the queue and start
1673 // listening for packets in a callback.
1674 err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx);
1675 if (err != 0) {
1676 GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
1677 throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
1678 }
1679 this->valid_iface = true;
1680}
1681
1682void ggml_hexagon_session::release() noexcept(true) {
1683 GGML_LOG_INFO("ggml-hex: releasing session: %s\n", this->name.c_str());
1684
1685 int err;
1686
1687 // Stop the DSP-side service and close the queue
1688 if (this->valid_iface) {
1689 err = htp_iface_stop(this->handle);
1690 if (err != 0) {
1691 GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err);
1692 }
1693 }
1694
1695 if (opt_etm) {
1696 err = htp_iface_disable_etm(this->handle);
1697 if (err != 0) {
1698 GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err);
1699 }
1700 }
1701
1702 if (this->valid_queue) {
1703 err = dspqueue_close(queue);
1704 if (err != 0) {
1705 GGML_ABORT("ggml-hex: dspqueue_close failed: 0x%08x\n", (unsigned) err);
1706 }
1707 }
1708
1709 if (this->valid_handle) {
1710 htp_iface_close(this->handle);
1711 }
1712}
1713
1714ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) {
1715 buffer_type.device = dev;
1716 repack_buffer_type.device = dev;
1717
1718 try {
1719 allocate(dev_id);
1720
1721 buffer_type.iface = ggml_backend_hexagon_buffer_type_interface;
1722 buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name, this);
1723
1724 repack_buffer_type.iface = ggml_backend_hexagon_repack_buffer_type_interface;
1725 repack_buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name + "-REPACK", this);
1726 } catch (const std::exception & exc) {
1727 release();
1728 throw;
1729 }
1730}
1731
1732ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {
1733 release();
1734
1735 delete static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type.context);
1736 delete static_cast<ggml_backend_hexagon_buffer_type_context *>(repack_buffer_type.context);
1737}
1738
1739// ** backend interface
1740
1741static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) {
1742 return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment;
1743}
1744
1745static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) {
1746 if (!opt_hostbuf) {
1747 return ggml_backend_buffer_is_hexagon(b);
1748 }
1749 return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
1750}
1751
1752static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
1753 if (x->ne[0] != y->ne[0]) {
1754 return false;
1755 }
1756 if (x->ne[1] != y->ne[1]) {
1757 return false;
1758 }
1759 if (x->ne[2] != y->ne[2]) {
1760 return false;
1761 }
1762 if (x->ne[3] != y->ne[3]) {
1763 return false;
1764 }
1765
1766 return true;
1767}
1768
1769static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1770 const struct ggml_tensor * src0 = op->src[0];
1771 const struct ggml_tensor * src1 = op->src[1];
1772 const struct ggml_tensor * src2 = op->src[2];
1773 const struct ggml_tensor * src3 = op->src[3];
1774 const struct ggml_tensor * src4 = op->src[4];
1775 const struct ggml_tensor * dst = op;
1776
1777 // Check for F16 support only as requested
1778 if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {
1779 return false;
1780 }
1781
1782 if (src3 && src3->type != GGML_TYPE_F16) { // mask
1783 return false;
1784 }
1785
1786 if (src4 && src4->type != GGML_TYPE_F32) { // sinks
1787 return false;
1788 }
1789
1790 // For now we support F32 or F16 output as htp backend often converts output on the fly if needed,
1791 // but the op implementation writes to F16 or F32.
1792 // Let's assume dst can be F32 or F16.
1793 if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
1794 return false;
1795 }
1796
1797 return opt_experimental;
1798}
1799
1800static bool hex_supported_src0_type(ggml_type t) {
1801 return t == GGML_TYPE_F32;
1802}
1803
1804static bool hex_supported_src1_type(ggml_type t) {
1805 return t == GGML_TYPE_F32;
1806}
1807
1808static bool hex_supported_src2_type(ggml_type t) {
1809 return t == GGML_TYPE_F32;
1810}
1811
1812static bool hex_supported_src1_type2(ggml_type t) {
1813 return t == GGML_TYPE_F16;
1814}
1815
1816static bool hex_supported_src1_type3(ggml_type t) {
1817 return t == GGML_TYPE_I32;
1818}
1819
1820static bool hex_supported_dst_type(ggml_type t) {
1821 return t == GGML_TYPE_F32;
1822}
1823
1824static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
1825 // TODO: support broadcast for ne[2 and 3]
1826 if (x->ne[0] != y->ne[0]) {
1827 return false;
1828 }
1829 if (x->ne[2] != y->ne[2]) {
1830 return false;
1831 }
1832 if (x->ne[3] != y->ne[3]) {
1833 return false;
1834 }
1835 return true;
1836}
1837
1838static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
1839 const struct ggml_tensor * src0 = dst->src[0];
1840 const struct ggml_tensor * src1 = dst->src[1];
1841
1842 if (dst->type != GGML_TYPE_F32) {
1843 return false;
1844 }
1845
1846 if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
1847 return false;
1848 }
1849
1850 switch (src0->type) {
1851 case GGML_TYPE_Q4_0:
1852 case GGML_TYPE_Q8_0:
1853 case GGML_TYPE_MXFP4:
1854 if (src0->ne[0] % 32) {
1855 return false;
1856 }
1857
1858 if (src0->ne[1] > 16 * 1024) {
1859 return false; // typically the lm-head which would be too large for VTCM
1860 }
1861
1862 if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
1863 return false;
1864 }
1865
1866 // src0 (weights) must be repacked
1867 if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {
1868 return false;
1869 }
1870 break;
1871
1872 case GGML_TYPE_F16:
1873 if (src0->nb[1] < src0->nb[0]) {
1874 GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n");
1875 return false;
1876 }
1877 break;
1878
1879 default:
1880 return false;
1881 }
1882
1883 return true;
1884}
1885
1886static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1887 const struct ggml_tensor * src0 = op->src[0];
1888 const struct ggml_tensor * src1 = op->src[1];
1889 const struct ggml_tensor * src2 = op->src[2];
1890 const struct ggml_tensor * dst = op;
1891
1892 if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32 || src2->type != GGML_TYPE_I32) {
1893 return false;
1894 }
1895
1896 switch (src0->type) {
1897 case GGML_TYPE_Q4_0:
1898 case GGML_TYPE_Q8_0:
1899 case GGML_TYPE_MXFP4:
1900 if ((src0->ne[0] % 32)) {
1901 return false;
1902 }
1903
1904 // src0 (weights) must be repacked
1905 if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {
1906 return false;
1907 }
1908 break;
1909
1910 default:
1911 return false;
1912 }
1913
1914 return true;
1915}
1916
1917static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1918 const struct ggml_tensor * src0 = op->src[0];
1919 const struct ggml_tensor * src1 = op->src[1];
1920 const struct ggml_tensor * dst = op;
1921
1922 if (!hex_supported_src0_type(src0->type)) {
1923 return false;
1924 }
1925 if (!hex_supported_src1_type(src1->type)) {
1926 return false;
1927 }
1928 if (!hex_supported_dst_type(dst->type)) {
1929 return false;
1930 }
1931 if (!hex_supported_dims2(src0, dst)) {
1932 return false;
1933 }
1934 if (!ggml_can_repeat(src1, src0)) {
1935 return false;
1936 }
1937
1938 return true;
1939}
1940
1941static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1942 const struct ggml_tensor * src0 = op->src[0];
1943 const struct ggml_tensor * src1 = op->src[1];
1944 const struct ggml_tensor * dst = op;
1945
1946 if (!hex_supported_src0_type(src0->type)) {
1947 return false;
1948 }
1949 if (!hex_supported_src1_type(src1->type)) {
1950 return false;
1951 }
1952 if (!hex_supported_dst_type(dst->type)) {
1953 return false;
1954 }
1955 if (!hex_supported_dims2(src0, dst)) {
1956 return false;
1957 }
1958
1959 // REVISIT: add support for non-contigiuos tensors
1960 if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
1961 return false;
1962 }
1963
1964 return true;
1965}
1966
1967static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1968 const struct ggml_tensor * src0 = op->src[0];
1969 const struct ggml_tensor * dst = op;
1970
1971 if (!hex_supported_src0_type(src0->type)) {
1972 return false;
1973 }
1974 if (!hex_supported_dst_type(dst->type)) {
1975 return false;
1976 }
1977 if (!hex_supported_dims2(src0, dst)) {
1978 return false;
1979 }
1980
1981 // TODO: add support for non-contigiuos tensors
1982 if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
1983 return false;
1984 }
1985
1986 return true;
1987}
1988
1989static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1990 const struct ggml_tensor * src0 = op->src[0];
1991 const struct ggml_tensor * dst = op;
1992
1993 if (!hex_supported_src0_type(src0->type)) {
1994 return false;
1995 }
1996 if (!hex_supported_dst_type(dst->type)) {
1997 return false;
1998 }
1999
2000 // TODO: add support for non-contigiuos tensors
2001 if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
2002 return false;
2003 }
2004
2005 return true;
2006}
2007
2008static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
2009 const struct ggml_tensor * op) {
2010 const struct ggml_tensor * src0 = op->src[0];
2011 const struct ggml_tensor * src1 = op->src[1];
2012 const struct ggml_tensor * dst = op;
2013
2014 if (!hex_supported_src0_type(src0->type)) {
2015 return false;
2016 }
2017 if (!hex_supported_dst_type(dst->type)) {
2018 return false;
2019 }
2020
2021 if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
2022 return false;
2023 }
2024
2025 if (src1) {
2026 if (!hex_supported_src1_type(src1->type)) {
2027 return false;
2028 }
2029 if (!hex_supported_dims2(src0, src1)) {
2030 return false;
2031 }
2032 if (!ggml_is_contiguous(src1)) {
2033 return false;
2034 }
2035 }
2036
2037 return true;
2038}
2039
2040static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2041 const struct ggml_tensor * src0 = op->src[0];
2042 const struct ggml_tensor * src1 = op->src[1];
2043 const struct ggml_tensor * src2 = op->src[2];
2044 const struct ggml_tensor * dst = op;
2045
2046 if (src2) {
2047 return false; // FIXME: add support for sinks
2048 }
2049
2050 if (!hex_supported_src0_type(src0->type)) {
2051 return false;
2052 }
2053 if (!hex_supported_dst_type(dst->type)) {
2054 return false;
2055 }
2056
2057 if (src1) {
2058 if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
2059 return false;
2060 }
2061 if (src0->ne[0] != src1->ne[0]) {
2062 return false;
2063 }
2064 if (src1->ne[1] < src0->ne[1]) {
2065 return false;
2066 }
2067 if (src0->ne[2] % src1->ne[2] != 0) {
2068 return false;
2069 }
2070 if (src0->ne[3] % src1->ne[3] != 0) {
2071 return false;
2072 }
2073 }
2074
2075 if (src1) {
2076 if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
2077 return false;
2078 }
2079 } else {
2080 if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
2081 return false;
2082 }
2083 }
2084
2085 return true;
2086}
2087
2088static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2089 const struct ggml_tensor * src0 = op->src[0]; // values
2090 const struct ggml_tensor * src1 = op->src[1]; // indices
2091 const struct ggml_tensor * dst = op;
2092
2093 if (src0->type != GGML_TYPE_F32) {
2094 return false;
2095 }
2096
2097 if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
2098 return false;
2099 }
2100
2101 if (dst->type != GGML_TYPE_F16) {
2102 return false;
2103 }
2104
2105 return true;
2106}
2107
2108static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2109 const struct ggml_tensor * src0 = op->src[0]; // values
2110 const struct ggml_tensor * src1 = op->src[1]; // indices
2111 const struct ggml_tensor * dst = op;
2112
2113 if (src0->type != GGML_TYPE_F32) {
2114 return false;
2115 }
2116
2117 if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
2118 return false;
2119 }
2120
2121 if (dst->type != GGML_TYPE_F32) {
2122 return false;
2123 }
2124
2125 return true;
2126}
2127
2128static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2129 const struct ggml_tensor * src0 = op->src[0]; // values
2130 const struct ggml_tensor * dst = op; // indices
2131
2132 if (src0->type != GGML_TYPE_F32) {
2133 return false;
2134 }
2135
2136 if (dst->type != GGML_TYPE_I32) {
2137 return false;
2138 }
2139
2140 if (src0->ne[0] > (16*1024)) {
2141 // reject tensors with huge rows for now
2142 return false;
2143 }
2144
2145 return true;
2146}
2147
2148static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2149 const int32_t * op_params = &op->op_params[0];
2150
2151 int mode = op_params[2];
2152
2153 if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
2154 return false;
2155 }
2156 if (mode & 1) {
2157 return false;
2158 }
2159
2160 const struct ggml_tensor * src0 = op->src[0];
2161 const struct ggml_tensor * src1 = op->src[1];
2162 const struct ggml_tensor * src2 = op->src[2];
2163 const struct ggml_tensor * dst = op;
2164
2165 if (!hex_supported_src0_type(src0->type)) {
2166 return false; // FIXME: add support for GGML_TYPE_F16 for src0
2167 }
2168 if (!hex_supported_dst_type(dst->type)) {
2169 return false;
2170 }
2171 if (!hex_supported_src1_type3(src1->type)) {
2172 return false;
2173 }
2174 if (src2) {
2175 if (!hex_supported_src2_type(src2->type)) {
2176 return false;
2177 }
2178 int n_dims = op_params[1];
2179 if (src2->ne[0] < (n_dims / 2)) {
2180 return false;
2181 }
2182 }
2183
2184 if (src2) {
2185 if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(src2) ||
2186 !ggml_is_contiguous(dst)) {
2187 return false;
2188 }
2189 } else {
2190 if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
2191 return false;
2192 }
2193 }
2194
2195 return true;
2196}
2197
2198enum dspqbuf_type {
2199 DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
2200 DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
2201 DSPQBUF_TYPE_CONSTANT,
2202};
2203
2204static void dspqbuf_dump(dspqueue_buffer * d, const struct ggml_tensor * t, dspqbuf_type type) {
2205 if (opt_verbose < 2) return;
2206
2207 auto buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);
2208 auto sess = buf->sess;
2209
2210 GGML_LOG_DEBUG("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(),
2211 t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset,
2212 (unsigned int) d->size);
2213}
2214
2215// Init hexagon tensor from GGML tensor and Hexagon buffer
2216static void htp_req_tensor_init(htp_tensor * h, const ggml_tensor * t) {
2217 h->data = 0; // updated by the receiver
2218 h->type = t->type;
2219 h->ne[0] = t->ne[0];
2220 h->ne[1] = t->ne[1];
2221 h->ne[2] = t->ne[2];
2222 h->ne[3] = t->ne[3];
2223 h->nb[0] = t->nb[0];
2224 h->nb[1] = t->nb[1];
2225 h->nb[2] = t->nb[2];
2226 h->nb[3] = t->nb[3];
2227}
2228
2229static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_tensor * t, dspqbuf_type type) {
2230 if (!t) {
2231 return 0;
2232 }
2233
2234 auto buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);
2235
2236 memset(d, 0, sizeof(*d));
2237 d->fd = buf->fd;
2238 d->ptr = t->data;
2239 d->offset = (uint8_t *) t->data - buf->base;
2240 d->size = ggml_nbytes(t);
2241
2242 if (!d->size) {
2243 // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty
2244 d->size = 64;
2245 }
2246
2247 switch (type) {
2248 case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:
2249 // Flush CPU
2250 d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER;
2251 break;
2252 case DSPQBUF_TYPE_CPU_WRITE_DSP_READ:
2253 // Flush CPU, Invalidate DSP
2254 d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
2255 break;
2256 default:
2257 // Constant buffer, no cache maintenance
2258 d->flags = 0;
2259 break;
2260 }
2261
2262 htp_req_tensor_init(h, t);
2263
2264 dspqbuf_dump(d, t, type);
2265
2266 return 1;
2267}
2268
2269typedef size_t (*htp_req_init_func_t)(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * op);
2270
2271template <htp_req_init_func_t _init_req_func>
2272static inline void ggml_hexagon_dispatch_op(ggml_hexagon_session *sess, const struct ggml_tensor * op, uint32_t flags) {
2273 uint64_t t = ggml_time_us();
2274
2275 // Construct HTP request
2276 htp_general_req req;
2277 memset(&req, 0, sizeof(req));
2278
2279 req.flags = flags;
2280 if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) {
2281 req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
2282 }
2283 if (!(opt_opmask & HTP_OPMASK_COMPUTE)) {
2284 req.flags |= HTP_OPFLAGS_SKIP_COMPUTE;
2285 }
2286
2287 ggml_hexagon_dump_op_exec(sess->name, op, req.flags);
2288
2289 if ((opt_opmask & HTP_OPMASK_QUEUE)) {
2290 dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
2291 size_t n_bufs = _init_req_func(&req, bufs, op);
2292 sess->enqueue(req, bufs, n_bufs, opt_opsync);
2293 }
2294
2295 t = ggml_time_us() - t;
2296
2297 ggml_hexagon_dump_op_prof(sess->name, op, sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, t);
2298}
2299
2300template <bool _is_src0_constant>
2301static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2302 switch (t->op) {
2303 case GGML_OP_MUL_MAT:
2304 req->op = HTP_OP_MUL_MAT;
2305 break;
2306 case GGML_OP_MUL:
2307 req->op = HTP_OP_MUL;
2308 break;
2309 case GGML_OP_ADD:
2310 req->op = HTP_OP_ADD;
2311 break;
2312 case GGML_OP_SUB:
2313 req->op = HTP_OP_SUB;
2314 break;
2315 case GGML_OP_DIV:
2316 req->op = HTP_OP_DIV;
2317 break;
2318 default:
2319 GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
2320 break;
2321 }
2322
2323 // src0: Weights (mulmat) or First Operand (binary op).
2324 // If constant (e.g. weights), no cache management is needed.
2325 // src1: Input Activations (mulmat) or Second Operand (binary op).
2326
2327 size_t n_bufs = 0;
2328 n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2329 n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2330 n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2331
2332 return n_bufs;
2333}
2334
2335static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2336 req->op = HTP_OP_CPY;
2337
2338 size_t n_bufs = 0;
2339 n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2340 n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2341
2342 return n_bufs;
2343}
2344
2345static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2346 req->op = HTP_OP_GET_ROWS;
2347
2348 size_t n_bufs = 0;
2349 n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2350 n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2351 n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2352
2353 return n_bufs;
2354}
2355
2356static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2357 req->op = HTP_OP_ARGSORT;
2358 memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2359
2360 size_t n_bufs = 0;
2361 n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2362 n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2363
2364 return n_bufs;
2365}
2366
2367template <bool _is_src0_constant>
2368static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2369 switch (t->op) {
2370 case GGML_OP_MUL_MAT_ID:
2371 req->op = HTP_OP_MUL_MAT_ID;
2372 break;
2373 case GGML_OP_ADD_ID:
2374 req->op = HTP_OP_ADD_ID;
2375 break;
2376 default:
2377 GGML_ABORT("ggml-hex: unsupported op: %d\n", t->op);
2378 }
2379
2380 // src0: Weights (mulmat) or Input Activations (other op).
2381 // If constant, no cache management is needed.
2382 // src1: Input Activations (mulmat) or Second Operand (binary op).
2383 // src2: Expert IDs (mulmat) or Activated Experts (other op).
2384
2385 size_t n_bufs = 0;
2386 n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2387 n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2388 n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2389 n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2390
2391 return n_bufs;
2392}
2393
2394static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2395 req->op = HTP_OP_SET_ROWS;
2396
2397 size_t n_bufs = 0;
2398 n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2399 n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2400 n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2401
2402 return n_bufs;
2403}
2404
2405static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2406 memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2407
2408 bool supported = false;
2409
2410 switch (t->op) {
2411 case GGML_OP_RMS_NORM:
2412 req->op = HTP_OP_RMS_NORM;
2413 supported = true;
2414 break;
2415
2416 case GGML_OP_SCALE:
2417 req->op = HTP_OP_SCALE;
2418 supported = true;
2419 break;
2420
2421 case GGML_OP_SQR:
2422 req->op = HTP_OP_SQR;
2423 supported = true;
2424 break;
2425
2426 case GGML_OP_SQRT:
2427 req->op = HTP_OP_SQRT;
2428 supported = true;
2429 break;
2430
2431 case GGML_OP_UNARY:
2432 if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
2433 req->op = HTP_OP_UNARY_SILU;
2434 supported = true;
2435 } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) {
2436 req->op = HTP_OP_UNARY_GELU;
2437 supported = true;
2438 }
2439 break;
2440
2441 case GGML_OP_GLU:
2442 if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU) {
2443 req->op = HTP_OP_GLU_SWIGLU;
2444 supported = true;
2445 } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
2446 req->op = HTP_OP_GLU_SWIGLU_OAI;
2447 supported = true;
2448 } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
2449 req->op = HTP_OP_GLU_GEGLU;
2450 supported = true;
2451 }
2452 break;
2453
2454 case GGML_OP_SOFT_MAX:
2455 req->op = HTP_OP_SOFTMAX;
2456 supported = true;
2457 break;
2458
2459 default:
2460 break;
2461 }
2462
2463 if (!supported) {
2464 GGML_ABORT("ggml-hex: unary : unsupported op: %d\n", t->op);
2465 }
2466
2467 size_t n_bufs = 0;
2468 n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2469 n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2470 n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2471
2472 return n_bufs;
2473}
2474
2475static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2476 memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2477 req->op = HTP_OP_SUM_ROWS;
2478
2479 size_t n_bufs = 0;
2480 n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2481 n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2482
2483 return n_bufs;
2484}
2485
2486static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2487 memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2488 req->op = HTP_OP_ROPE;
2489
2490 size_t n_bufs = 0;
2491 n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2492 n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2493 n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2494 n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2495
2496 return n_bufs;
2497}
2498
2499static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2500 memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2501 req->op = HTP_OP_FLASH_ATTN_EXT;
2502
2503 size_t n_bufs = 0;
2504 n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2505 n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2506 n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2507 n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2508 n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2509 n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2510
2511 return n_bufs;
2512}
2513
2514static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
2515 auto sess = static_cast<ggml_hexagon_session *>(backend->context);
2516 return sess->name.c_str();
2517}
2518
2519static void ggml_backend_hexagon_free(ggml_backend_t backend) {
2520 // we just need to delete the backend here
2521 // the sessions are allocated & freed as part of the registry
2522 delete backend;
2523}
2524
2525static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
2526 return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
2527}
2528
2529static inline bool is_compute_op(ggml_tensor *node)
2530{
2531 return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
2532}
2533
2534// scan the graph and figure out last compute op index
2535static inline int last_compute_op(ggml_cgraph * graph) {
2536 int last = 0;
2537 for (int i = 0; i < graph->n_nodes; ++i) {
2538 if (is_compute_op(graph->nodes[i])) {
2539 last = i;
2540 }
2541 }
2542
2543 return last;
2544}
2545
2546static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
2547 auto sess = static_cast<ggml_hexagon_session *>(backend->context);
2548
2549 HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes);
2550
2551 const int last = last_compute_op(graph);
2552
2553 const struct ggml_tensor * prev_op = nullptr; // prev executed op
2554
2555 for (int i = 0; i < graph->n_nodes; ++i) {
2556 ggml_tensor * node = graph->nodes[i];
2557
2558 if (!is_compute_op(node)) {
2559 continue;
2560 }
2561
2562 uint32_t flags = 0;
2563
2564 // skip quantizer if src1 is reused
2565 if (op_reuse_src1(node, prev_op)) {
2566 flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
2567 }
2568
2569 prev_op = node;
2570
2571 // ask for early notification for the last Op
2572 if (i == last) {
2573 flags |= HTP_OPFLAGS_EARLY_WAKEUP;
2574 }
2575
2576 switch (node->op) {
2577 case GGML_OP_MUL_MAT:
2578 if (ggml_is_quantized(node->src[0]->type)) {
2579 ggml_hexagon_dispatch_op<init_binary_req<true>>(sess, node, flags);
2580 } else {
2581 ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
2582 }
2583 break;
2584 case GGML_OP_MUL_MAT_ID:
2585 if (ggml_is_quantized(node->src[0]->type)) {
2586 ggml_hexagon_dispatch_op<init_binary_id_req<true>>(sess, node, flags);
2587 } else {
2588 ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
2589 }
2590 break;
2591 case GGML_OP_MUL:
2592 case GGML_OP_ADD:
2593 case GGML_OP_SUB:
2594 case GGML_OP_DIV:
2595 ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
2596 break;
2597 case GGML_OP_ADD_ID:
2598 ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
2599 break;
2600 case GGML_OP_RMS_NORM:
2601 case GGML_OP_SCALE:
2602 ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2603 break;
2604 case GGML_OP_SQR:
2605 case GGML_OP_SQRT:
2606 ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2607 break;
2608 case GGML_OP_SUM_ROWS:
2609 ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
2610 break;
2611 case GGML_OP_UNARY:
2612 if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
2613 (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
2614 ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2615 }
2616 break;
2617 case GGML_OP_GLU:
2618 if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
2619 (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
2620 (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
2621 ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2622 }
2623 break;
2624 case GGML_OP_SOFT_MAX:
2625 ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2626 break;
2627
2628 case GGML_OP_ROPE:
2629 ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);
2630 break;
2631
2632 case GGML_OP_FLASH_ATTN_EXT:
2633 ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);
2634 break;
2635
2636 case GGML_OP_SET_ROWS:
2637 ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);
2638 break;
2639
2640 case GGML_OP_GET_ROWS:
2641 ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
2642 break;
2643
2644 case GGML_OP_CPY:
2645 ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
2646 break;
2647
2648 case GGML_OP_ARGSORT:
2649 ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
2650 break;
2651
2652 default:
2653 GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
2654 }
2655 }
2656
2657 // Wait until all pending ops complete
2658 sess->flush();
2659
2660 return GGML_STATUS_SUCCESS;
2661}
2662
2663static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) {
2664 auto sess = static_cast<ggml_hexagon_session *>(backend->context);
2665
2666 HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str());
2667
2668 // Wait until all pending ops complete
2669 sess->flush();
2670}
2671
2672struct node_info {
2673 ggml_tensor * node;
2674
2675 std::vector<ggml_tensor *> fused;
2676
2677 ggml_op op() const {
2678 return node->op;
2679 }
2680
2681 const ggml_tensor * dst() const {
2682 return fused.empty() ? node : fused.back();
2683 }
2684
2685 const ggml_tensor * src0() const {
2686 return node->src[0];
2687 }
2688
2689 const ggml_tensor * src1() const {
2690 return node->src[1];
2691 }
2692
2693 bool is_empty() const {
2694 return ggml_op_is_empty(node->op);
2695 }
2696
2697 void add_fused(ggml_tensor * t) {
2698 fused.push_back(t);
2699 }
2700
2701 bool stackable() const {
2702 switch (this->op()) {
2703 case GGML_OP_MUL_MAT:
2704 case GGML_OP_MUL_MAT_ID:
2705 return ggml_is_quantized(this->src0()->type);
2706 default:
2707 return false;
2708 }
2709 }
2710
2711 bool same_input(const node_info& n) const {
2712 return n.src1() == this->src1();
2713 }
2714};
2715
2716static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<node_info> & nodes) {
2717 const int n = nodes.size();
2718
2719 std::vector<int> res;
2720 res.reserve(n);
2721
2722 std::vector<bool> used(n, false);
2723
2724 // The main goal here is to stack the MUL_MAT ops with the same src1 input.
2725 // This allows use to reuse dynamically quantized src1 in VTCM.
2726
2727 // TODO: the current version might do incorrect reodering in cases where quantized src0
2728 // input is an output of another Op.
2729
2730 for (int i0 = 0; i0 < n; i0++) {
2731 if (used[i0]) {
2732 continue;
2733 }
2734
2735 res.push_back(i0);
2736
2737 const auto & node0 = nodes[i0];
2738
2739 if (!node0.stackable()) {
2740 continue;
2741 }
2742
2743 // that many nodes forward to search for stackable nodes that can reuse VTCM
2744 constexpr int N_FORWARD = 16;
2745
2746 for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
2747 if (used[i1]) {
2748 continue;
2749 }
2750
2751 const auto & node1 = nodes[i1];
2752
2753 if (node1.stackable() && node1.same_input(node0)) {
2754 res.push_back(i1);
2755 used[i1] = true;
2756 }
2757 }
2758 }
2759
2760 return res;
2761}
2762
2763static void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgraph * gf) {
2764 const int n = gf->n_nodes;
2765
2766 constexpr int MAX_FUSE = 16;
2767
2768 enum ggml_op ops[MAX_FUSE];
2769
2770 std::vector<node_info> nodes;
2771 nodes.reserve(gf->n_nodes);
2772
2773 // fuse nodes:
2774 // we don't want to make reorders that break fusing, so we first pack all fusable tensors
2775 // and perform the reorder over the fused nodes. after the reorder is done, we unfuse
2776 for (int i = 0; i < n; i++) {
2777 node_info node = {
2778 /*.node =*/gf->nodes[i],
2779 /*.fused =*/{},
2780 };
2781
2782 // fuse only ops that start with these operations
2783 // can be expanded when needed
2784 if (node.op() == GGML_OP_ADD ||
2785 node.op() == GGML_OP_NORM ||
2786 node.op() == GGML_OP_RMS_NORM) {
2787 ops[0] = node.op();
2788
2789 int f = i + 1;
2790 while (f < n && f < i + MAX_FUSE) {
2791 // conservatively allow fusing only these ops
2792 // can be expanded when needed
2793 if (gf->nodes[f]->op != GGML_OP_ADD &&
2794 gf->nodes[f]->op != GGML_OP_MUL &&
2795 gf->nodes[f]->op != GGML_OP_NORM &&
2796 gf->nodes[f]->op != GGML_OP_RMS_NORM) {
2797 break;
2798 }
2799 ops[f - i] = gf->nodes[f]->op;
2800 f++;
2801 }
2802
2803 f -= i;
2804 for (; f > 1; f--) {
2805 if (ggml_can_fuse(gf, i, ops, f)) {
2806 break;
2807 }
2808 }
2809
2810 // add the fused tensors into the node info so we can unfuse them later
2811 for (int k = 1; k < f; k++) {
2812 ++i;
2813
2814 // the .dst() becomes the last fused tensor
2815 node.add_fused(gf->nodes[i]);
2816 }
2817 }
2818
2819 nodes.push_back(std::move(node));
2820 }
2821
2822 const auto order = ggml_hexagon_graph_optimize_reorder(nodes);
2823
2824 // unfuse
2825 {
2826 int j = 0;
2827 for (const auto i : order) {
2828 const auto & node = nodes[i];
2829
2830 gf->nodes[j++] = node.node;
2831
2832 for (auto * fused : node.fused) {
2833 gf->nodes[j++] = fused;
2834 }
2835 }
2836 }
2837}
2838
2839static struct ggml_backend_i hexagon_backend_i = {
2840 /* .get_name = */ ggml_backend_hexagon_name,
2841 /* .free = */ ggml_backend_hexagon_free,
2842 /* .set_tensor_async = */ NULL,
2843 /* .get_tensor_async = */ NULL,
2844 /* .cpy_tensor_async = */ NULL,
2845 /* .synchronize = */ ggml_backend_hexagon_synchronize,
2846 /* .graph_plan_create = */ NULL,
2847 /* .graph_plan_free = */ NULL,
2848 /* .graph_plan_update = */ NULL,
2849 /* .graph_plan_compute = */ NULL,
2850 /* .graph_compute = */ ggml_backend_hexagon_graph_compute,
2851 /* .event_record = */ NULL,
2852 /* .event_wait = */ NULL,
2853 /* .graph_optimize = */ ggml_backend_hexagon_graph_optimize,
2854};
2855
2856static ggml_guid_t ggml_backend_hexagon_guid() {
2857 static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49,
2858 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11 };
2859 return &guid;
2860}
2861
2862bool ggml_backend_is_hexagon(ggml_backend_t backend) {
2863 return backend && backend->iface.get_name == ggml_backend_hexagon_name;
2864}
2865
2866// device interface
2867
2868static ggml_backend_t ggml_backend_hexagon_device_init(ggml_backend_dev_t dev, const char * params) {
2869 auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2870
2871 return new ggml_backend{
2872 /* .guid = */ ggml_backend_hexagon_guid(),
2873 /* .interface = */ hexagon_backend_i,
2874 /* .device = */ dev,
2875 /* .context = */ sess,
2876 };
2877
2878 GGML_UNUSED(params);
2879}
2880
2881static const char * ggml_backend_hexagon_device_get_name(ggml_backend_dev_t dev) {
2882 auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2883 return sess->name.c_str();
2884
2885 GGML_UNUSED(dev);
2886}
2887
2888static const char * ggml_backend_hexagon_device_get_description(ggml_backend_dev_t dev) {
2889 return "Hexagon";
2890 GGML_UNUSED(dev);
2891}
2892
2893static void ggml_backend_hexagon_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2894 // ~2GB per session for now
2895 *free = 2ULL * 1024 * 1024 * 1024;
2896 *total = *free;
2897
2898 GGML_UNUSED(dev);
2899}
2900
2901static enum ggml_backend_dev_type ggml_backend_hexagon_device_get_type(ggml_backend_dev_t dev) {
2902 return GGML_BACKEND_DEVICE_TYPE_GPU;
2903
2904 GGML_UNUSED(dev);
2905}
2906
2907static void ggml_backend_hexagon_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2908 props->name = ggml_backend_hexagon_device_get_name(dev);
2909 props->description = ggml_backend_hexagon_device_get_description(dev);
2910 props->type = ggml_backend_hexagon_device_get_type(dev);
2911 ggml_backend_hexagon_device_get_memory(dev, &props->memory_free, &props->memory_total);
2912 props->caps = {
2913 /* .async = */ true,
2914 /* .host_buffer = */ (bool) opt_hostbuf,
2915 /* .buffer_from_host_ptr = */ false,
2916 /* .events = */ false,
2917 };
2918}
2919
2920static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_buffer_type(ggml_backend_dev_t dev) {
2921 auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2922 return &sess->buffer_type;
2923}
2924
2925static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_repack_buffer_type(ggml_backend_dev_t dev) {
2926 auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2927 return &sess->repack_buffer_type;
2928}
2929
2930static bool ggml_hexagon_supported_buffer(ggml_hexagon_session *sess, const struct ggml_tensor * t) {
2931 if (t && t->buffer) {
2932 if (ggml_backend_buffer_is_hexagon(t->buffer) == false) return false; // not our buffer
2933 if (ggml_backend_hexagon_buffer_get_sess(t->buffer) != sess) return false; // wrong session
2934 }
2935 return true;
2936}
2937
2938static bool ggml_hexagon_supported_buffers(ggml_hexagon_session *sess, const struct ggml_tensor * t) {
2939 // all srcs & dsts must be mapped to the same session
2940 if (!ggml_hexagon_supported_buffer(sess, t)) {
2941 return false;
2942 }
2943
2944 for (int i = 0; i < GGML_MAX_SRC; i++) {
2945 if (!ggml_hexagon_supported_buffer(sess, t->src[i])) {
2946 return false;
2947 }
2948 }
2949
2950 return true;
2951}
2952
2953static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2954 const struct ggml_tensor * src0 = op->src[0];
2955 const struct ggml_tensor * dst = op;
2956
2957 // for now we can do f32 -> f16 and f16 -> f32 (without reshaping)
2958 if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
2959 if ( dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) return false;
2960
2961 const bool sametype = (src0->type == dst->type);
2962 const bool transposed = ggml_is_transposed(src0) || ggml_is_transposed(dst);
2963 const bool sameshape = !transposed && ggml_are_same_shape(src0, dst);
2964
2965 // can handle any shape and any same-type (pretty slow if reshaping is required)
2966 if (sametype) return true;
2967
2968 // cannot handle re-shaping and type conversion at the same time
2969 if (!sameshape) return false;
2970
2971 return true;
2972}
2973
2974static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
2975 auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2976
2977 // all srcs & dsts must be mapped to the same session
2978 if (!ggml_hexagon_supported_buffers(sess, op)) {
2979 ggml_hexagon_dump_op_supp(sess->name, op, false);
2980 return false;
2981 }
2982
2983 bool supp = false;
2984 switch (op->op) {
2985 case GGML_OP_NONE:
2986 case GGML_OP_RESHAPE:
2987 case GGML_OP_VIEW:
2988 case GGML_OP_PERMUTE:
2989 case GGML_OP_TRANSPOSE:
2990 supp = true;
2991 break;
2992
2993 case GGML_OP_MUL_MAT:
2994 supp = ggml_hexagon_supported_mul_mat(sess, op);
2995 break;
2996
2997 case GGML_OP_MUL_MAT_ID:
2998 supp = ggml_hexagon_supported_mul_mat_id(sess, op);
2999 break;
3000
3001 case GGML_OP_MUL:
3002 case GGML_OP_ADD:
3003 case GGML_OP_SUB:
3004 case GGML_OP_DIV:
3005 supp = ggml_hexagon_supported_binary(sess, op);
3006 break;
3007
3008 case GGML_OP_ADD_ID:
3009 supp = ggml_hexagon_supported_add_id(sess, op);
3010 break;
3011
3012 case GGML_OP_RMS_NORM:
3013 case GGML_OP_SCALE:
3014 supp = ggml_hexagon_supported_unary(sess, op);
3015 break;
3016
3017 case GGML_OP_SQR:
3018 case GGML_OP_SQRT:
3019 supp = ggml_hexagon_supported_unary(sess, op);
3020 break;
3021
3022 case GGML_OP_SUM_ROWS:
3023 supp = ggml_hexagon_supported_sum_rows(sess, op);
3024 break;
3025
3026 case GGML_OP_SOFT_MAX:
3027 supp = ggml_hexagon_supported_softmax(sess, op);
3028 break;
3029
3030 case GGML_OP_UNARY:
3031 {
3032 const auto unary_op = ggml_get_unary_op(op);
3033 if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) {
3034 supp = ggml_hexagon_supported_activations(sess, op);
3035 }
3036 break;
3037 }
3038 case GGML_OP_GLU:
3039 {
3040 const auto glu_op = ggml_get_glu_op(op);
3041 if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
3042 supp = ggml_hexagon_supported_activations(sess, op);
3043 }
3044 break;
3045 }
3046 case GGML_OP_ROPE:
3047 supp = ggml_hexagon_supported_rope(sess, op);
3048 break;
3049
3050 case GGML_OP_FLASH_ATTN_EXT:
3051 supp = ggml_hexagon_supported_flash_attn_ext(sess, op);
3052 break;
3053
3054 case GGML_OP_SET_ROWS:
3055 supp = ggml_hexagon_supported_set_rows(sess, op);
3056 break;
3057
3058 case GGML_OP_GET_ROWS:
3059 supp = ggml_hexagon_supported_get_rows(sess, op);
3060 break;
3061
3062 case GGML_OP_CPY:
3063 supp = ggml_hexagon_supported_cpy(sess, op);
3064 break;
3065
3066 case GGML_OP_ARGSORT:
3067 supp = ggml_hexagon_supported_argsort(sess, op);
3068 break;
3069
3070 default:
3071 break;
3072 }
3073
3074 ggml_hexagon_dump_op_supp(sess->name, op, supp);
3075 return supp;
3076}
3077
3078static bool ggml_backend_hexagon_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
3079 if (buft->iface.get_alignment != ggml_backend_hexagon_buffer_type_get_alignment) {
3080 return false;
3081 }
3082
3083 auto s0 = static_cast<ggml_hexagon_session *>(dev->context);
3084 auto s1 = static_cast<ggml_backend_hexagon_buffer_type_context *>(buft->context)->sess;
3085
3086 // Need session/domain-id for buffers to be compatible
3087 bool supp = (s0->session_id == s1->session_id);
3088
3089 HEX_VERBOSE("ggml-hex: %s device-supports-buft %s (%d)\n", s0->name.c_str(), s1->name.c_str(), (int) supp);
3090
3091 return supp;
3092}
3093
3094static ggml_backend_buffer_type_t * ggml_backend_hexagon_device_get_extra_buffers_type(ggml_backend_dev_t dev) {
3095 auto s0 = static_cast<ggml_hexagon_session *>(dev->context);
3096 HEX_VERBOSE("ggml-hex: device-get-extra-buft : %s \n", s0->name.c_str());
3097
3098 static ggml_backend_buffer_type_t bufts[2];
3099 bufts[0] = ggml_backend_hexagon_device_get_repack_buffer_type(dev);
3100 bufts[1] = NULL;
3101 return bufts;
3102}
3103
3104static const struct ggml_backend_device_i ggml_backend_hexagon_device_i = {
3105 /* .get_name = */ ggml_backend_hexagon_device_get_name,
3106 /* .get_description = */ ggml_backend_hexagon_device_get_description,
3107 /* .get_memory = */ ggml_backend_hexagon_device_get_memory,
3108 /* .get_type = */ ggml_backend_hexagon_device_get_type,
3109 /* .get_props = */ ggml_backend_hexagon_device_get_props,
3110 /* .init_backend = */ ggml_backend_hexagon_device_init,
3111 /* .get_buffer_type = */ ggml_backend_hexagon_device_get_buffer_type,
3112 /* .get_host_buffer_type = */ NULL, // ggml_backend_hexagon_device_get_host_buffer_type,
3113 /* .buffer_from_host_ptr = */ NULL, // ggml_backend_hexagon_device_buffer_from_ptr,
3114 /* .supports_op = */ ggml_backend_hexagon_device_supports_op,
3115 /* .supports_buft = */ ggml_backend_hexagon_device_supports_buft,
3116 /* .offload_op = */ NULL, // ggml_backend_hexagon_device_offload_op,
3117 /* .event_new = */ NULL,
3118 /* .event_free = */ NULL,
3119 /* .event_synchronize = */ NULL,
3120};
3121
3122//** backend registry
3123
3124#define GGML_HEXAGON_MAX_SESSIONS 16
3125
3126struct ggml_hexagon_registry {
3127 ggml_hexagon_registry(ggml_backend_reg_t reg);
3128 ~ggml_hexagon_registry();
3129
3130 ggml_backend_device devices[GGML_HEXAGON_MAX_SESSIONS];
3131};
3132
3133ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
3134 GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev);
3135
3136 if (!opt_arch) {
3137 int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch);
3138 if (err != 0) {
3139 GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err);
3140 opt_arch = 73;
3141 }
3142 }
3143
3144#if defined(__ANDROID__)
3145 if (opt_arch < 75) {
3146 opt_ndev = 1;
3147 GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
3148 }
3149#endif
3150
3151 GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
3152
3153 // Create devices / sessions
3154 for (size_t i = 0; i < opt_ndev; i++) {
3155 devices[i].iface = ggml_backend_hexagon_device_i;
3156 devices[i].reg = reg;
3157 try {
3158 devices[i].context = new ggml_hexagon_session(i, &devices[i]);
3159 } catch (const std::exception & exc) {
3160 GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i);
3161 devices[i].context = nullptr;
3162 }
3163 }
3164}
3165
3166ggml_hexagon_registry::~ggml_hexagon_registry() {
3167 GGML_LOG_INFO("ggml-hex: releasing registry\n");
3168
3169 // Release devices / sessions
3170 for (size_t i = 0; i < opt_ndev; i++) {
3171 auto sess = static_cast<ggml_hexagon_session *>(devices[i].context);
3172 delete sess;
3173 }
3174}
3175
3176static const char * ggml_backend_hexagon_reg_get_name(ggml_backend_reg_t reg) {
3177 return "HTP";
3178 GGML_UNUSED(reg);
3179}
3180
3181static size_t ggml_backend_hexagon_reg_get_device_count(ggml_backend_reg_t reg) {
3182 return opt_ndev;
3183 GGML_UNUSED(reg);
3184}
3185
3186static ggml_backend_dev_t ggml_backend_hexagon_reg_get_device(ggml_backend_reg_t reg, size_t index) {
3187 auto hreg = static_cast<ggml_hexagon_registry *>(reg->context);
3188
3189 if (index >= opt_ndev || !hreg->devices[index].context) {
3190 return nullptr;
3191 }
3192
3193 return &hreg->devices[index];
3194}
3195
3196static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, const char * name) {
3197 if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0 && opt_hostbuf) {
3198 ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_hexagon_device_get_extra_buffers_type;
3199 return (void *) fct;
3200 }
3201
3202 return NULL;
3203}
3204
3205static void ggml_hexagon_init(ggml_backend_reg * reg) {
3206 // Basic sanity checks to make sure definitions match
3207 static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0,
3208 "please update hexagon_type to match ggml_type");
3209 static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0,
3210 "please update hexagon_type to match ggml_type");
3211 static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
3212 "please update hexagon_type to match ggml_type");
3213
3214 const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL");
3215 const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
3216 const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF");
3217 const char * str_opmask = getenv("GGML_HEXAGON_OPMASK");
3218 const char * str_opsync = getenv("GGML_HEXAGON_OPSYNC");
3219 const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
3220 const char * str_etm = getenv("GGML_HEXAGON_ETM");
3221 const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
3222 const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
3223 const char * str_arch = getenv("GGML_HEXAGON_ARCH");
3224
3225 opt_experimental = str_experimental ? atoi(str_experimental) : 0;
3226 opt_verbose = str_verbose ? atoi(str_verbose) : 0;
3227 opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
3228 opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask;
3229 opt_opsync = str_opsync ? atoi(str_opsync) : 0;
3230 opt_profile = str_profile ? atoi(str_profile) : 0;
3231 opt_etm = str_etm ? atoi(str_etm) : 0;
3232 opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
3233 opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
3234
3235 if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
3236 opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
3237 }
3238
3239 if (str_arch) {
3240 if (str_arch[0] == 'v') {
3241 str_arch++;
3242 }
3243 opt_arch = strtoul(str_arch, NULL, 0);
3244 }
3245
3246 opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1;
3247
3248 reg->context = new ggml_hexagon_registry(reg);
3249
3250 HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req),
3251 sizeof(struct htp_general_rsp));
3252}
3253
3254static const struct ggml_backend_reg_i ggml_backend_hexagon_reg_i = {
3255 /* .get_name = */ ggml_backend_hexagon_reg_get_name,
3256 /* .get_device_count = */ ggml_backend_hexagon_reg_get_device_count,
3257 /* .get_device = */ ggml_backend_hexagon_reg_get_device,
3258 /* .get_proc_address = */ ggml_backend_hexagon_get_proc_address,
3259};
3260
3261ggml_backend_reg_t ggml_backend_hexagon_reg(void) {
3262 static bool initialized = false;
3263
3264 static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION,
3265 /* .iface = */ ggml_backend_hexagon_reg_i,
3266 /* .context = */ NULL };
3267
3268 {
3269 static std::mutex mutex;
3270 std::lock_guard<std::mutex> lock(mutex);
3271 if (!initialized) {
3272 auto nErr = htpdrv_init();
3273 if (nErr != AEE_SUCCESS) {
3274 return NULL;
3275 }
3276
3277 ggml_hexagon_init(®);
3278 }
3279
3280 initialized = true;
3281 }
3282
3283 return ®
3284}
3285
3286GGML_BACKEND_DL_IMPL(ggml_backend_hexagon_reg)