1#include <assert.h>
   2#include <inttypes.h>
   3#include <stdio.h>
   4#include <stdlib.h>
   5#include <string.h>
   6#include <time.h>
   7
   8#include <atomic>
   9#include <chrono>
  10#include <cstddef>
  11#include <mutex>
  12#include <stdexcept>
  13#include <string>
  14
  15#ifdef _WIN32
  16#    include <sal.h>
  17#else
  18#    include <semaphore.h>
  19#    include <unistd.h>
  20#endif
  21
  22#pragma clang diagnostic ignored "-Wnested-anon-types"
  23#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
  24
  25#include <AEEStdErr.h>
  26#include <dspqueue.h>
  27#include <rpcmem.h>
  28
  29#define GGML_COMMON_IMPL_CPP
  30#include "ggml-backend-impl.h"
  31#include "ggml-common.h"
  32#include "ggml-hexagon.h"
  33#include "ggml-impl.h"
  34#include "ggml-quants.h"
  35#include "op-desc.h"
  36#include "htp-msg.h"
  37#include "htp_iface.h"
  38#include "htp-drv.h"
  39
  40static size_t opt_ndev         = 1;
  41static size_t opt_nhvx         = 0; // use all
  42static int    opt_arch         = 0; // autodetect
  43static int    opt_etm          = 0;
  44static int    opt_verbose      = 0;
  45static int    opt_profile      = 0;
  46static int    opt_hostbuf      = 1; // hostbuf ON by default
  47static int    opt_experimental = 0;
  48
  49// Enable all stages by default
  50static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;
  51static int opt_opsync = 0;  // synchronous ops
  52
  53#define HEX_VERBOSE(...) \
  54    if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__)
  55
  56static inline uint64_t hex_is_aligned(void * addr, uint32_t align) {
  57    return ((size_t) addr & (align - 1)) == 0;
  58}
  59
  60static inline size_t hex_round_up(size_t n, size_t m) {
  61    return m * ((n + m - 1) / m);
  62}
  63
  64static const char * status_to_str(uint32_t status) {
  65    switch (status) {
  66        case HTP_STATUS_OK:
  67            return "OK";
  68        case HTP_STATUS_NO_SUPPORT:
  69            return "NO-SUPPORT";
  70        case HTP_STATUS_INVAL_PARAMS:
  71            return "INVAL-PARAMS";
  72        case HTP_STATUS_VTCM_TOO_SMALL:
  73            return "VTCM-TOO-SMALL";
  74        case HTP_STATUS_INTERNAL_ERR:
  75            return "INTERNAL-ERROR";
  76        default:
  77            return "UNKNOWN";
  78    }
  79}
  80
  81// ** debug helpers
  82
  83static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_tensor * op, const uint32_t req_flags) {
  84    if (!opt_verbose) return;
  85
  86    op_desc desc(op);
  87    GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(),
  88                ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags);
  89}
  90
  91static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) {
  92    if (!opt_verbose) return;
  93
  94    op_desc desc(op);
  95    GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(),
  96                ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no");
  97}
  98
  99static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op,
 100                                      uint32_t op_usec, uint32_t op_cycles, uint32_t op_pkts, uint64_t call_usec) {
 101    if (!opt_profile) return;
 102
 103    op_desc desc(op);
 104    GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(),
 105                ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs,
 106                op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec);
 107}
 108
 109// ** backend sessions
 110
 111struct ggml_hexagon_session {
 112    ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false);
 113    ~ggml_hexagon_session() noexcept(true);
 114
 115    void allocate(int dev_id) noexcept(false);
 116    void release() noexcept(true);
 117
 118    void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false);
 119    void flush();
 120
 121    ggml_backend_buffer_type buffer_type        = {};
 122    ggml_backend_buffer_type repack_buffer_type = {};
 123
 124    std::string      name;
 125    remote_handle64  handle;
 126    dspqueue_t       queue;
 127    uint32_t         session_id;
 128    uint32_t         domain_id;
 129    uint64_t         queue_id;
 130    int              dev_id;
 131    bool             valid_session;
 132    bool             valid_handle;
 133    bool             valid_queue;
 134    bool             valid_iface;
 135    std::atomic<int> op_pending;
 136    uint32_t         prof_usecs;
 137    uint32_t         prof_cycles;
 138    uint32_t         prof_pkts;
 139};
 140
 141void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) {
 142    // Bump pending flag (cleared in the session::flush once we get the responce)
 143    this->op_pending++;  // atomic inc
 144
 145    int err = dspqueue_write(this->queue,
 146                             0,                       // flags - the framework will autoset this
 147                             n_bufs,                  // number of buffers
 148                             bufs,                    // buffer references
 149                             sizeof(req),             // Message length
 150                             (const uint8_t *) &req,  // Message
 151                             DSPQUEUE_TIMEOUT         // Timeout
 152    );
 153
 154    if (err != 0) {
 155        GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
 156    }
 157
 158    if (sync) {
 159        flush();
 160    }
 161}
 162
 163// Flush HTP response queue i.e wait for all outstanding requests to complete
 164void ggml_hexagon_session::flush() {
 165    dspqueue_t q = this->queue;
 166
 167    // Repeatedly read packets from the queue until it's empty. We don't
 168    // necessarily get a separate callback for each packet, and new packets
 169    // may arrive while we're processing the previous one.
 170
 171    while (this->op_pending) {
 172        struct htp_general_rsp rsp;
 173        uint32_t               rsp_size;
 174        uint32_t               flags;
 175
 176        struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
 177        uint32_t               n_bufs;
 178
 179        // Read response packet from queue
 180        int err = dspqueue_read(q, &flags,
 181                                HTP_MAX_PACKET_BUFFERS,  // Maximum number of buffer references
 182                                &n_bufs,                 // Number of buffer references
 183                                bufs,                    // Buffer references
 184                                sizeof(rsp),             // Max message length
 185                                &rsp_size,               // Message length
 186                                (uint8_t *) &rsp,        // Message
 187                                DSPQUEUE_TIMEOUT);       // Timeout
 188
 189        if (err == AEE_EEXPIRED) {
 190            // TODO: might need to bail out if the HTP is stuck on something
 191            continue;
 192        }
 193
 194        if (err != 0) {
 195            GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err);
 196        }
 197
 198        // Basic sanity checks
 199        if (rsp_size != sizeof(rsp)) {
 200            GGML_ABORT("ggml-hex: dspcall : bad response (size)\n");
 201        }
 202
 203        if (rsp.status != HTP_STATUS_OK) {
 204            GGML_LOG_ERROR("ggml-hex: dspcall : dsp-rsp: %s\n", status_to_str(rsp.status));
 205            // TODO: handle errors
 206        }
 207
 208        // TODO: update profiling implementation, currently only works for opt_opsync mode
 209        this->prof_usecs  = rsp.prof_usecs;
 210        this->prof_cycles = rsp.prof_cycles;
 211        this->prof_pkts   = rsp.prof_pkts;
 212
 213        this->op_pending--;  // atomic dec
 214    }
 215}
 216
 217// ** backend buffers
 218
 219struct ggml_backend_hexagon_buffer_type_context {
 220    ggml_backend_hexagon_buffer_type_context(const std::string & name, ggml_hexagon_session * sess) {
 221        this->sess = sess;
 222        this->name = name;
 223    }
 224
 225    ggml_hexagon_session * sess;
 226    std::string            name;
 227};
 228
 229struct ggml_backend_hexagon_buffer_context {
 230    bool mmap_to(ggml_hexagon_session * s) {
 231        HEX_VERBOSE("ggml-hex: %s mmaping buffer: base %p domain-id %d session-id %d size %zu fd %d repack %d\n",
 232                    s->name.c_str(), (void *) this->base, s->domain_id, s->session_id, this->size, this->fd,
 233                    (int) this->repack);
 234
 235        int err = fastrpc_mmap(s->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD);
 236        if (err != 0) {
 237            GGML_LOG_ERROR("ggml-hex: buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n",
 238                    s->domain_id, this->size, this->fd, (unsigned) err);
 239            return false;
 240        }
 241
 242        return true;
 243    }
 244
 245    bool mmap() {
 246        if (this->mapped) {
 247            return true;
 248        }
 249        if (!mmap_to(this->sess)) {
 250            return false;
 251        }
 252        this->mapped = true;
 253        return true;
 254    }
 255
 256    void munmap() {
 257        if (!this->mapped) {
 258            return;
 259        }
 260
 261        fastrpc_munmap(this->sess->domain_id, this->fd, this->base, this->size);
 262        this->mapped = false;
 263    }
 264
 265    ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {
 266        size += 4 * 1024;  // extra page for padding
 267
 268        this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
 269        if (!this->base) {
 270            GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size);
 271            throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)");
 272        }
 273
 274        this->fd = rpcmem_to_fd(this->base);
 275        if (this->fd < 0) {
 276            GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->name.c_str(), (void *) this->base);
 277            rpcmem_free(this->base);
 278            this->base = NULL;
 279            throw std::runtime_error("ggml-hex: rpcmem_to_fd failed (see log for details)");
 280        }
 281
 282        HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d repack %d\n", sess->name.c_str(),
 283                    (void *) this->base, size, this->fd, (int) repack);
 284
 285        this->sess   = sess;
 286        this->size   = size;
 287        this->mapped = false;
 288        this->repack = repack;
 289    }
 290
 291    ~ggml_backend_hexagon_buffer_context() {
 292        munmap();
 293        if (this->base) {
 294            rpcmem_free(this->base);
 295            this->base = NULL;
 296        }
 297    }
 298
 299    ggml_hexagon_session * sess;  // primary session
 300    uint8_t *              base;
 301    size_t                 size;
 302    int                    fd;
 303    bool                   mapped;  // mmap is done
 304    bool                   repack;  // repacked buffer
 305};
 306
 307static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_buffer_t buffer) {
 308    return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer->buft->context)->sess;
 309}
 310
 311static void ggml_backend_hexagon_buffer_free_buffer(ggml_backend_buffer_t buffer) {
 312    auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
 313    delete ctx;
 314}
 315
 316static void * ggml_backend_hexagon_buffer_get_base(ggml_backend_buffer_t buffer) {
 317    auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
 318    return ctx->base;
 319}
 320
 321static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
 322    auto ctx  = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
 323    auto sess = ctx->sess;
 324
 325    HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d repack %d\n", sess->name.c_str(),
 326                tensor->name, (void *) ctx->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage,
 327                (int) ctx->repack);
 328
 329    if (tensor->view_src != NULL && tensor->view_offs == 0) {
 330        ; // nothing to do for the view
 331    } else {
 332        if (!ctx->mapped) {
 333            ctx->mmap();
 334        }
 335    }
 336    return GGML_STATUS_SUCCESS;
 337}
 338
 339// ======== Q4x4x2 ====================
 340struct x2_q4 {
 341    int v[2];
 342};
 343
 344static x2_q4 unpack_q4(uint8_t v) {
 345    x2_q4 x = { (int) (v & 0x0f) - 8, (int) (v >> 4) - 8 };
 346    return x;
 347}
 348
 349static void dump_block_q4_0(const block_q4_0 * b, int i) {
 350    HEX_VERBOSE("ggml-hex: repack q4_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_q4(b->qs[0]).v[0],
 351                unpack_q4(b->qs[1]).v[0], unpack_q4(b->qs[2]).v[0], unpack_q4(b->qs[3]).v[0], unpack_q4(b->qs[12]).v[1],
 352                unpack_q4(b->qs[13]).v[1], unpack_q4(b->qs[14]).v[1], unpack_q4(b->qs[15]).v[1],
 353                GGML_FP16_TO_FP32(b->d));
 354}
 355
 356static void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k) {
 357    static const int qk        = QK_Q4_0x4x2;
 358    const int        dblk_size = 8 * 2;   // 8x __fp16
 359    const int        qblk_size = qk / 2;  // int4
 360    const int        qrow_size = k / 2;   // int4 (not padded)
 361
 362    const uint8_t * v_q = v + 0;          // quants first
 363    const uint8_t * v_d = v + qrow_size;  // then scales
 364
 365    const uint8_t *   q = v_q + i * qblk_size;
 366    const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);
 367
 368    HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
 369                unpack_q4(q[0]).v[0], unpack_q4(q[1]).v[0], unpack_q4(q[2]).v[0], unpack_q4(q[3]).v[0],
 370                unpack_q4(q[60]).v[0], unpack_q4(q[61]).v[0], unpack_q4(q[62]).v[0], unpack_q4(q[63]).v[0],
 371                unpack_q4(q[124]).v[0], unpack_q4(q[125]).v[0], unpack_q4(q[126]).v[0], unpack_q4(q[127]).v[0],
 372                GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));
 373
 374    HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
 375                i + 1, unpack_q4(q[0]).v[1], unpack_q4(q[1]).v[1], unpack_q4(q[2]).v[1], unpack_q4(q[3]).v[1],
 376                unpack_q4(q[60]).v[1], unpack_q4(q[61]).v[1], unpack_q4(q[62]).v[1], unpack_q4(q[63]).v[1],
 377                unpack_q4(q[124]).v[1], unpack_q4(q[125]).v[1], unpack_q4(q[126]).v[1], unpack_q4(q[127]).v[1],
 378                GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));
 379}
 380
 381static void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) {
 382    static const int qk = QK4_0;
 383
 384    for (unsigned int i = 0; i < qk / 2; ++i) {
 385        const int x0             = (x->qs[i] & 0x0F);
 386        const int x1             = (x->qs[i] >> 4);
 387        qs[bi * qk + i + 0]      = x0;
 388        qs[bi * qk + i + qk / 2] = x1;
 389    }
 390}
 391
 392static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi) {
 393    static const int qk = QK4_0;
 394
 395    for (unsigned int i = 0; i < qk / 2; ++i) {
 396        const uint8_t x0 = qs[bi * qk + i + 0];
 397        const uint8_t x1 = qs[bi * qk + i + qk / 2];
 398        x->qs[i]         = x0 | (x1 << 4);
 399    }
 400}
 401
 402static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
 403    static const int qk = QK_Q4_0x4x2;
 404    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
 405
 406    const int dblk_size = 8 * 2;              // 8x __fp16
 407    const int qblk_size = qk / 2;             // int4
 408    const int qrow_size = k / 2;              // int4 (not padded to blocks)
 409
 410    uint8_t * y_q = y + 0;                    // quants first
 411    uint8_t * y_d = y + qrow_size;            // then scales
 412
 413    if (opt_verbose > 2) {
 414        for (int i = 0; i < nb; i++) {
 415            dump_block_q4_0(&x[i * 8 + 0], 0);
 416            dump_block_q4_0(&x[i * 8 + 1], 1);
 417            dump_block_q4_0(&x[i * 8 + 2], 2);
 418            dump_block_q4_0(&x[i * 8 + 3], 3);
 419            dump_block_q4_0(&x[i * 8 + 4], 4);
 420            dump_block_q4_0(&x[i * 8 + 5], 5);
 421            dump_block_q4_0(&x[i * 8 + 6], 6);
 422            dump_block_q4_0(&x[i * 8 + 7], 7);
 423        }
 424    }
 425
 426    // Repack the quants
 427    for (int i = 0; i < nb; i++) {
 428        uint8_t qs[QK_Q4_0x4x2];  // unpacked quants
 429        unpack_q4_0_quants(qs, &x[i * 8 + 0], 0);
 430        unpack_q4_0_quants(qs, &x[i * 8 + 1], 1);
 431        unpack_q4_0_quants(qs, &x[i * 8 + 2], 2);
 432        unpack_q4_0_quants(qs, &x[i * 8 + 3], 3);
 433        unpack_q4_0_quants(qs, &x[i * 8 + 4], 4);
 434        unpack_q4_0_quants(qs, &x[i * 8 + 5], 5);
 435        unpack_q4_0_quants(qs, &x[i * 8 + 6], 6);
 436        unpack_q4_0_quants(qs, &x[i * 8 + 7], 7);
 437
 438        uint8_t * q = y_q + (i * qblk_size);
 439        for (int j = 0; j < qk / 2; j++) {
 440            q[j] = (qs[j + 128] << 4) | qs[j];
 441        }
 442    }
 443
 444    // Repack the scales
 445    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
 446    // the last block is truncated and overriden by the scales.
 447    for (int i = 0; i < nb; i++) {
 448        // Repack the scales
 449        ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
 450        d[0]          = x[i * 8 + 0].d;
 451        d[1]          = x[i * 8 + 1].d;
 452        d[2]          = x[i * 8 + 2].d;
 453        d[3]          = x[i * 8 + 3].d;
 454        d[4]          = x[i * 8 + 4].d;
 455        d[5]          = x[i * 8 + 5].d;
 456        d[6]          = x[i * 8 + 6].d;
 457        d[7]          = x[i * 8 + 7].d;
 458    }
 459
 460    if (opt_verbose > 1) {
 461        for (int i = 0; i < nb; i++) {
 462            dump_packed_block_q4x4x2(y, i, k);
 463        }
 464    }
 465}
 466
 467static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
 468    static const int qk = QK_Q4_0x4x2;
 469    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
 470
 471    const int dblk_size = 8 * 2;              // 8x __fp16
 472    const int qblk_size = qk / 2;             // int4
 473    const int qrow_size = k / 2;              // int4 (not padded to blocks)
 474
 475    const uint8_t * y_q = y + 0;              // quants first
 476    const uint8_t * y_d = y + qrow_size;      // then scales
 477
 478    if (opt_verbose > 1) {
 479        for (int i = 0; i < nb; i++) {
 480            dump_packed_block_q4x4x2(y, i, k);
 481        }
 482    }
 483
 484    // Unpack the quants
 485    for (int i = 0; i < nb; i++) {
 486        uint8_t qs[QK_Q4_0x4x2];  // unpacked quants
 487
 488        const uint8_t * q = y_q + (i * qblk_size);
 489        for (int j = 0; j < qk / 2; j++) {
 490            qs[j]       = q[j] & 0xf;
 491            qs[j + 128] = q[j] >> 4;
 492        }
 493
 494        pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
 495        pack_q4_0_quants(&x[i * 8 + 1], qs, 1);
 496        pack_q4_0_quants(&x[i * 8 + 2], qs, 2);
 497        pack_q4_0_quants(&x[i * 8 + 3], qs, 3);
 498        pack_q4_0_quants(&x[i * 8 + 4], qs, 4);
 499        pack_q4_0_quants(&x[i * 8 + 5], qs, 5);
 500        pack_q4_0_quants(&x[i * 8 + 6], qs, 6);
 501        pack_q4_0_quants(&x[i * 8 + 7], qs, 7);
 502    }
 503
 504    // Repack the scales
 505    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
 506    // the last block is truncated and overriden by the scales.
 507    for (int i = 0; i < nb; i++) {
 508        // Unpack the scales
 509        const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
 510        x[i * 8 + 0].d      = d[0];
 511        x[i * 8 + 1].d      = d[1];
 512        x[i * 8 + 2].d      = d[2];
 513        x[i * 8 + 3].d      = d[3];
 514        x[i * 8 + 4].d      = d[4];
 515        x[i * 8 + 5].d      = d[5];
 516        x[i * 8 + 6].d      = d[6];
 517        x[i * 8 + 7].d      = d[7];
 518    }
 519
 520    if (opt_verbose > 2) {
 521        for (int i = 0; i < nb; i++) {
 522            dump_block_q4_0(&x[i * 8 + 0], 0);
 523            dump_block_q4_0(&x[i * 8 + 1], 1);
 524            dump_block_q4_0(&x[i * 8 + 2], 2);
 525            dump_block_q4_0(&x[i * 8 + 3], 3);
 526            dump_block_q4_0(&x[i * 8 + 4], 4);
 527            dump_block_q4_0(&x[i * 8 + 5], 5);
 528            dump_block_q4_0(&x[i * 8 + 6], 6);
 529            dump_block_q4_0(&x[i * 8 + 7], 7);
 530        }
 531    }
 532}
 533
 534static void init_row_q4x4x2(block_q4_0 * x, int64_t k) {
 535    static const int qk = QK_Q4_0x4x2;
 536    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
 537
 538    // Init the quants such that they unpack into zeros
 539    uint8_t qs[QK_Q4_0x4x2];  // unpacked quants
 540    memset(qs, 8, sizeof(qs));
 541
 542    for (int i = 0; i < nb; i++) {
 543        pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
 544        pack_q4_0_quants(&x[i * 8 + 1], qs, 1);
 545        pack_q4_0_quants(&x[i * 8 + 2], qs, 2);
 546        pack_q4_0_quants(&x[i * 8 + 3], qs, 3);
 547        pack_q4_0_quants(&x[i * 8 + 4], qs, 4);
 548        pack_q4_0_quants(&x[i * 8 + 5], qs, 5);
 549        pack_q4_0_quants(&x[i * 8 + 6], qs, 6);
 550        pack_q4_0_quants(&x[i * 8 + 7], qs, 7);
 551    }
 552
 553    // Init the scales
 554    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
 555    // the last block is truncated and overriden by the scales.
 556    for (int i = 0; i < nb; i++) {
 557        // Unpack the scales
 558        x[i * 8 + 0].d = 0;
 559        x[i * 8 + 1].d = 0;
 560        x[i * 8 + 2].d = 0;
 561        x[i * 8 + 3].d = 0;
 562        x[i * 8 + 4].d = 0;
 563        x[i * 8 + 5].d = 0;
 564        x[i * 8 + 6].d = 0;
 565        x[i * 8 + 7].d = 0;
 566    }
 567}
 568
 569// repack q4_0 data into q4x4x2 tensor
 570static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) {
 571    int64_t nrows = ggml_nrows(t);
 572
 573    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
 574    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));  // extra elements for the pad
 575    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
 576
 577    // Ensure we don't try to read more data than is available in the source buffer 'data'
 578    // or write more than the tensor can hold.
 579    const size_t total_tensor_size = (size_t)nrows * row_size;
 580    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
 581
 582    // Calculate how many full rows and how many remaining bytes we need to process.
 583    const int64_t n_full_rows = n_bytes_to_copy / row_size;
 584    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
 585
 586    void * buf_pd = ggml_aligned_malloc(row_size_pd);
 587    GGML_ASSERT(buf_pd != NULL);
 588
 589    void * buf_rp = ggml_aligned_malloc(row_size_rp);
 590    GGML_ASSERT(buf_rp != NULL);
 591
 592    HEX_VERBOSE("ggml-hex: repack-q4_0-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
 593                t->ne[0], nrows, row_size);
 594
 595    init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros
 596
 597    // 1. Process all the full rows
 598    for (int64_t i = 0; i < n_full_rows; i++) {
 599        const uint8_t * src = (const uint8_t *) data + (i * row_size);
 600        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
 601
 602        memcpy(buf_pd, src, row_size);
 603        repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
 604        memcpy(dst, buf_rp, row_size);
 605    }
 606
 607    // 2. Process the final, potentially partial, row
 608    if (n_rem_bytes > 0) {
 609        const int64_t i = n_full_rows;
 610        const uint8_t * src = (const uint8_t *) data + (i * row_size);
 611        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
 612
 613        // re-init the row because we are potentially copying a partial row
 614        init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);
 615
 616        // Copy only the remaining bytes from the source.
 617        memcpy(buf_pd, src, n_rem_bytes);
 618
 619        // Repack the entire buffer
 620        repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
 621
 622        // Write only the corresponding remaining bytes to the destination tensor.
 623        memcpy(dst, buf_rp, n_rem_bytes);
 624    }
 625
 626    ggml_aligned_free(buf_pd, row_size_pd);
 627    ggml_aligned_free(buf_rp, row_size_rp);
 628}
 629
 630// repack q4x4x2 tensor into q4_0 data
 631static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) {
 632    int64_t nrows = ggml_nrows(t);
 633
 634    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
 635    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));  // extra elements for the pad
 636    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
 637
 638    // Ensure we don't try to copy more data than the tensor actually contains.
 639    const size_t total_tensor_size = (size_t)nrows * row_size;
 640    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
 641
 642    // Calculate how many full rows and how many remaining bytes we need to process.
 643    const int64_t n_full_rows = n_bytes_to_copy / row_size;
 644    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
 645
 646    void * buf_pd = ggml_aligned_malloc(row_size_pd);
 647    GGML_ASSERT(buf_pd != NULL);
 648
 649    void * buf_rp = ggml_aligned_malloc(row_size_rp);
 650    GGML_ASSERT(buf_rp != NULL);
 651
 652    HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
 653                t->ne[0], nrows, row_size);
 654
 655    memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros
 656
 657    // 1. Process all the full rows
 658    for (int64_t i = 0; i < n_full_rows; i++) {
 659        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
 660        uint8_t *       dst = (uint8_t *) data + (i * row_size);
 661
 662        memcpy(buf_pd, src, row_size);
 663        unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
 664        memcpy(dst, buf_rp, row_size);
 665    }
 666
 667    // 2. Process the final, potentially partial, row
 668    if (n_rem_bytes > 0) {
 669        const int64_t i = n_full_rows;
 670        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
 671        uint8_t *       dst = (uint8_t *) data + (i * row_size);
 672
 673        // We still need to read and unpack the entire source row because quantization is block-based.
 674        memcpy(buf_pd, src, row_size);
 675        unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
 676
 677        // But we only copy the remaining number of bytes to the destination.
 678        memcpy(dst, buf_rp, n_rem_bytes);
 679    }
 680
 681    ggml_aligned_free(buf_pd, row_size_pd);
 682    ggml_aligned_free(buf_rp, row_size_rp);
 683}
 684
 685// ======== Q8x4x2 ====================
 686static void dump_block_q8_0(const block_q8_0 * b, int i) {
 687    HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2],
 688                b->qs[3], b->qs[28], b->qs[29], b->qs[30], b->qs[31], GGML_FP16_TO_FP32(b->d));
 689}
 690
 691static void dump_packed_block_q8x4x2(const uint8_t * v, unsigned int i, size_t k) {
 692    static const int qk        = QK_Q8_0x4x2;
 693    const int        dblk_size = 8 * 2;   // 8x __fp16
 694    const int        qblk_size = qk;      // int8
 695    const int        qrow_size = k;       // int8 (not padded)
 696
 697    const uint8_t * v_q = v + 0;          // quants first
 698    const uint8_t * v_d = v + qrow_size;  // then scales
 699
 700    const uint8_t *   q = v_q + i * qblk_size;
 701    const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);
 702
 703    HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
 704                q[0], q[1], q[2], q[3], q[60], q[61], q[62], q[63], q[124], q[125], q[126], q[127],
 705                GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));
 706
 707    HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
 708                i + 1, q[128], q[129], q[130], q[131], q[192], q[193], q[194], q[195], q[252], q[253], q[254], q[255],
 709                GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));
 710}
 711
 712static void unpack_q8_0_quants(uint8_t * qs, const block_q8_0 * x, unsigned int bi) {
 713    static const int qk = QK8_0;
 714
 715    for (unsigned int i = 0; i < qk; ++i) {
 716        qs[bi * qk + i] = x->qs[i];
 717    }
 718}
 719
 720static void pack_q8_0_quants(block_q8_0 * x, const uint8_t * qs, unsigned int bi) {
 721    static const int qk = QK8_0;
 722
 723    for (unsigned int i = 0; i < qk; ++i) {
 724        x->qs[i] = qs[bi * qk + i];
 725    }
 726}
 727
 728static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) {
 729    static const int qk = QK_Q8_0x4x2;
 730    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
 731
 732    const int dblk_size = 8 * 2;              // 8x __fp16
 733    const int qblk_size = qk;                 // int8
 734    const int qrow_size = k;                  // int8 (not padded to blocks)
 735
 736    uint8_t * y_q = y + 0;                    // quants first
 737    uint8_t * y_d = y + qrow_size;            // then scales
 738
 739    if (opt_verbose > 2) {
 740        for (int i = 0; i < nb; i++) {
 741            dump_block_q8_0(&x[i * 8 + 0], 0);
 742            dump_block_q8_0(&x[i * 8 + 1], 1);
 743            dump_block_q8_0(&x[i * 8 + 2], 2);
 744            dump_block_q8_0(&x[i * 8 + 3], 3);
 745            dump_block_q8_0(&x[i * 8 + 4], 4);
 746            dump_block_q8_0(&x[i * 8 + 5], 5);
 747            dump_block_q8_0(&x[i * 8 + 6], 6);
 748            dump_block_q8_0(&x[i * 8 + 7], 7);
 749        }
 750    }
 751
 752    // Repack the quants
 753    for (int i = 0; i < nb; i++) {
 754        uint8_t qs[QK_Q8_0x4x2];  // unpacked quants
 755
 756        unpack_q8_0_quants(qs, &x[i * 8 + 0], 0);
 757        unpack_q8_0_quants(qs, &x[i * 8 + 1], 1);
 758        unpack_q8_0_quants(qs, &x[i * 8 + 2], 2);
 759        unpack_q8_0_quants(qs, &x[i * 8 + 3], 3);
 760        unpack_q8_0_quants(qs, &x[i * 8 + 4], 4);
 761        unpack_q8_0_quants(qs, &x[i * 8 + 5], 5);
 762        unpack_q8_0_quants(qs, &x[i * 8 + 6], 6);
 763        unpack_q8_0_quants(qs, &x[i * 8 + 7], 7);
 764
 765        uint8_t * q = y_q + (i * qblk_size);
 766        for (int j = 0; j < qk; j++) {
 767            q[j] = qs[j];
 768        }
 769    }
 770
 771    // Repack the scales
 772    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
 773    // the last block is truncated and overriden by the scales.
 774    for (int i = 0; i < nb; i++) {
 775        // Repack the scales
 776        ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
 777        d[0]          = x[i * 8 + 0].d;
 778        d[1]          = x[i * 8 + 1].d;
 779        d[2]          = x[i * 8 + 2].d;
 780        d[3]          = x[i * 8 + 3].d;
 781        d[4]          = x[i * 8 + 4].d;
 782        d[5]          = x[i * 8 + 5].d;
 783        d[6]          = x[i * 8 + 6].d;
 784        d[7]          = x[i * 8 + 7].d;
 785    }
 786
 787    if (opt_verbose > 1) {
 788        for (int i = 0; i < nb; i++) {
 789            dump_packed_block_q8x4x2(y, i, k);
 790        }
 791    }
 792}
 793
 794static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) {
 795    static const int qk = QK_Q8_0x4x2;
 796    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
 797
 798    const int dblk_size = 8 * 2;              // 8x __fp16
 799    const int qblk_size = qk;                 // int8
 800    const int qrow_size = k;                  // int8 (not padded to blocks)
 801
 802    const uint8_t * y_q = y + 0;              // quants first
 803    const uint8_t * y_d = y + qrow_size;      // then scales
 804
 805    if (opt_verbose > 1) {
 806        for (int i = 0; i < nb; i++) {
 807            dump_packed_block_q8x4x2(y, i, k);
 808        }
 809    }
 810
 811    // Unpack the quants
 812    for (int i = 0; i < nb; i++) {
 813        uint8_t qs[QK_Q4_0x4x2];  // unpacked quants
 814
 815        const uint8_t * q = y_q + (i * qblk_size);
 816        for (int j = 0; j < qk; j++) {
 817            qs[j] = q[j];
 818        }
 819
 820        pack_q8_0_quants(&x[i * 8 + 0], qs, 0);
 821        pack_q8_0_quants(&x[i * 8 + 1], qs, 1);
 822        pack_q8_0_quants(&x[i * 8 + 2], qs, 2);
 823        pack_q8_0_quants(&x[i * 8 + 3], qs, 3);
 824        pack_q8_0_quants(&x[i * 8 + 4], qs, 4);
 825        pack_q8_0_quants(&x[i * 8 + 5], qs, 5);
 826        pack_q8_0_quants(&x[i * 8 + 6], qs, 6);
 827        pack_q8_0_quants(&x[i * 8 + 7], qs, 7);
 828    }
 829
 830    // Repack the scales
 831    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
 832    // the last block is truncated and overriden by the scales.
 833    for (int i = 0; i < nb; i++) {
 834        // Unpack the scales
 835        const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
 836        x[i * 8 + 0].d      = d[0];
 837        x[i * 8 + 1].d      = d[1];
 838        x[i * 8 + 2].d      = d[2];
 839        x[i * 8 + 3].d      = d[3];
 840        x[i * 8 + 4].d      = d[4];
 841        x[i * 8 + 5].d      = d[5];
 842        x[i * 8 + 6].d      = d[6];
 843        x[i * 8 + 7].d      = d[7];
 844    }
 845
 846    if (opt_verbose > 2) {
 847        for (int i = 0; i < nb; i++) {
 848            dump_block_q8_0(&x[i * 8 + 0], 0);
 849            dump_block_q8_0(&x[i * 8 + 1], 1);
 850            dump_block_q8_0(&x[i * 8 + 2], 2);
 851            dump_block_q8_0(&x[i * 8 + 3], 3);
 852            dump_block_q8_0(&x[i * 8 + 4], 4);
 853            dump_block_q8_0(&x[i * 8 + 5], 5);
 854            dump_block_q8_0(&x[i * 8 + 6], 6);
 855            dump_block_q8_0(&x[i * 8 + 7], 7);
 856        }
 857    }
 858}
 859
 860static void init_row_q8x4x2(block_q8_0 * x, int64_t k) {
 861    static const int qk = QK_Q8_0x4x2;
 862    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
 863
 864    // Init the quants such that they unpack into zeros
 865    uint8_t qs[QK_Q8_0x4x2];  // unpacked quants
 866    memset(qs, 0, sizeof(qs));
 867
 868    for (int i = 0; i < nb; i++) {
 869        pack_q8_0_quants(&x[i * 8 + 0], qs, 0);
 870        pack_q8_0_quants(&x[i * 8 + 1], qs, 1);
 871        pack_q8_0_quants(&x[i * 8 + 2], qs, 2);
 872        pack_q8_0_quants(&x[i * 8 + 3], qs, 3);
 873        pack_q8_0_quants(&x[i * 8 + 4], qs, 4);
 874        pack_q8_0_quants(&x[i * 8 + 5], qs, 5);
 875        pack_q8_0_quants(&x[i * 8 + 6], qs, 6);
 876        pack_q8_0_quants(&x[i * 8 + 7], qs, 7);
 877    }
 878
 879    // Init the scales
 880    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2)
 881    // the last block is truncated and overriden by the scales.
 882    for (int i = 0; i < nb; i++) {
 883        // Unpack the scales
 884        x[i * 8 + 0].d = 0;
 885        x[i * 8 + 1].d = 0;
 886        x[i * 8 + 2].d = 0;
 887        x[i * 8 + 3].d = 0;
 888        x[i * 8 + 4].d = 0;
 889        x[i * 8 + 5].d = 0;
 890        x[i * 8 + 6].d = 0;
 891        x[i * 8 + 7].d = 0;
 892    }
 893}
 894
 895// repack q8_0 data into q8x4x2 tensor
 896static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) {
 897    int64_t nrows = ggml_nrows(t);
 898
 899    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
 900    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2));  // extra elements for the pad
 901    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
 902
 903    // Ensure we don't try to read more data than is available in the source buffer 'data'
 904    // or write more than the tensor can hold.
 905    const size_t total_tensor_size = (size_t)nrows * row_size;
 906    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
 907
 908    // Calculate how many full rows and how many remaining bytes we need to process.
 909    const int64_t n_full_rows = n_bytes_to_copy / row_size;
 910    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
 911
 912    void * buf_pd = ggml_aligned_malloc(row_size_pd);
 913    GGML_ASSERT(buf_pd != NULL);
 914
 915    void * buf_rp = ggml_aligned_malloc(row_size_rp);
 916    GGML_ASSERT(buf_rp != NULL);
 917
 918    HEX_VERBOSE("ggml-hex: repack-q8_0-q8x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
 919                t->ne[0], nrows, row_size);
 920
 921    init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros
 922
 923    // 1. Process all the full rows
 924    for (int64_t i = 0; i < n_full_rows; i++) {
 925        const uint8_t * src = (const uint8_t *) data + (i * row_size);
 926        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
 927
 928        memcpy(buf_pd, src, row_size);
 929        repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
 930        memcpy(dst, buf_rp, row_size);
 931    }
 932
 933    // 2. Process the final, potentially partial, row
 934    if (n_rem_bytes > 0) {
 935        const int64_t i = n_full_rows;
 936        const uint8_t * src = (const uint8_t *) data + (i * row_size);
 937        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
 938
 939        // re-init the row because we are potentially copying a partial row
 940        init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);
 941
 942        // Copy only the remaining bytes from the source.
 943        memcpy(buf_pd, src, n_rem_bytes);
 944
 945        // Repack the entire buffer
 946        repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
 947
 948        // Write only the corresponding remaining bytes to the destination tensor.
 949        memcpy(dst, buf_rp, n_rem_bytes);
 950    }
 951
 952    ggml_aligned_free(buf_pd, row_size_pd);
 953    ggml_aligned_free(buf_rp, row_size_rp);
 954}
 955
 956// repack q8x4x2 tensor into q8_0 data
 957static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) {
 958    int64_t nrows = ggml_nrows(t);
 959
 960    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
 961    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2));  // extra elements for the pad
 962    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
 963
 964    // Ensure we don't try to copy more data than the tensor actually contains.
 965    const size_t total_tensor_size = (size_t)nrows * row_size;
 966    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
 967
 968    // Calculate how many full rows and how many remaining bytes we need to process.
 969    const int64_t n_full_rows = n_bytes_to_copy / row_size;
 970    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
 971
 972    void * buf_pd = ggml_aligned_malloc(row_size_pd);
 973    GGML_ASSERT(buf_pd != NULL);
 974
 975    void * buf_rp = ggml_aligned_malloc(row_size_rp);
 976    GGML_ASSERT(buf_rp != NULL);
 977
 978    HEX_VERBOSE("ggml-hex: repack-q8x4x2-q8_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
 979                t->ne[0], nrows, row_size);
 980
 981    memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros
 982
 983    // 1. Process all the full rows
 984    for (int64_t i = 0; i < n_full_rows; i++) {
 985        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
 986        uint8_t *       dst = (uint8_t *) data + (i * row_size);
 987
 988        memcpy(buf_pd, src, row_size);
 989        unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
 990        memcpy(dst, buf_rp, row_size);
 991    }
 992
 993    // 2. Process the final, potentially partial, row
 994    if (n_rem_bytes > 0) {
 995        const int64_t i = n_full_rows;
 996        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
 997        uint8_t *       dst = (uint8_t *) data + (i * row_size);
 998
 999        // We still need to read and unpack the entire source row because quantization is block-based.
1000        memcpy(buf_pd, src, row_size);
1001        unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
1002
1003        // But we only copy the remaining number of bytes to the destination.
1004        memcpy(dst, buf_rp, n_rem_bytes);
1005    }
1006
1007    ggml_aligned_free(buf_pd, row_size_pd);
1008    ggml_aligned_free(buf_rp, row_size_rp);
1009}
1010
1011// ======== MXFP4x4x2 ====================
1012struct x2_mxfp4 {
1013    int v[2];
1014};
1015
1016static x2_mxfp4 unpack_mxfp4(uint8_t v) {
1017    x2_mxfp4 x;
1018    x.v[0] = kvalues_mxfp4[(v & 0x0f)];
1019    x.v[1] = kvalues_mxfp4[(v >> 4)];
1020    return x;
1021}
1022
1023static void dump_block_mxfp4(const block_mxfp4 * b, int i) {
1024    HEX_VERBOSE("ggml-hex: repack mxfp4 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_mxfp4(b->qs[0]).v[0],
1025                unpack_mxfp4(b->qs[1]).v[0], unpack_mxfp4(b->qs[2]).v[0], unpack_mxfp4(b->qs[3]).v[0],
1026                unpack_mxfp4(b->qs[12]).v[1], unpack_mxfp4(b->qs[13]).v[1], unpack_mxfp4(b->qs[14]).v[1],
1027                unpack_mxfp4(b->qs[15]).v[1], GGML_E8M0_TO_FP32_HALF(b->e));
1028}
1029
1030static void dump_packed_block_mxfp4x4x2(const uint8_t * v, unsigned int i, size_t k) {
1031    static const int qk        = QK_MXFP4x4x2;
1032    const int        eblk_size = 8 * 1;   // 8x E8M0
1033    const int        qblk_size = qk / 2;  // int4
1034    const int        qrow_size = k / 2;   // int4 (not padded)
1035
1036    const uint8_t * v_q = v + 0;          // quants first
1037    const uint8_t * v_e = v + qrow_size;  // then scales
1038
1039    const uint8_t * q = v_q + i * qblk_size;
1040    const uint8_t * e = (const uint8_t *) (v_e + i * eblk_size);
1041
1042    HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
1043                unpack_mxfp4(q[0]).v[0], unpack_mxfp4(q[1]).v[0], unpack_mxfp4(q[2]).v[0], unpack_mxfp4(q[3]).v[0],
1044                unpack_mxfp4(q[60]).v[0], unpack_mxfp4(q[61]).v[0], unpack_mxfp4(q[62]).v[0], unpack_mxfp4(q[63]).v[0],
1045                unpack_mxfp4(q[124]).v[0], unpack_mxfp4(q[125]).v[0], unpack_mxfp4(q[126]).v[0],
1046                unpack_mxfp4(q[127]).v[0], GGML_E8M0_TO_FP32_HALF(e[0]), GGML_E8M0_TO_FP32_HALF(e[1]),
1047                GGML_E8M0_TO_FP32_HALF(e[2]), GGML_E8M0_TO_FP32_HALF(e[3]));
1048
1049    HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
1050                i + 1, unpack_mxfp4(q[0]).v[1], unpack_mxfp4(q[1]).v[1], unpack_mxfp4(q[2]).v[1],
1051                unpack_mxfp4(q[3]).v[1], unpack_mxfp4(q[60]).v[1], unpack_mxfp4(q[61]).v[1], unpack_mxfp4(q[62]).v[1],
1052                unpack_mxfp4(q[63]).v[1], unpack_mxfp4(q[124]).v[1], unpack_mxfp4(q[125]).v[1],
1053                unpack_mxfp4(q[126]).v[1], unpack_mxfp4(q[127]).v[1], GGML_E8M0_TO_FP32_HALF(e[4]),
1054                GGML_E8M0_TO_FP32_HALF(e[5]), GGML_E8M0_TO_FP32_HALF(e[6]), GGML_E8M0_TO_FP32_HALF(e[7]));
1055}
1056
1057static void unpack_mxfp4_quants(uint8_t * qs, const block_mxfp4 * x, unsigned int bi) {
1058    static const int qk = QK_MXFP4;
1059
1060    for (unsigned int i = 0; i < qk / 2; ++i) {
1061        const uint8_t x0         = (x->qs[i] & 0x0F);
1062        const uint8_t x1         = (x->qs[i] >> 4);
1063        qs[bi * qk + i + 0]      = x0;
1064        qs[bi * qk + i + qk / 2] = x1;
1065    }
1066}
1067
1068static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int bi) {
1069    static const int qk = QK4_0;
1070
1071    for (unsigned int i = 0; i < qk / 2; ++i) {
1072        const uint8_t x0 = qs[bi * qk + i + 0];
1073        const uint8_t x1 = qs[bi * qk + i + qk / 2];
1074        x->qs[i]         = x0 | (x1 << 4);
1075    }
1076}
1077
1078static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) {
1079    static const int qk = QK_MXFP4x4x2;
1080    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
1081
1082    const int eblk_size = 8 * 1;              // 8x E8M0
1083    const int qblk_size = qk / 2;             // int4
1084    const int qrow_size = k / 2;              // int4 (not padded to blocks)
1085
1086    uint8_t * y_q = y + 0;                    // quants first
1087    uint8_t * y_e = y + qrow_size;            // then scales
1088
1089    if (opt_verbose > 2) {
1090        for (int i = 0; i < nb; i++) {
1091            dump_block_mxfp4(&x[i * 8 + 0], 0);
1092            dump_block_mxfp4(&x[i * 8 + 1], 1);
1093            dump_block_mxfp4(&x[i * 8 + 2], 2);
1094            dump_block_mxfp4(&x[i * 8 + 3], 3);
1095            dump_block_mxfp4(&x[i * 8 + 4], 4);
1096            dump_block_mxfp4(&x[i * 8 + 5], 5);
1097            dump_block_mxfp4(&x[i * 8 + 6], 6);
1098            dump_block_mxfp4(&x[i * 8 + 7], 7);
1099        }
1100    }
1101
1102    // Repack the quants
1103    for (int i = 0; i < nb; i++) {
1104        uint8_t qs[QK_MXFP4x4x2];  // unpacked quants
1105
1106        unpack_mxfp4_quants(qs, &x[i * 8 + 0], 0);
1107        unpack_mxfp4_quants(qs, &x[i * 8 + 1], 1);
1108        unpack_mxfp4_quants(qs, &x[i * 8 + 2], 2);
1109        unpack_mxfp4_quants(qs, &x[i * 8 + 3], 3);
1110        unpack_mxfp4_quants(qs, &x[i * 8 + 4], 4);
1111        unpack_mxfp4_quants(qs, &x[i * 8 + 5], 5);
1112        unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6);
1113        unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7);
1114
1115        uint8_t * q = y_q + (i * qblk_size);
1116        for (int j = 0; j < qk / 2; j++) {
1117            q[j] = (qs[j + 128] << 4) | qs[j];
1118        }
1119    }
1120
1121    // Repack the scales
1122    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
1123    // the last block is truncated and overriden by the scales.
1124    for (int i = 0; i < nb; i++) {
1125        // Repack the scales
1126        uint8_t * e = (uint8_t *) (y_e + i * eblk_size);
1127        e[0]        = x[i * 8 + 0].e;
1128        e[1]        = x[i * 8 + 1].e;
1129        e[2]        = x[i * 8 + 2].e;
1130        e[3]        = x[i * 8 + 3].e;
1131        e[4]        = x[i * 8 + 4].e;
1132        e[5]        = x[i * 8 + 5].e;
1133        e[6]        = x[i * 8 + 6].e;
1134        e[7]        = x[i * 8 + 7].e;
1135    }
1136
1137    if (opt_verbose > 1) {
1138        for (int i = 0; i < nb; i++) {
1139            dump_packed_block_mxfp4x4x2(y, i, k);
1140        }
1141    }
1142}
1143
1144static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) {
1145    static const int qk = QK_MXFP4x4x2;
1146    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
1147
1148    const int eblk_size = 8 * 1;              // 8x E8M0
1149    const int qblk_size = qk / 2;             // int4
1150    const int qrow_size = k / 2;              // int4 (not padded to blocks)
1151
1152    const uint8_t * y_q = y + 0;              // quants first
1153    const uint8_t * y_e = y + qrow_size;      // then scales
1154
1155    if (opt_verbose > 1) {
1156        for (int i = 0; i < nb; i++) {
1157            dump_packed_block_mxfp4x4x2(y, i, k);
1158        }
1159    }
1160
1161    // Unpack the quants
1162    for (int i = 0; i < nb; i++) {
1163        uint8_t qs[QK_MXFP4x4x2];  // unpacked quants
1164
1165        const uint8_t * q = y_q + (i * qblk_size);
1166        for (int j = 0; j < qk / 2; j++) {
1167            qs[j]       = q[j] & 0xf;
1168            qs[j + 128] = q[j] >> 4;
1169        }
1170
1171        pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
1172        pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);
1173        pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);
1174        pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);
1175        pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);
1176        pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);
1177        pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);
1178        pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);
1179    }
1180
1181    // Repack the scales
1182    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2)
1183    // the last block is truncated and overriden by the scales.
1184    for (int i = 0; i < nb; i++) {
1185        // Unpack the scales
1186        const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size);
1187        x[i * 8 + 0].e    = e[0];
1188        x[i * 8 + 1].e    = e[1];
1189        x[i * 8 + 2].e    = e[2];
1190        x[i * 8 + 3].e    = e[3];
1191        x[i * 8 + 4].e    = e[4];
1192        x[i * 8 + 5].e    = e[5];
1193        x[i * 8 + 6].e    = e[6];
1194        x[i * 8 + 7].e    = e[7];
1195    }
1196
1197    if (opt_verbose > 2) {
1198        for (int i = 0; i < nb; i++) {
1199            dump_block_mxfp4(&x[i * 8 + 0], 0);
1200            dump_block_mxfp4(&x[i * 8 + 1], 1);
1201            dump_block_mxfp4(&x[i * 8 + 2], 2);
1202            dump_block_mxfp4(&x[i * 8 + 3], 3);
1203            dump_block_mxfp4(&x[i * 8 + 4], 4);
1204            dump_block_mxfp4(&x[i * 8 + 5], 5);
1205            dump_block_mxfp4(&x[i * 8 + 6], 6);
1206            dump_block_mxfp4(&x[i * 8 + 7], 7);
1207        }
1208    }
1209}
1210
1211static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) {
1212    static const int qk = QK_MXFP4x4x2;
1213    const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
1214
1215    // Init the quants such that they unpack into zeros
1216    uint8_t qs[QK_MXFP4x4x2];  // unpacked quants
1217    memset(qs, 0, sizeof(qs));
1218
1219    for (int i = 0; i < nb; i++) {
1220        pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
1221        pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);
1222        pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);
1223        pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);
1224        pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);
1225        pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);
1226        pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);
1227        pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);
1228    }
1229
1230    // Init the scales
1231    // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
1232    // the last block is truncated and overriden by the scales.
1233    for (int i = 0; i < nb; i++) {
1234        // Unpack the scales
1235        x[i * 8 + 0].e = 0;
1236        x[i * 8 + 1].e = 0;
1237        x[i * 8 + 2].e = 0;
1238        x[i * 8 + 3].e = 0;
1239        x[i * 8 + 4].e = 0;
1240        x[i * 8 + 5].e = 0;
1241        x[i * 8 + 6].e = 0;
1242        x[i * 8 + 7].e = 0;
1243    }
1244}
1245
1246// repack mxfp4 data into mxfp4x4x2 tensor
1247static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t size) {
1248    int64_t nrows = ggml_nrows(t);
1249
1250    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
1251    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2));  // extra elements for the pad
1252    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
1253
1254    // Ensure we don't try to read more data than is available in the source buffer 'data'
1255    // or write more than the tensor can hold.
1256    const size_t total_tensor_size = (size_t)nrows * row_size;
1257    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
1258
1259    // Calculate how many full rows and how many remaining bytes we need to process.
1260    const int64_t n_full_rows = n_bytes_to_copy / row_size;
1261    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
1262
1263    void * buf_pd = ggml_aligned_malloc(row_size_pd);
1264    GGML_ASSERT(buf_pd != NULL);
1265
1266    void * buf_rp = ggml_aligned_malloc(row_size_rp);
1267    GGML_ASSERT(buf_rp != NULL);
1268
1269    HEX_VERBOSE("ggml-hex: repack-mxfp4-mxfp4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data,
1270                size, t->ne[0], nrows, row_size);
1271
1272    init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros
1273
1274    // 1. Process all the full rows
1275    for (int64_t i = 0; i < n_full_rows; i++) {
1276        const uint8_t * src = (const uint8_t *) data + (i * row_size);
1277        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
1278
1279        memcpy(buf_pd, src, row_size);
1280        repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
1281        memcpy(dst, buf_rp, row_size);
1282    }
1283
1284    // 2. Process the final, potentially partial, row
1285    if (n_rem_bytes > 0) {
1286        const int64_t i = n_full_rows;
1287        const uint8_t * src = (const uint8_t *) data + (i * row_size);
1288        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
1289
1290        // re-init the row because we are potentially copying a partial row
1291        init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);
1292
1293        // Copy only the remaining bytes from the source.
1294        memcpy(buf_pd, src, n_rem_bytes);
1295
1296        // Repack the entire buffer (partial data + zero padding).
1297        repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
1298
1299        // Write only the corresponding remaining bytes to the destination tensor.
1300        memcpy(dst, buf_rp, n_rem_bytes);
1301    }
1302
1303    ggml_aligned_free(buf_pd, row_size_pd);
1304    ggml_aligned_free(buf_rp, row_size_rp);
1305}
1306
1307// repack mxfp4x4x2 tensor into mxfp4 data
1308static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t size) {
1309    int64_t nrows = ggml_nrows(t);
1310
1311    size_t row_size    = ggml_row_size(t->type, t->ne[0]);
1312    size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2));  // extra elements for the pad
1313    size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
1314
1315    // Ensure we don't try to copy more data than the tensor actually contains.
1316    const size_t total_tensor_size = (size_t)nrows * row_size;
1317    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
1318
1319    // Calculate how many full rows and how many remaining bytes we need to process.
1320    const int64_t n_full_rows = n_bytes_to_copy / row_size;
1321    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
1322
1323    void * buf_pd = ggml_aligned_malloc(row_size_pd);
1324    GGML_ASSERT(buf_pd != NULL);
1325
1326    void * buf_rp = ggml_aligned_malloc(row_size_rp);
1327    GGML_ASSERT(buf_rp != NULL);
1328
1329    HEX_VERBOSE("ggml-hex: repack-mxfp4x4x2-mxfp4 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data,
1330                size, t->ne[0], nrows, row_size);
1331
1332    memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros
1333
1334    // 1. Process all the full rows
1335    for (int64_t i = 0; i < n_full_rows; i++) {
1336        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
1337        uint8_t *       dst = (uint8_t *) data + (i * row_size);
1338
1339        memcpy(buf_pd, src, row_size);
1340        unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
1341        memcpy(dst, buf_rp, row_size);
1342    }
1343
1344    // 2. Process the final, potentially partial, row
1345    if (n_rem_bytes > 0) {
1346        const int64_t i = n_full_rows;
1347        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
1348        uint8_t *       dst = (uint8_t *) data + (i * row_size);
1349
1350        // We still need to read and unpack the entire source row because the format is block-based.
1351        memcpy(buf_pd, src, row_size);
1352        unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
1353
1354        // But we only copy the remaining number of bytes to the destination to respect the size limit.
1355        memcpy(dst, buf_rp, n_rem_bytes);
1356    }
1357
1358    ggml_aligned_free(buf_pd, row_size_pd);
1359    ggml_aligned_free(buf_rp, row_size_rp);
1360}
1361
1362static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
1363                                                   ggml_tensor *         tensor,
1364                                                   const void *          data,
1365                                                   size_t                offset,
1366                                                   size_t                size) {
1367    auto ctx  = (ggml_backend_hexagon_buffer_context *) buffer->context;
1368    auto sess = ctx->sess;
1369
1370    HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data,
1371                offset, size);
1372
1373    switch (tensor->type) {
1374        case GGML_TYPE_Q4_0:
1375            GGML_ASSERT(offset == 0);
1376            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1377            repack_q4_0_q4x4x2(tensor, data, size);
1378            break;
1379
1380        case GGML_TYPE_Q8_0:
1381            GGML_ASSERT(offset == 0);
1382            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1383            repack_q8_0_q8x4x2(tensor, data, size);
1384            break;
1385
1386        case GGML_TYPE_MXFP4:
1387            GGML_ASSERT(offset == 0);
1388            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1389            repack_mxfp4_mxfp4x4x2(tensor, data, size);
1390            break;
1391
1392        default:
1393            memcpy((char *) tensor->data + offset, data, size);
1394            break;
1395    }
1396}
1397
1398static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
1399                                                   const ggml_tensor *   tensor,
1400                                                   void *                data,
1401                                                   size_t                offset,
1402                                                   size_t                size) {
1403    auto ctx  = (ggml_backend_hexagon_buffer_context *) buffer->context;
1404    auto sess = ctx->sess;
1405
1406    HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data,
1407                offset, size);
1408
1409    switch (tensor->type) {
1410        case GGML_TYPE_Q4_0:
1411            GGML_ASSERT(offset == 0);
1412            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1413            repack_q4x4x2_q4_0(data, tensor, size);
1414            break;
1415
1416        case GGML_TYPE_Q8_0:
1417            GGML_ASSERT(offset == 0);
1418            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1419            repack_q8x4x2_q8_0(data, tensor, size);
1420            break;
1421
1422        case GGML_TYPE_MXFP4:
1423            GGML_ASSERT(offset == 0);
1424            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1425            repack_mxfp4x4x2_mxfp4(data, tensor, size);
1426            break;
1427
1428        default:
1429            memcpy(data, (const char *) tensor->data + offset, size);
1430            break;
1431    }
1432}
1433
1434static bool ggml_backend_hexagon_buffer_cpy_tensor(ggml_backend_buffer_t      buffer,
1435                                                   const struct ggml_tensor * src,
1436                                                   struct ggml_tensor *       dst) {
1437    GGML_UNUSED(buffer);
1438    GGML_UNUSED(src);
1439    GGML_UNUSED(dst);
1440    // we might optimize this later, for now take the slow path (ie get/set_tensor)
1441    return false;
1442}
1443
1444static void ggml_backend_hexagon_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1445    auto ctx  = (ggml_backend_hexagon_buffer_context *) buffer->context;
1446    auto sess = ctx->sess;
1447    HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->name.c_str(), (void *) ctx->base, ctx->size);
1448    memset(ctx->base, value, ctx->size);
1449}
1450
1451static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = {
1452    /* .free_buffer     = */ ggml_backend_hexagon_buffer_free_buffer,
1453    /* .get_base        = */ ggml_backend_hexagon_buffer_get_base,
1454    /* .init_tensor     = */ ggml_backend_hexagon_buffer_init_tensor,
1455    /* .memset_tensor   = */ NULL,
1456    /* .set_tensor      = */ ggml_backend_hexagon_buffer_set_tensor,
1457    /* .get_tensor      = */ ggml_backend_hexagon_buffer_get_tensor,
1458    /* .cpy_tensor      = */ ggml_backend_hexagon_buffer_cpy_tensor,
1459    /* .clear           = */ ggml_backend_hexagon_buffer_clear,
1460    /* .reset           = */ NULL,
1461};
1462
1463// ** backend buffer type
1464
1465static const char * ggml_backend_hexagon_buffer_type_name(ggml_backend_buffer_type_t buffer_type) {
1466    return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->name.c_str();
1467}
1468
1469static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer(
1470            ggml_backend_buffer_type_t buffer_type, size_t size) {
1471    auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
1472    try {
1473        ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/);
1474        return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);
1475    } catch (const std::exception & exc) {
1476        GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what());
1477        return nullptr;
1478    }
1479}
1480
1481static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffer(
1482            ggml_backend_buffer_type_t buffer_type, size_t size) {
1483    auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
1484    try {
1485        ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/);
1486        return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);
1487    } catch (const std::exception & exc) {
1488        GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what());
1489        return nullptr;
1490    }
1491}
1492
1493static size_t ggml_backend_hexagon_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
1494    return 128;  // HVX alignment
1495    GGML_UNUSED(buffer_type);
1496}
1497
1498static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * t) {
1499    return ggml_nbytes(t);
1500}
1501
1502static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
1503    return 1 * 1024 * 1024 * 1024;  // 1GB per buffer
1504    GGML_UNUSED(buffer_type);
1505}
1506
1507static bool ggml_backend_hexagon_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1508    return opt_hostbuf;
1509    GGML_UNUSED(buft);
1510}
1511
1512static bool ggml_backend_hexagon_repack_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1513    return false;
1514    GGML_UNUSED(buft);
1515}
1516
1517static ggml_backend_buffer_type_i ggml_backend_hexagon_buffer_type_interface = {
1518    /* .get_name         = */ ggml_backend_hexagon_buffer_type_name,
1519    /* .alloc_buffer     = */ ggml_backend_hexagon_buffer_type_alloc_buffer,
1520    /* .get_alignment    = */ ggml_backend_hexagon_buffer_type_get_alignment,
1521    /* .get_max_size     = */ ggml_backend_hexagon_buffer_type_get_max_size,
1522    /* .get_alloc_size   = */ ggml_backend_hexagon_buffer_type_get_alloc_size,
1523    /* .is_host          = */ ggml_backend_hexagon_buffer_type_is_host,
1524};
1525
1526static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interface = {
1527    /* .get_name         = */ ggml_backend_hexagon_buffer_type_name,
1528    /* .alloc_buffer     = */ ggml_backend_hexagon_repack_buffer_type_alloc_buffer,
1529    /* .get_alignment    = */ ggml_backend_hexagon_buffer_type_get_alignment,
1530    /* .get_max_size     = */ ggml_backend_hexagon_buffer_type_get_max_size,
1531    /* .get_alloc_size   = */ ggml_backend_hexagon_buffer_type_get_alloc_size,
1532    /* .is_host          = */ ggml_backend_hexagon_repack_buffer_type_is_host,
1533};
1534
1535void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
1536    this->valid_session = false;
1537    this->valid_handle  = false;
1538    this->valid_queue   = false;
1539    this->valid_iface   = false;
1540
1541    this->domain_id  = 3;  // Default for CDSP, updated after the session is created
1542    this->session_id = 0;  // Default for CDSP, updated after the session is created
1543    this->dev_id     = dev_id;
1544    this->name       = std::string("HTP") + std::to_string(dev_id);
1545
1546    this->op_pending  = 0;
1547    this->prof_usecs  = 0;
1548    this->prof_cycles = 0;
1549    this->prof_pkts   = 0;
1550
1551    GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str());
1552
1553    domain * my_domain = get_domain(this->domain_id);
1554    if (my_domain == NULL) {
1555        GGML_LOG_ERROR("ggml-hex: unable to get domain struct for CDSP\n");
1556        throw std::runtime_error("ggml-hex: failed to get CDSP domain (see log for details)");
1557    }
1558
1559    // Create new session
1560    if (dev_id != 0) {
1561        struct remote_rpc_reserve_new_session n;
1562        n.domain_name_len  = strlen(CDSP_DOMAIN_NAME);
1563        n.domain_name      = const_cast<char *>(CDSP_DOMAIN_NAME);
1564        n.session_name     = const_cast<char *>(this->name.c_str());
1565        n.session_name_len = this->name.size();
1566
1567        int err = remote_session_control(FASTRPC_RESERVE_NEW_SESSION, (void *) &n, sizeof(n));
1568        if (err != AEE_SUCCESS) {
1569            GGML_LOG_ERROR("ggml-hex: failed to reserve new session %d : error 0x%x\n", dev_id, err);
1570            throw std::runtime_error("ggml-hex: remote_session_control(new-sess) failed (see log for details)");
1571        }
1572
1573        // Save the IDs
1574        this->session_id    = n.session_id;
1575        this->domain_id     = n.effective_domain_id;
1576        this->valid_session = true;
1577    }
1578
1579    // Get session URI
1580
1581    char session_uri[256];
1582    {
1583        char htp_uri[256];
1584        snprintf(htp_uri, sizeof(htp_uri), "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch);
1585
1586        struct remote_rpc_get_uri u = {};
1587        u.session_id      = this->session_id;
1588        u.domain_name     = const_cast<char *>(CDSP_DOMAIN_NAME);
1589        u.domain_name_len = strlen(CDSP_DOMAIN_NAME);
1590        u.module_uri      = const_cast<char *>(htp_uri);
1591        u.module_uri_len  = strlen(htp_uri);
1592        u.uri             = session_uri;
1593        u.uri_len         = sizeof(session_uri);
1594
1595        int err = remote_session_control(FASTRPC_GET_URI, (void *) &u, sizeof(u));
1596        if (err != AEE_SUCCESS) {
1597            // fallback to single session uris
1598            int htp_URI_domain_len = strlen(htp_uri) + MAX_DOMAIN_NAMELEN;
1599
1600            snprintf(session_uri, htp_URI_domain_len, "%s%s", htp_uri, my_domain->uri);
1601
1602            GGML_LOG_WARN("ggml-hex: failed to get URI for session %d : error 0x%x. Falling back to single session URI: %s\n", dev_id, err, session_uri);
1603        }
1604    }
1605
1606    // Enable Unsigned PD
1607    {
1608        struct remote_rpc_control_unsigned_module u;
1609        u.domain = this->domain_id;
1610        u.enable = 1;
1611        int err  = remote_session_control(DSPRPC_CONTROL_UNSIGNED_MODULE, (void *) &u, sizeof(u));
1612        if (err != AEE_SUCCESS) {
1613            GGML_LOG_ERROR("ggml-hex: failed to enable unsigned PD for session %d : error 0x%x\n", dev_id, err);
1614            throw std::runtime_error("ggml-hex: remote_session_control(unsign) failed (see log for details)");
1615        }
1616    }
1617
1618    // Open session
1619    int err = htp_iface_open(session_uri, &this->handle);
1620    if (err != AEE_SUCCESS) {
1621        GGML_LOG_ERROR("ggml-hex: failed to open session %d : error 0x%x\n", dev_id, err);
1622        throw std::runtime_error("ggml-hex: failed to open session (see log for details)");
1623    }
1624
1625    this->valid_handle = true;
1626
1627    GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(),
1628                  this->session_id, this->domain_id, session_uri, (unsigned long) this->handle);
1629
1630    // Enable FastRPC QoS mode
1631    {
1632        struct remote_rpc_control_latency l;
1633        l.enable = 1;
1634
1635        int err = remote_handle64_control(this->handle, DSPRPC_CONTROL_LATENCY, (void *) &l, sizeof(l));
1636        if (err != 0) {
1637            GGML_LOG_WARN("ggml-hex: failed to enable fastrpc QOS mode: 0x%08x\n", (unsigned) err);
1638        }
1639    }
1640
1641    // Now let's setup the DSP queue
1642    err = dspqueue_create(this->domain_id,
1643                          0,              // Flags
1644                          128 * 1024,     // Request  queue size (in bytes)
1645                          64 * 1024,      // Response queue size (in bytes)
1646                          nullptr,        // Read packet callback (we handle reads explicitly)
1647                          nullptr,        // Error callback (we handle errors during reads)
1648                          (void *) this,  // Callback context
1649                          &queue);
1650    if (err != 0) {
1651        GGML_LOG_ERROR("ggml-hex: %s dspqueue_create failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
1652        throw std::runtime_error("ggml-hex: failed to create dspqueue (see log for details)");
1653    }
1654
1655    this->valid_queue = true;
1656
1657    // Export queue for use on the DSP
1658    err = dspqueue_export(queue, &this->queue_id);
1659    if (err != 0) {
1660        GGML_LOG_ERROR("ggml-hex: dspqueue_export failed: 0x%08x\n", (unsigned) err);
1661        throw std::runtime_error("ggml-hex: dspqueue export failed (see log for details)");
1662    }
1663
1664    if (opt_etm) {
1665        err = htp_iface_enable_etm(this->handle);
1666        if (err != 0) {
1667            GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err);
1668        }
1669    }
1670
1671    // Start the DSP-side service. We need to pass the queue ID to the
1672    // DSP in a FastRPC call; the DSP side will import the queue and start
1673    // listening for packets in a callback.
1674    err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx);
1675    if (err != 0) {
1676        GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
1677        throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
1678    }
1679    this->valid_iface = true;
1680}
1681
1682void ggml_hexagon_session::release() noexcept(true) {
1683    GGML_LOG_INFO("ggml-hex: releasing session: %s\n", this->name.c_str());
1684
1685    int err;
1686
1687    // Stop the DSP-side service and close the queue
1688    if (this->valid_iface) {
1689        err = htp_iface_stop(this->handle);
1690        if (err != 0) {
1691            GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err);
1692        }
1693    }
1694
1695    if (opt_etm) {
1696        err = htp_iface_disable_etm(this->handle);
1697        if (err != 0) {
1698            GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err);
1699        }
1700    }
1701
1702    if (this->valid_queue) {
1703        err = dspqueue_close(queue);
1704        if (err != 0) {
1705            GGML_ABORT("ggml-hex: dspqueue_close failed: 0x%08x\n", (unsigned) err);
1706        }
1707    }
1708
1709    if (this->valid_handle) {
1710        htp_iface_close(this->handle);
1711    }
1712}
1713
1714ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) {
1715    buffer_type.device        = dev;
1716    repack_buffer_type.device = dev;
1717
1718    try {
1719        allocate(dev_id);
1720
1721        buffer_type.iface   = ggml_backend_hexagon_buffer_type_interface;
1722        buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name, this);
1723
1724        repack_buffer_type.iface   = ggml_backend_hexagon_repack_buffer_type_interface;
1725        repack_buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name + "-REPACK", this);
1726    } catch (const std::exception & exc) {
1727        release();
1728        throw;
1729    }
1730}
1731
1732ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {
1733    release();
1734
1735    delete static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type.context);
1736    delete static_cast<ggml_backend_hexagon_buffer_type_context *>(repack_buffer_type.context);
1737}
1738
1739// ** backend interface
1740
1741static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) {
1742    return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment;
1743}
1744
1745static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) {
1746    if (!opt_hostbuf) {
1747        return ggml_backend_buffer_is_hexagon(b);
1748    }
1749    return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
1750}
1751
1752static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
1753    if (x->ne[0] != y->ne[0]) {
1754        return false;
1755    }
1756    if (x->ne[1] != y->ne[1]) {
1757        return false;
1758    }
1759    if (x->ne[2] != y->ne[2]) {
1760        return false;
1761    }
1762    if (x->ne[3] != y->ne[3]) {
1763        return false;
1764    }
1765
1766    return true;
1767}
1768
1769static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1770    const struct ggml_tensor * src0 = op->src[0];
1771    const struct ggml_tensor * src1 = op->src[1];
1772    const struct ggml_tensor * src2 = op->src[2];
1773    const struct ggml_tensor * src3 = op->src[3];
1774    const struct ggml_tensor * src4 = op->src[4];
1775    const struct ggml_tensor * dst  = op;
1776
1777    // Check for F16 support only as requested
1778    if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {
1779        return false;
1780    }
1781
1782    if (src3 && src3->type != GGML_TYPE_F16) {  // mask
1783        return false;
1784    }
1785
1786    if (src4 && src4->type != GGML_TYPE_F32) {  // sinks
1787        return false;
1788    }
1789
1790    // For now we support F32 or F16 output as htp backend often converts output on the fly if needed,
1791    // but the op implementation writes to F16 or F32.
1792    // Let's assume dst can be F32 or F16.
1793    if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
1794        return false;
1795    }
1796
1797    return opt_experimental;
1798}
1799
1800static bool hex_supported_src0_type(ggml_type t) {
1801    return t == GGML_TYPE_F32;
1802}
1803
1804static bool hex_supported_src1_type(ggml_type t) {
1805    return t == GGML_TYPE_F32;
1806}
1807
1808static bool hex_supported_src2_type(ggml_type t) {
1809    return t == GGML_TYPE_F32;
1810}
1811
1812static bool hex_supported_src1_type2(ggml_type t) {
1813    return t == GGML_TYPE_F16;
1814}
1815
1816static bool hex_supported_src1_type3(ggml_type t) {
1817    return t == GGML_TYPE_I32;
1818}
1819
1820static bool hex_supported_dst_type(ggml_type t) {
1821    return t == GGML_TYPE_F32;
1822}
1823
1824static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
1825    // TODO: support broadcast for ne[2 and 3]
1826    if (x->ne[0] != y->ne[0]) {
1827        return false;
1828    }
1829    if (x->ne[2] != y->ne[2]) {
1830        return false;
1831    }
1832    if (x->ne[3] != y->ne[3]) {
1833        return false;
1834    }
1835    return true;
1836}
1837
1838static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
1839    const struct ggml_tensor * src0 = dst->src[0];
1840    const struct ggml_tensor * src1 = dst->src[1];
1841
1842    if (dst->type != GGML_TYPE_F32) {
1843        return false;
1844    }
1845
1846    if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
1847        return false;
1848    }
1849
1850    switch (src0->type) {
1851        case GGML_TYPE_Q4_0:
1852        case GGML_TYPE_Q8_0:
1853        case GGML_TYPE_MXFP4:
1854            if (src0->ne[0] % 32) {
1855                return false;
1856            }
1857
1858            if (src0->ne[1] > 16 * 1024) {
1859                return false;  // typically the lm-head which would be too large for VTCM
1860            }
1861
1862            if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
1863                return false;
1864            }
1865
1866            // src0 (weights) must be repacked
1867            if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {
1868                return false;
1869            }
1870            break;
1871
1872        case GGML_TYPE_F16:
1873            if (src0->nb[1] < src0->nb[0]) {
1874                GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n");
1875                return false;
1876            }
1877            break;
1878
1879        default:
1880            return false;
1881    }
1882
1883    return true;
1884}
1885
1886static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1887    const struct ggml_tensor * src0 = op->src[0];
1888    const struct ggml_tensor * src1 = op->src[1];
1889    const struct ggml_tensor * src2 = op->src[2];
1890    const struct ggml_tensor * dst  = op;
1891
1892    if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32 || src2->type != GGML_TYPE_I32) {
1893        return false;
1894    }
1895
1896    switch (src0->type) {
1897        case GGML_TYPE_Q4_0:
1898        case GGML_TYPE_Q8_0:
1899        case GGML_TYPE_MXFP4:
1900            if ((src0->ne[0] % 32)) {
1901                return false;
1902            }
1903
1904            // src0 (weights) must be repacked
1905            if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {
1906                return false;
1907            }
1908            break;
1909
1910        default:
1911            return false;
1912    }
1913
1914    return true;
1915}
1916
1917static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1918    const struct ggml_tensor * src0 = op->src[0];
1919    const struct ggml_tensor * src1 = op->src[1];
1920    const struct ggml_tensor * dst  = op;
1921
1922    if (!hex_supported_src0_type(src0->type)) {
1923        return false;
1924    }
1925    if (!hex_supported_src1_type(src1->type)) {
1926        return false;
1927    }
1928    if (!hex_supported_dst_type(dst->type)) {
1929        return false;
1930    }
1931    if (!hex_supported_dims2(src0, dst)) {
1932        return false;
1933    }
1934    if (!ggml_can_repeat(src1, src0)) {
1935        return false;
1936    }
1937
1938    return true;
1939}
1940
1941static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1942    const struct ggml_tensor * src0 = op->src[0];
1943    const struct ggml_tensor * src1 = op->src[1];
1944    const struct ggml_tensor * dst  = op;
1945
1946    if (!hex_supported_src0_type(src0->type)) {
1947        return false;
1948    }
1949    if (!hex_supported_src1_type(src1->type)) {
1950        return false;
1951    }
1952    if (!hex_supported_dst_type(dst->type)) {
1953        return false;
1954    }
1955    if (!hex_supported_dims2(src0, dst)) {
1956        return false;
1957    }
1958
1959    // REVISIT: add support for non-contigiuos tensors
1960    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
1961        return false;
1962    }
1963
1964    return true;
1965}
1966
1967static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1968    const struct ggml_tensor * src0 = op->src[0];
1969    const struct ggml_tensor * dst  = op;
1970
1971    if (!hex_supported_src0_type(src0->type)) {
1972        return false;
1973    }
1974    if (!hex_supported_dst_type(dst->type)) {
1975        return false;
1976    }
1977    if (!hex_supported_dims2(src0, dst)) {
1978        return false;
1979    }
1980
1981    // TODO: add support for non-contigiuos tensors
1982    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
1983        return false;
1984    }
1985
1986    return true;
1987}
1988
1989static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1990    const struct ggml_tensor * src0 = op->src[0];
1991    const struct ggml_tensor * dst  = op;
1992
1993    if (!hex_supported_src0_type(src0->type)) {
1994        return false;
1995    }
1996    if (!hex_supported_dst_type(dst->type)) {
1997        return false;
1998    }
1999
2000    // TODO: add support for non-contigiuos tensors
2001    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
2002        return false;
2003    }
2004
2005    return true;
2006}
2007
2008static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
2009                                               const struct ggml_tensor *          op) {
2010    const struct ggml_tensor * src0 = op->src[0];
2011    const struct ggml_tensor * src1 = op->src[1];
2012    const struct ggml_tensor * dst  = op;
2013
2014    if (!hex_supported_src0_type(src0->type)) {
2015        return false;
2016    }
2017    if (!hex_supported_dst_type(dst->type)) {
2018        return false;
2019    }
2020
2021    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
2022        return false;
2023    }
2024
2025    if (src1) {
2026        if (!hex_supported_src1_type(src1->type)) {
2027            return false;
2028        }
2029        if (!hex_supported_dims2(src0, src1)) {
2030            return false;
2031        }
2032        if (!ggml_is_contiguous(src1)) {
2033            return false;
2034        }
2035    }
2036
2037    return true;
2038}
2039
2040static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2041    const struct ggml_tensor * src0 = op->src[0];
2042    const struct ggml_tensor * src1 = op->src[1];
2043    const struct ggml_tensor * src2 = op->src[2];
2044    const struct ggml_tensor * dst  = op;
2045
2046    if (src2) {
2047        return false;  // FIXME: add support for sinks
2048    }
2049
2050    if (!hex_supported_src0_type(src0->type)) {
2051        return false;
2052    }
2053    if (!hex_supported_dst_type(dst->type)) {
2054        return false;
2055    }
2056
2057    if (src1) {
2058        if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
2059            return false;
2060        }
2061        if (src0->ne[0] != src1->ne[0]) {
2062            return false;
2063        }
2064        if (src1->ne[1] < src0->ne[1]) {
2065            return false;
2066        }
2067        if (src0->ne[2] % src1->ne[2] != 0) {
2068            return false;
2069        }
2070        if (src0->ne[3] % src1->ne[3] != 0) {
2071            return false;
2072        }
2073    }
2074
2075    if (src1) {
2076        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
2077            return false;
2078        }
2079    } else {
2080        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
2081            return false;
2082        }
2083    }
2084
2085    return true;
2086}
2087
2088static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2089    const struct ggml_tensor * src0 = op->src[0]; // values
2090    const struct ggml_tensor * src1 = op->src[1]; // indices
2091    const struct ggml_tensor * dst  = op;
2092
2093    if (src0->type != GGML_TYPE_F32) {
2094        return false;
2095    }
2096
2097    if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
2098        return false;
2099    }
2100
2101    if (dst->type != GGML_TYPE_F16) {
2102        return false;
2103    }
2104
2105    return true;
2106}
2107
2108static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2109    const struct ggml_tensor * src0 = op->src[0]; // values
2110    const struct ggml_tensor * src1 = op->src[1]; // indices
2111    const struct ggml_tensor * dst  = op;
2112
2113    if (src0->type != GGML_TYPE_F32) {
2114        return false;
2115    }
2116
2117    if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
2118        return false;
2119    }
2120
2121    if (dst->type != GGML_TYPE_F32) {
2122        return false;
2123    }
2124
2125    return true;
2126}
2127
2128static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2129    const struct ggml_tensor * src0 = op->src[0]; // values
2130    const struct ggml_tensor * dst  = op;         // indices
2131
2132    if (src0->type != GGML_TYPE_F32) {
2133        return false;
2134    }
2135
2136    if (dst->type != GGML_TYPE_I32) {
2137        return false;
2138    }
2139
2140    if (src0->ne[0] > (16*1024)) {
2141        // reject tensors with huge rows for now
2142        return false;
2143    }
2144
2145    return true;
2146}
2147
2148static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2149    const int32_t * op_params = &op->op_params[0];
2150
2151    int mode = op_params[2];
2152
2153    if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
2154        return false;
2155    }
2156    if (mode & 1) {
2157        return false;
2158    }
2159
2160    const struct ggml_tensor * src0 = op->src[0];
2161    const struct ggml_tensor * src1 = op->src[1];
2162    const struct ggml_tensor * src2 = op->src[2];
2163    const struct ggml_tensor * dst  = op;
2164
2165    if (!hex_supported_src0_type(src0->type)) {
2166        return false;  // FIXME: add support for GGML_TYPE_F16 for src0
2167    }
2168    if (!hex_supported_dst_type(dst->type)) {
2169        return false;
2170    }
2171    if (!hex_supported_src1_type3(src1->type)) {
2172        return false;
2173    }
2174    if (src2) {
2175        if (!hex_supported_src2_type(src2->type)) {
2176            return false;
2177        }
2178        int n_dims = op_params[1];
2179        if (src2->ne[0] < (n_dims / 2)) {
2180            return false;
2181        }
2182    }
2183
2184    if (src2) {
2185        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(src2) ||
2186            !ggml_is_contiguous(dst)) {
2187            return false;
2188        }
2189    } else {
2190        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
2191            return false;
2192        }
2193    }
2194
2195    return true;
2196}
2197
2198enum dspqbuf_type {
2199    DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
2200    DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
2201    DSPQBUF_TYPE_CONSTANT,
2202};
2203
2204static void dspqbuf_dump(dspqueue_buffer * d, const struct ggml_tensor * t, dspqbuf_type type) {
2205    if (opt_verbose < 2) return;
2206
2207    auto buf  = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);
2208    auto sess = buf->sess;
2209
2210    GGML_LOG_DEBUG("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(),
2211                t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset,
2212                (unsigned int) d->size);
2213}
2214
2215// Init hexagon tensor from GGML tensor and Hexagon buffer
2216static void htp_req_tensor_init(htp_tensor * h, const ggml_tensor * t) {
2217    h->data  = 0;  // updated by the receiver
2218    h->type  = t->type;
2219    h->ne[0] = t->ne[0];
2220    h->ne[1] = t->ne[1];
2221    h->ne[2] = t->ne[2];
2222    h->ne[3] = t->ne[3];
2223    h->nb[0] = t->nb[0];
2224    h->nb[1] = t->nb[1];
2225    h->nb[2] = t->nb[2];
2226    h->nb[3] = t->nb[3];
2227}
2228
2229static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_tensor * t, dspqbuf_type type) {
2230    if (!t) {
2231        return 0;
2232    }
2233
2234    auto buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);
2235
2236    memset(d, 0, sizeof(*d));
2237    d->fd     = buf->fd;
2238    d->ptr    = t->data;
2239    d->offset = (uint8_t *) t->data - buf->base;
2240    d->size   = ggml_nbytes(t);
2241
2242    if (!d->size) {
2243        // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty
2244        d->size = 64;
2245    }
2246
2247    switch (type) {
2248        case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:
2249            // Flush CPU
2250            d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER;
2251            break;
2252        case DSPQBUF_TYPE_CPU_WRITE_DSP_READ:
2253            // Flush CPU, Invalidate DSP
2254            d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
2255            break;
2256        default:
2257            // Constant buffer, no cache maintenance
2258            d->flags = 0;
2259            break;
2260    }
2261
2262    htp_req_tensor_init(h, t);
2263
2264    dspqbuf_dump(d, t, type);
2265
2266    return 1;
2267}
2268
2269typedef size_t (*htp_req_init_func_t)(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * op);
2270
2271template <htp_req_init_func_t _init_req_func>
2272static inline void ggml_hexagon_dispatch_op(ggml_hexagon_session *sess, const struct ggml_tensor * op, uint32_t flags) {
2273    uint64_t t = ggml_time_us();
2274
2275    // Construct HTP request
2276    htp_general_req req;
2277    memset(&req, 0, sizeof(req));
2278
2279    req.flags = flags;
2280    if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) {
2281        req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
2282    }
2283    if (!(opt_opmask & HTP_OPMASK_COMPUTE)) {
2284        req.flags |= HTP_OPFLAGS_SKIP_COMPUTE;
2285    }
2286
2287    ggml_hexagon_dump_op_exec(sess->name, op, req.flags);
2288
2289    if ((opt_opmask & HTP_OPMASK_QUEUE)) {
2290        dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
2291        size_t n_bufs = _init_req_func(&req, bufs, op);
2292        sess->enqueue(req, bufs, n_bufs, opt_opsync);
2293    }
2294
2295    t = ggml_time_us() - t;
2296
2297    ggml_hexagon_dump_op_prof(sess->name, op, sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, t);
2298}
2299
2300template <bool _is_src0_constant>
2301static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2302    switch (t->op) {
2303        case GGML_OP_MUL_MAT:
2304            req->op = HTP_OP_MUL_MAT;
2305            break;
2306        case GGML_OP_MUL:
2307            req->op = HTP_OP_MUL;
2308            break;
2309        case GGML_OP_ADD:
2310            req->op = HTP_OP_ADD;
2311            break;
2312        case GGML_OP_SUB:
2313            req->op = HTP_OP_SUB;
2314            break;
2315        case GGML_OP_DIV:
2316            req->op = HTP_OP_DIV;
2317            break;
2318        default:
2319            GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
2320            break;
2321    }
2322
2323    // src0: Weights (mulmat) or First Operand (binary op).
2324    // If constant (e.g. weights), no cache management is needed.
2325    // src1: Input Activations (mulmat) or Second Operand (binary op).
2326
2327    size_t n_bufs = 0;
2328    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2329    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2330    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2331
2332    return n_bufs;
2333}
2334
2335static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2336    req->op = HTP_OP_CPY;
2337
2338    size_t n_bufs = 0;
2339    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2340    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2341
2342    return n_bufs;
2343}
2344
2345static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2346    req->op = HTP_OP_GET_ROWS;
2347
2348    size_t n_bufs = 0;
2349    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2350    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2351    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2352
2353    return n_bufs;
2354}
2355
2356static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2357    req->op = HTP_OP_ARGSORT;
2358    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2359
2360    size_t n_bufs = 0;
2361    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2362    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2363
2364    return n_bufs;
2365}
2366
2367template <bool _is_src0_constant>
2368static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2369    switch (t->op) {
2370        case GGML_OP_MUL_MAT_ID:
2371            req->op = HTP_OP_MUL_MAT_ID;
2372            break;
2373        case GGML_OP_ADD_ID:
2374            req->op = HTP_OP_ADD_ID;
2375            break;
2376        default:
2377            GGML_ABORT("ggml-hex: unsupported op: %d\n", t->op);
2378    }
2379
2380    // src0: Weights (mulmat) or Input Activations (other op).
2381    // If constant, no cache management is needed.
2382    // src1: Input Activations (mulmat) or Second Operand (binary op).
2383    // src2: Expert IDs (mulmat) or Activated Experts (other op).
2384
2385    size_t n_bufs = 0;
2386    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2387    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2388    n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2389    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2390
2391    return n_bufs;
2392}
2393
2394static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2395    req->op = HTP_OP_SET_ROWS;
2396
2397    size_t n_bufs = 0;
2398    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2399    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2400    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2401
2402    return n_bufs;
2403}
2404
2405static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2406    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2407
2408    bool supported = false;
2409
2410    switch (t->op) {
2411        case GGML_OP_RMS_NORM:
2412            req->op   = HTP_OP_RMS_NORM;
2413            supported = true;
2414            break;
2415
2416        case GGML_OP_SCALE:
2417            req->op   = HTP_OP_SCALE;
2418            supported = true;
2419            break;
2420
2421        case GGML_OP_SQR:
2422            req->op   = HTP_OP_SQR;
2423            supported = true;
2424            break;
2425
2426        case GGML_OP_SQRT:
2427            req->op   = HTP_OP_SQRT;
2428            supported = true;
2429            break;
2430
2431        case GGML_OP_UNARY:
2432            if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
2433                req->op   = HTP_OP_UNARY_SILU;
2434                supported = true;
2435            } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) {
2436                req->op   = HTP_OP_UNARY_GELU;
2437                supported = true;
2438            }
2439            break;
2440
2441        case GGML_OP_GLU:
2442            if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU) {
2443                req->op   = HTP_OP_GLU_SWIGLU;
2444                supported = true;
2445            } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
2446                req->op   = HTP_OP_GLU_SWIGLU_OAI;
2447                supported = true;
2448            } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
2449                req->op   = HTP_OP_GLU_GEGLU;
2450                supported = true;
2451            }
2452            break;
2453
2454        case GGML_OP_SOFT_MAX:
2455            req->op   = HTP_OP_SOFTMAX;
2456            supported = true;
2457            break;
2458
2459        default:
2460            break;
2461    }
2462
2463    if (!supported) {
2464        GGML_ABORT("ggml-hex: unary : unsupported op: %d\n", t->op);
2465    }
2466
2467    size_t n_bufs = 0;
2468    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2469    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2470    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2471
2472    return n_bufs;
2473}
2474
2475static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2476    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2477    req->op = HTP_OP_SUM_ROWS;
2478
2479    size_t n_bufs = 0;
2480    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2481    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2482
2483    return n_bufs;
2484}
2485
2486static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2487    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2488    req->op = HTP_OP_ROPE;
2489
2490    size_t n_bufs = 0;
2491    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2492    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2493    n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2494    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2495
2496    return n_bufs;
2497}
2498
2499static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2500    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2501    req->op = HTP_OP_FLASH_ATTN_EXT;
2502
2503    size_t n_bufs = 0;
2504    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2505    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2506    n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2507    n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2508    n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2509    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2510
2511    return n_bufs;
2512}
2513
2514static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
2515    auto sess = static_cast<ggml_hexagon_session *>(backend->context);
2516    return sess->name.c_str();
2517}
2518
2519static void ggml_backend_hexagon_free(ggml_backend_t backend) {
2520    // we just need to delete the backend here
2521    // the sessions are allocated & freed as part of the registry
2522    delete backend;
2523}
2524
2525static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
2526    return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
2527}
2528
2529static inline bool is_compute_op(ggml_tensor *node)
2530{
2531    return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
2532}
2533
2534// scan the graph and figure out last compute op index
2535static inline int last_compute_op(ggml_cgraph * graph) {
2536    int last = 0;
2537    for (int i = 0; i < graph->n_nodes; ++i) {
2538        if (is_compute_op(graph->nodes[i])) {
2539            last = i;
2540        }
2541    }
2542
2543    return last;
2544}
2545
2546static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
2547    auto sess = static_cast<ggml_hexagon_session *>(backend->context);
2548
2549    HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes);
2550
2551    const int last = last_compute_op(graph);
2552
2553    const struct ggml_tensor * prev_op = nullptr;  // prev executed op
2554
2555    for (int i = 0; i < graph->n_nodes; ++i) {
2556        ggml_tensor * node = graph->nodes[i];
2557
2558        if (!is_compute_op(node)) {
2559            continue;
2560        }
2561
2562        uint32_t flags = 0;
2563
2564        // skip quantizer if src1 is reused
2565        if (op_reuse_src1(node, prev_op)) {
2566            flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
2567        }
2568
2569        prev_op = node;
2570
2571        // ask for early notification for the last Op
2572        if (i == last) {
2573            flags |= HTP_OPFLAGS_EARLY_WAKEUP;
2574        }
2575
2576        switch (node->op) {
2577            case GGML_OP_MUL_MAT:
2578                if (ggml_is_quantized(node->src[0]->type)) {
2579                    ggml_hexagon_dispatch_op<init_binary_req<true>>(sess, node, flags);
2580                } else {
2581                    ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
2582                }
2583                break;
2584            case GGML_OP_MUL_MAT_ID:
2585                if (ggml_is_quantized(node->src[0]->type)) {
2586                    ggml_hexagon_dispatch_op<init_binary_id_req<true>>(sess, node, flags);
2587                } else {
2588                    ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
2589                }
2590                break;
2591            case GGML_OP_MUL:
2592            case GGML_OP_ADD:
2593            case GGML_OP_SUB:
2594            case GGML_OP_DIV:
2595                ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
2596                break;
2597            case GGML_OP_ADD_ID:
2598                ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
2599                break;
2600            case GGML_OP_RMS_NORM:
2601            case GGML_OP_SCALE:
2602                ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2603                break;
2604            case GGML_OP_SQR:
2605            case GGML_OP_SQRT:
2606                ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2607                break;
2608            case GGML_OP_SUM_ROWS:
2609                ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
2610                break;
2611            case GGML_OP_UNARY:
2612                if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
2613                        (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
2614                    ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2615                }
2616                break;
2617            case GGML_OP_GLU:
2618                if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
2619                        (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
2620                        (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
2621                    ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2622                }
2623                break;
2624            case GGML_OP_SOFT_MAX:
2625                ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2626                break;
2627
2628            case GGML_OP_ROPE:
2629                ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);
2630                break;
2631
2632            case GGML_OP_FLASH_ATTN_EXT:
2633                ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);
2634                break;
2635
2636            case GGML_OP_SET_ROWS:
2637                ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);
2638                break;
2639
2640            case GGML_OP_GET_ROWS:
2641                ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
2642                break;
2643
2644            case GGML_OP_CPY:
2645                ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
2646                break;
2647
2648            case GGML_OP_ARGSORT:
2649                ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
2650                break;
2651
2652            default:
2653                GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
2654        }
2655    }
2656
2657    // Wait until all pending ops complete
2658    sess->flush();
2659
2660    return GGML_STATUS_SUCCESS;
2661}
2662
2663static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) {
2664    auto sess = static_cast<ggml_hexagon_session *>(backend->context);
2665
2666    HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str());
2667
2668    // Wait until all pending ops complete
2669    sess->flush();
2670}
2671
2672struct node_info {
2673    ggml_tensor * node;
2674
2675    std::vector<ggml_tensor *> fused;
2676
2677    ggml_op op() const {
2678        return node->op;
2679    }
2680
2681    const ggml_tensor * dst() const {
2682        return fused.empty() ? node : fused.back();
2683    }
2684
2685    const ggml_tensor * src0() const {
2686        return node->src[0];
2687    }
2688
2689    const ggml_tensor * src1() const {
2690        return node->src[1];
2691    }
2692
2693    bool is_empty() const {
2694        return ggml_op_is_empty(node->op);
2695    }
2696
2697    void add_fused(ggml_tensor * t) {
2698        fused.push_back(t);
2699    }
2700
2701    bool stackable() const {
2702        switch (this->op()) {
2703            case GGML_OP_MUL_MAT:
2704            case GGML_OP_MUL_MAT_ID:
2705                return ggml_is_quantized(this->src0()->type);
2706            default:
2707                return false;
2708        }
2709    }
2710
2711    bool same_input(const node_info& n) const {
2712        return n.src1() == this->src1();
2713    }
2714};
2715
2716static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<node_info> & nodes) {
2717    const int n = nodes.size();
2718
2719    std::vector<int> res;
2720    res.reserve(n);
2721
2722    std::vector<bool> used(n, false);
2723
2724    // The main goal here is to stack the MUL_MAT ops with the same src1 input.
2725    // This allows use to reuse dynamically quantized src1 in VTCM.
2726
2727    // TODO: the current version might do incorrect reodering in cases where quantized src0
2728    //       input is an output of another Op.
2729
2730    for (int i0 = 0; i0 < n; i0++) {
2731        if (used[i0]) {
2732            continue;
2733        }
2734
2735        res.push_back(i0);
2736
2737        const auto & node0 = nodes[i0];
2738
2739        if (!node0.stackable()) {
2740            continue;
2741        }
2742
2743        // that many nodes forward to search for stackable nodes that can reuse VTCM
2744        constexpr int N_FORWARD = 16;
2745
2746        for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
2747            if (used[i1]) {
2748                continue;
2749            }
2750
2751            const auto & node1 = nodes[i1];
2752
2753            if (node1.stackable() && node1.same_input(node0)) {
2754                res.push_back(i1);
2755                used[i1] = true;
2756            }
2757        }
2758    }
2759
2760    return res;
2761}
2762
2763static void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgraph * gf) {
2764    const int n = gf->n_nodes;
2765
2766    constexpr int MAX_FUSE = 16;
2767
2768    enum ggml_op ops[MAX_FUSE];
2769
2770    std::vector<node_info> nodes;
2771    nodes.reserve(gf->n_nodes);
2772
2773    // fuse nodes:
2774    // we don't want to make reorders that break fusing, so we first pack all fusable tensors
2775    //   and perform the reorder over the fused nodes. after the reorder is done, we unfuse
2776    for (int i = 0; i < n; i++) {
2777        node_info node = {
2778            /*.node =*/gf->nodes[i],
2779            /*.fused =*/{},
2780        };
2781
2782        // fuse only ops that start with these operations
2783        // can be expanded when needed
2784        if (node.op() == GGML_OP_ADD ||
2785            node.op() == GGML_OP_NORM ||
2786            node.op() == GGML_OP_RMS_NORM) {
2787            ops[0] = node.op();
2788
2789            int f = i + 1;
2790            while (f < n && f < i + MAX_FUSE) {
2791                // conservatively allow fusing only these ops
2792                // can be expanded when needed
2793                if (gf->nodes[f]->op != GGML_OP_ADD &&
2794                    gf->nodes[f]->op != GGML_OP_MUL &&
2795                    gf->nodes[f]->op != GGML_OP_NORM &&
2796                    gf->nodes[f]->op != GGML_OP_RMS_NORM) {
2797                    break;
2798                }
2799                ops[f - i] = gf->nodes[f]->op;
2800                f++;
2801            }
2802
2803            f -= i;
2804            for (; f > 1; f--) {
2805                if (ggml_can_fuse(gf, i, ops, f)) {
2806                    break;
2807                }
2808            }
2809
2810            // add the fused tensors into the node info so we can unfuse them later
2811            for (int k = 1; k < f; k++) {
2812                ++i;
2813
2814                // the .dst() becomes the last fused tensor
2815                node.add_fused(gf->nodes[i]);
2816            }
2817        }
2818
2819        nodes.push_back(std::move(node));
2820    }
2821
2822    const auto order = ggml_hexagon_graph_optimize_reorder(nodes);
2823
2824    // unfuse
2825    {
2826        int j = 0;
2827        for (const auto i : order) {
2828            const auto & node = nodes[i];
2829
2830            gf->nodes[j++] = node.node;
2831
2832            for (auto * fused : node.fused) {
2833                gf->nodes[j++] = fused;
2834            }
2835        }
2836    }
2837}
2838
2839static struct ggml_backend_i hexagon_backend_i = {
2840    /* .get_name                = */ ggml_backend_hexagon_name,
2841    /* .free                    = */ ggml_backend_hexagon_free,
2842    /* .set_tensor_async        = */ NULL,
2843    /* .get_tensor_async        = */ NULL,
2844    /* .cpy_tensor_async        = */ NULL,
2845    /* .synchronize             = */ ggml_backend_hexagon_synchronize,
2846    /* .graph_plan_create       = */ NULL,
2847    /* .graph_plan_free         = */ NULL,
2848    /* .graph_plan_update       = */ NULL,
2849    /* .graph_plan_compute      = */ NULL,
2850    /* .graph_compute           = */ ggml_backend_hexagon_graph_compute,
2851    /* .event_record            = */ NULL,
2852    /* .event_wait              = */ NULL,
2853    /* .graph_optimize          = */ ggml_backend_hexagon_graph_optimize,
2854};
2855
2856static ggml_guid_t ggml_backend_hexagon_guid() {
2857    static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49,
2858                              0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11 };
2859    return &guid;
2860}
2861
2862bool ggml_backend_is_hexagon(ggml_backend_t backend) {
2863    return backend && backend->iface.get_name == ggml_backend_hexagon_name;
2864}
2865
2866// device interface
2867
2868static ggml_backend_t ggml_backend_hexagon_device_init(ggml_backend_dev_t dev, const char * params) {
2869    auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2870
2871    return new ggml_backend{
2872        /* .guid      = */ ggml_backend_hexagon_guid(),
2873        /* .interface = */ hexagon_backend_i,
2874        /* .device    = */ dev,
2875        /* .context   = */ sess,
2876    };
2877
2878    GGML_UNUSED(params);
2879}
2880
2881static const char * ggml_backend_hexagon_device_get_name(ggml_backend_dev_t dev) {
2882    auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2883    return sess->name.c_str();
2884
2885    GGML_UNUSED(dev);
2886}
2887
2888static const char * ggml_backend_hexagon_device_get_description(ggml_backend_dev_t dev) {
2889    return "Hexagon";
2890    GGML_UNUSED(dev);
2891}
2892
2893static void ggml_backend_hexagon_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2894    // ~2GB per session for now
2895    *free  = 2ULL * 1024 * 1024 * 1024;
2896    *total = *free;
2897
2898    GGML_UNUSED(dev);
2899}
2900
2901static enum ggml_backend_dev_type ggml_backend_hexagon_device_get_type(ggml_backend_dev_t dev) {
2902    return GGML_BACKEND_DEVICE_TYPE_GPU;
2903
2904    GGML_UNUSED(dev);
2905}
2906
2907static void ggml_backend_hexagon_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2908    props->name        = ggml_backend_hexagon_device_get_name(dev);
2909    props->description = ggml_backend_hexagon_device_get_description(dev);
2910    props->type        = ggml_backend_hexagon_device_get_type(dev);
2911    ggml_backend_hexagon_device_get_memory(dev, &props->memory_free, &props->memory_total);
2912    props->caps = {
2913        /* .async                 = */ true,
2914        /* .host_buffer           = */ (bool) opt_hostbuf,
2915        /* .buffer_from_host_ptr  = */ false,
2916        /* .events                = */ false,
2917    };
2918}
2919
2920static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_buffer_type(ggml_backend_dev_t dev) {
2921    auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2922    return &sess->buffer_type;
2923}
2924
2925static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_repack_buffer_type(ggml_backend_dev_t dev) {
2926    auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2927    return &sess->repack_buffer_type;
2928}
2929
2930static bool ggml_hexagon_supported_buffer(ggml_hexagon_session *sess, const struct ggml_tensor * t) {
2931    if (t && t->buffer) {
2932        if (ggml_backend_buffer_is_hexagon(t->buffer)      == false) return false; // not our buffer
2933        if (ggml_backend_hexagon_buffer_get_sess(t->buffer) != sess) return false; // wrong session
2934    }
2935    return true;
2936}
2937
2938static bool ggml_hexagon_supported_buffers(ggml_hexagon_session *sess, const struct ggml_tensor * t) {
2939    // all srcs & dsts must be mapped to the same session
2940    if (!ggml_hexagon_supported_buffer(sess, t)) {
2941        return false;
2942    }
2943
2944    for (int i = 0; i < GGML_MAX_SRC; i++) {
2945        if (!ggml_hexagon_supported_buffer(sess, t->src[i])) {
2946            return false;
2947        }
2948    }
2949
2950    return true;
2951}
2952
2953static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2954    const struct ggml_tensor * src0 = op->src[0];
2955    const struct ggml_tensor * dst  = op;
2956
2957    // for now we can do f32 -> f16 and f16 -> f32 (without reshaping)
2958    if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
2959    if ( dst->type != GGML_TYPE_F32 &&  dst->type != GGML_TYPE_F16) return false;
2960
2961    const bool sametype   = (src0->type == dst->type);
2962    const bool transposed = ggml_is_transposed(src0) || ggml_is_transposed(dst);
2963    const bool sameshape  = !transposed && ggml_are_same_shape(src0, dst);
2964
2965    // can handle any shape and any same-type (pretty slow if reshaping is required)
2966    if (sametype) return true;
2967
2968    // cannot handle re-shaping and type conversion at the same time
2969    if (!sameshape) return false;
2970
2971    return true;
2972}
2973
2974static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
2975    auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2976
2977    // all srcs & dsts must be mapped to the same session
2978    if (!ggml_hexagon_supported_buffers(sess, op)) {
2979        ggml_hexagon_dump_op_supp(sess->name, op, false);
2980        return false;
2981    }
2982
2983    bool supp = false;
2984    switch (op->op) {
2985        case GGML_OP_NONE:
2986        case GGML_OP_RESHAPE:
2987        case GGML_OP_VIEW:
2988        case GGML_OP_PERMUTE:
2989        case GGML_OP_TRANSPOSE:
2990            supp = true;
2991            break;
2992
2993        case GGML_OP_MUL_MAT:
2994            supp = ggml_hexagon_supported_mul_mat(sess, op);
2995            break;
2996
2997        case GGML_OP_MUL_MAT_ID:
2998            supp = ggml_hexagon_supported_mul_mat_id(sess, op);
2999            break;
3000
3001        case GGML_OP_MUL:
3002        case GGML_OP_ADD:
3003        case GGML_OP_SUB:
3004        case GGML_OP_DIV:
3005            supp = ggml_hexagon_supported_binary(sess, op);
3006            break;
3007
3008        case GGML_OP_ADD_ID:
3009            supp = ggml_hexagon_supported_add_id(sess, op);
3010            break;
3011
3012        case GGML_OP_RMS_NORM:
3013        case GGML_OP_SCALE:
3014            supp = ggml_hexagon_supported_unary(sess, op);
3015            break;
3016
3017        case GGML_OP_SQR:
3018        case GGML_OP_SQRT:
3019            supp = ggml_hexagon_supported_unary(sess, op);
3020            break;
3021
3022        case GGML_OP_SUM_ROWS:
3023            supp = ggml_hexagon_supported_sum_rows(sess, op);
3024            break;
3025
3026        case GGML_OP_SOFT_MAX:
3027            supp = ggml_hexagon_supported_softmax(sess, op);
3028            break;
3029
3030        case GGML_OP_UNARY:
3031            {
3032                const auto unary_op = ggml_get_unary_op(op);
3033                if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) {
3034                    supp = ggml_hexagon_supported_activations(sess, op);
3035                }
3036                break;
3037            }
3038        case GGML_OP_GLU:
3039            {
3040                const auto glu_op = ggml_get_glu_op(op);
3041                if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
3042                    supp = ggml_hexagon_supported_activations(sess, op);
3043                }
3044                break;
3045            }
3046        case GGML_OP_ROPE:
3047            supp = ggml_hexagon_supported_rope(sess, op);
3048            break;
3049
3050        case GGML_OP_FLASH_ATTN_EXT:
3051            supp = ggml_hexagon_supported_flash_attn_ext(sess, op);
3052            break;
3053
3054        case GGML_OP_SET_ROWS:
3055            supp = ggml_hexagon_supported_set_rows(sess, op);
3056            break;
3057
3058        case GGML_OP_GET_ROWS:
3059            supp = ggml_hexagon_supported_get_rows(sess, op);
3060            break;
3061
3062        case GGML_OP_CPY:
3063            supp = ggml_hexagon_supported_cpy(sess, op);
3064            break;
3065
3066        case GGML_OP_ARGSORT:
3067            supp = ggml_hexagon_supported_argsort(sess, op);
3068            break;
3069
3070        default:
3071            break;
3072    }
3073
3074    ggml_hexagon_dump_op_supp(sess->name, op, supp);
3075    return supp;
3076}
3077
3078static bool ggml_backend_hexagon_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
3079    if (buft->iface.get_alignment != ggml_backend_hexagon_buffer_type_get_alignment) {
3080        return false;
3081    }
3082
3083    auto s0 = static_cast<ggml_hexagon_session *>(dev->context);
3084    auto s1 = static_cast<ggml_backend_hexagon_buffer_type_context *>(buft->context)->sess;
3085
3086    // Need session/domain-id for buffers to be compatible
3087    bool supp = (s0->session_id == s1->session_id);
3088
3089    HEX_VERBOSE("ggml-hex: %s device-supports-buft %s (%d)\n", s0->name.c_str(), s1->name.c_str(), (int) supp);
3090
3091    return supp;
3092}
3093
3094static ggml_backend_buffer_type_t * ggml_backend_hexagon_device_get_extra_buffers_type(ggml_backend_dev_t dev) {
3095    auto s0 = static_cast<ggml_hexagon_session *>(dev->context);
3096    HEX_VERBOSE("ggml-hex: device-get-extra-buft : %s \n", s0->name.c_str());
3097
3098    static ggml_backend_buffer_type_t bufts[2];
3099    bufts[0] = ggml_backend_hexagon_device_get_repack_buffer_type(dev);
3100    bufts[1] = NULL;
3101    return bufts;
3102}
3103
3104static const struct ggml_backend_device_i ggml_backend_hexagon_device_i = {
3105    /* .get_name             = */ ggml_backend_hexagon_device_get_name,
3106    /* .get_description      = */ ggml_backend_hexagon_device_get_description,
3107    /* .get_memory           = */ ggml_backend_hexagon_device_get_memory,
3108    /* .get_type             = */ ggml_backend_hexagon_device_get_type,
3109    /* .get_props            = */ ggml_backend_hexagon_device_get_props,
3110    /* .init_backend         = */ ggml_backend_hexagon_device_init,
3111    /* .get_buffer_type      = */ ggml_backend_hexagon_device_get_buffer_type,
3112    /* .get_host_buffer_type = */ NULL,  // ggml_backend_hexagon_device_get_host_buffer_type,
3113    /* .buffer_from_host_ptr = */ NULL,  // ggml_backend_hexagon_device_buffer_from_ptr,
3114    /* .supports_op          = */ ggml_backend_hexagon_device_supports_op,
3115    /* .supports_buft        = */ ggml_backend_hexagon_device_supports_buft,
3116    /* .offload_op           = */ NULL,  // ggml_backend_hexagon_device_offload_op,
3117    /* .event_new            = */ NULL,
3118    /* .event_free           = */ NULL,
3119    /* .event_synchronize    = */ NULL,
3120};
3121
3122//** backend registry
3123
3124#define GGML_HEXAGON_MAX_SESSIONS 16
3125
3126struct ggml_hexagon_registry {
3127    ggml_hexagon_registry(ggml_backend_reg_t reg);
3128    ~ggml_hexagon_registry();
3129
3130    ggml_backend_device devices[GGML_HEXAGON_MAX_SESSIONS];
3131};
3132
3133ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
3134    GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev);
3135
3136    if (!opt_arch) {
3137        int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch);
3138        if (err != 0) {
3139            GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err);
3140            opt_arch = 73;
3141        }
3142    }
3143
3144#if defined(__ANDROID__)
3145    if (opt_arch < 75) {
3146        opt_ndev = 1;
3147        GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
3148    }
3149#endif
3150
3151    GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
3152
3153    // Create devices / sessions
3154    for (size_t i = 0; i < opt_ndev; i++) {
3155        devices[i].iface = ggml_backend_hexagon_device_i;
3156        devices[i].reg   = reg;
3157        try {
3158            devices[i].context = new ggml_hexagon_session(i, &devices[i]);
3159        } catch (const std::exception & exc) {
3160            GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i);
3161            devices[i].context = nullptr;
3162        }
3163    }
3164}
3165
3166ggml_hexagon_registry::~ggml_hexagon_registry() {
3167    GGML_LOG_INFO("ggml-hex: releasing registry\n");
3168
3169    // Release devices / sessions
3170    for (size_t i = 0; i < opt_ndev; i++) {
3171        auto sess = static_cast<ggml_hexagon_session *>(devices[i].context);
3172        delete sess;
3173    }
3174}
3175
3176static const char * ggml_backend_hexagon_reg_get_name(ggml_backend_reg_t reg) {
3177    return "HTP";
3178    GGML_UNUSED(reg);
3179}
3180
3181static size_t ggml_backend_hexagon_reg_get_device_count(ggml_backend_reg_t reg) {
3182    return opt_ndev;
3183    GGML_UNUSED(reg);
3184}
3185
3186static ggml_backend_dev_t ggml_backend_hexagon_reg_get_device(ggml_backend_reg_t reg, size_t index) {
3187    auto hreg = static_cast<ggml_hexagon_registry *>(reg->context);
3188
3189    if (index >= opt_ndev || !hreg->devices[index].context) {
3190        return nullptr;
3191    }
3192
3193    return &hreg->devices[index];
3194}
3195
3196static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, const char * name) {
3197    if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0 && opt_hostbuf) {
3198        ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_hexagon_device_get_extra_buffers_type;
3199        return (void *) fct;
3200    }
3201
3202    return NULL;
3203}
3204
3205static void ggml_hexagon_init(ggml_backend_reg * reg) {
3206    // Basic sanity checks to make sure definitions match
3207    static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0,
3208                  "please update hexagon_type to match ggml_type");
3209    static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0,
3210                  "please update hexagon_type to match ggml_type");
3211    static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
3212                  "please update hexagon_type to match ggml_type");
3213
3214    const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL");
3215    const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
3216    const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF");
3217    const char * str_opmask  = getenv("GGML_HEXAGON_OPMASK");
3218    const char * str_opsync  = getenv("GGML_HEXAGON_OPSYNC");
3219    const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
3220    const char * str_etm     = getenv("GGML_HEXAGON_ETM");
3221    const char * str_nhvx    = getenv("GGML_HEXAGON_NHVX");
3222    const char * str_ndev    = getenv("GGML_HEXAGON_NDEV");
3223    const char * str_arch    = getenv("GGML_HEXAGON_ARCH");
3224
3225    opt_experimental = str_experimental ? atoi(str_experimental) : 0;
3226    opt_verbose      = str_verbose ? atoi(str_verbose) : 0;
3227    opt_hostbuf      = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
3228    opt_opmask       = str_opmask  ? strtoul(str_opmask, NULL, 0) : opt_opmask;
3229    opt_opsync       = str_opsync  ? atoi(str_opsync)  : 0;
3230    opt_profile      = str_profile ? atoi(str_profile) : 0;
3231    opt_etm          = str_etm     ? atoi(str_etm) : 0;
3232    opt_nhvx         = str_nhvx    ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
3233    opt_ndev         = str_ndev    ? strtoul(str_ndev, NULL, 0) : opt_ndev;
3234
3235    if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
3236        opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
3237    }
3238
3239    if (str_arch) {
3240        if (str_arch[0] == 'v') {
3241            str_arch++;
3242        }
3243        opt_arch = strtoul(str_arch, NULL, 0);
3244    }
3245
3246    opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1;
3247
3248    reg->context = new ggml_hexagon_registry(reg);
3249
3250    HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req),
3251                sizeof(struct htp_general_rsp));
3252}
3253
3254static const struct ggml_backend_reg_i ggml_backend_hexagon_reg_i = {
3255    /* .get_name         = */ ggml_backend_hexagon_reg_get_name,
3256    /* .get_device_count = */ ggml_backend_hexagon_reg_get_device_count,
3257    /* .get_device       = */ ggml_backend_hexagon_reg_get_device,
3258    /* .get_proc_address = */ ggml_backend_hexagon_get_proc_address,
3259};
3260
3261ggml_backend_reg_t ggml_backend_hexagon_reg(void) {
3262    static bool initialized = false;
3263
3264    static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION,
3265                                    /* .iface       = */ ggml_backend_hexagon_reg_i,
3266                                    /* .context     = */ NULL };
3267
3268    {
3269        static std::mutex           mutex;
3270        std::lock_guard<std::mutex> lock(mutex);
3271        if (!initialized) {
3272            auto nErr = htpdrv_init();
3273            if (nErr != AEE_SUCCESS) {
3274                return NULL;
3275            }
3276
3277            ggml_hexagon_init(&reg);
3278        }
3279
3280        initialized = true;
3281    }
3282
3283    return &reg;
3284}
3285
3286GGML_BACKEND_DL_IMPL(ggml_backend_hexagon_reg)