1#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
   2#define _USE_MATH_DEFINES // For M_PI on MSVC
   3
   4#include "ggml-backend.h"
   5#include "ggml-impl.h"
   6#include "ggml-threading.h"
   7#include "ggml-cpu.h"
   8#include "ggml.h"
   9
  10// FIXME: required here for quantization functions
  11#include "ggml-quants.h"
  12
  13#ifdef GGML_USE_CPU_HBM
  14#include <hbwmalloc.h>
  15#endif
  16
  17#if defined(_MSC_VER) || defined(__MINGW32__)
  18#include <malloc.h> // using malloc.h with MSC/MINGW
  19#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
  20#include <alloca.h>
  21#endif
  22
  23#include <assert.h>
  24#include <errno.h>
  25#include <time.h>
  26#include <math.h>
  27#include <stdlib.h>
  28#include <string.h>
  29#include <stdint.h>
  30#include <inttypes.h>
  31#include <stdio.h>
  32#include <float.h>
  33#include <limits.h>
  34#include <stdarg.h>
  35#include <signal.h>
  36#if defined(__gnu_linux__)
  37#include <syscall.h>
  38#endif
  39
  40#if defined(__APPLE__)
  41#include <unistd.h>
  42#include <mach/mach.h>
  43#include <TargetConditionals.h>
  44#endif
  45
  46#if defined(_WIN32)
  47#define WIN32_LEAN_AND_MEAN
  48#ifndef NOMINMAX
  49    #define NOMINMAX
  50#endif
  51#include <windows.h>
  52#endif
  53
  54#define UNUSED GGML_UNUSED
  55
  56// Needed for ggml_fp32_to_bf16_row()
  57#if defined(__AVX512BF16__)
  58#if defined(_MSC_VER)
  59#define m512i(p) p
  60#else
  61#include <immintrin.h>
  62#define m512i(p) (__m512i)(p)
  63#endif // defined(_MSC_VER)
  64#endif // defined(__AVX512BF16__)
  65
  66#if defined(__linux__) || \
  67    defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \
  68    (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH)
  69
  70#include <unistd.h>
  71#include <sys/types.h>
  72#include <sys/stat.h>
  73#include <sys/wait.h>
  74#if defined(__linux__)
  75#include <sys/prctl.h>
  76#endif
  77
  78#if defined(__ANDROID__)
  79#include <unwind.h>
  80#include <dlfcn.h>
  81#include <stdio.h>
  82
  83struct backtrace_state {
  84    void ** current;
  85    void ** end;
  86};
  87
  88static _Unwind_Reason_Code unwind_callback(struct _Unwind_Context* context, void* arg) {
  89    struct backtrace_state * state = (struct backtrace_state *)arg;
  90    uintptr_t pc = _Unwind_GetIP(context);
  91    if (pc) {
  92        if (state->current == state->end) {
  93            return _URC_END_OF_STACK;
  94        } else {
  95            *state->current++ = (void*)pc;
  96        }
  97    }
  98    return _URC_NO_REASON;
  99}
 100
 101static void ggml_print_backtrace_symbols(void) {
 102    const int max = 100;
 103    void* buffer[max];
 104
 105    struct backtrace_state state = {buffer, buffer + max};
 106    _Unwind_Backtrace(unwind_callback, &state);
 107
 108    int count = state.current - buffer;
 109
 110    for (int idx = 0; idx < count; ++idx) {
 111        const void * addr = buffer[idx];
 112        const char * symbol = "";
 113
 114        Dl_info info;
 115        if (dladdr(addr, &info) && info.dli_sname) {
 116            symbol = info.dli_sname;
 117        }
 118
 119        fprintf(stderr, "%d: %p %s\n", idx, addr, symbol);
 120    }
 121}
 122#elif defined(__linux__) && defined(__GLIBC__)
 123#include <execinfo.h>
 124static void ggml_print_backtrace_symbols(void) {
 125    void * trace[100];
 126    int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
 127    backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
 128}
 129#elif defined(__APPLE__)
 130#include <execinfo.h>
 131static void ggml_print_backtrace_symbols(void) {
 132    void * trace[100];
 133    int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
 134    backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
 135}
 136#else
 137static void ggml_print_backtrace_symbols(void) {
 138    // platform not supported
 139}
 140#endif
 141
 142void ggml_print_backtrace(void) {
 143    const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
 144    if (GGML_NO_BACKTRACE) {
 145        return;
 146    }
 147#if defined(__APPLE__)
 148    // On macOS, fork+debugger attachment is problematic due to:
 149    // 1. libdispatch "poisons" forked child processes
 150    // 2. lldb has issues attaching to parent from forked child
 151    // Use simple backtrace() instead to avoid Terminal.app crashes
 152    const char * GGML_BACKTRACE_LLDB = getenv("GGML_BACKTRACE_LLDB");
 153    if (!GGML_BACKTRACE_LLDB) {
 154        fprintf(stderr, "WARNING: Using native backtrace. Set GGML_BACKTRACE_LLDB for more info.\n");
 155        fprintf(stderr, "WARNING: GGML_BACKTRACE_LLDB may cause native MacOS Terminal.app to crash.\n");
 156        fprintf(stderr, "See: https://github.com/ggml-org/llama.cpp/pull/17869\n");
 157        ggml_print_backtrace_symbols();
 158        return;
 159    }
 160#endif
 161#if defined(__linux__)
 162    FILE * f = fopen("/proc/self/status", "r");
 163    size_t size = 0;
 164    char * line = NULL;
 165    ssize_t length = 0;
 166    while ((length = getline(&line, &size, f)) > 0) {
 167        if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) &&
 168            (length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) {
 169            // Already being debugged, and the breakpoint is the later abort()
 170            free(line);
 171            fclose(f);
 172            return;
 173        }
 174    }
 175    free(line);
 176    fclose(f);
 177    int lock[2] = { -1, -1 };
 178    (void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER
 179#endif
 180    const int parent_pid = getpid();
 181    const int child_pid = fork();
 182    if (child_pid < 0) { // error
 183#if defined(__linux__)
 184        close(lock[1]);
 185        close(lock[0]);
 186#endif
 187        return;
 188    } else if (child_pid == 0) { // child
 189        char attach[32];
 190        snprintf(attach, sizeof(attach), "attach %d", parent_pid);
 191#if defined(__linux__)
 192        close(lock[1]);
 193        (void) !read(lock[0], lock, 1);
 194        close(lock[0]);
 195#endif
 196        // try gdb
 197        execlp("gdb", "gdb", "--batch",
 198            "-ex", "set style enabled on",
 199            "-ex", attach,
 200            "-ex", "bt -frame-info source-and-location",
 201            "-ex", "detach",
 202            "-ex", "quit",
 203            (char *) NULL);
 204        // try lldb
 205        execlp("lldb", "lldb", "--batch",
 206            "-o", "bt",
 207            "-o", "quit",
 208            "-p", &attach[sizeof("attach ") - 1],
 209            (char *) NULL);
 210        // gdb failed, fallback to backtrace_symbols
 211        ggml_print_backtrace_symbols();
 212        _Exit(0);
 213    } else { // parent
 214#if defined(__linux__)
 215        prctl(PR_SET_PTRACER, child_pid);
 216        close(lock[1]);
 217        close(lock[0]);
 218#endif
 219        waitpid(child_pid, NULL, 0);
 220    }
 221}
 222#else
 223void ggml_print_backtrace(void) {
 224    // platform not supported
 225}
 226#endif
 227
 228static ggml_abort_callback_t g_abort_callback = NULL;
 229
 230// Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)
 231GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback) {
 232    ggml_abort_callback_t ret_val = g_abort_callback;
 233    g_abort_callback = callback;
 234    return ret_val;
 235}
 236
 237void ggml_abort(const char * file, int line, const char * fmt, ...) {
 238    fflush(stdout);
 239
 240    char message[2048];
 241    int offset = snprintf(message, sizeof(message), "%s:%d: ", file, line);
 242
 243    va_list args;
 244    va_start(args, fmt);
 245    vsnprintf(message + offset, sizeof(message) - offset, fmt, args);
 246    va_end(args);
 247
 248    if (g_abort_callback) {
 249        g_abort_callback(message);
 250    } else {
 251        // default: print error and backtrace to stderr
 252        fprintf(stderr, "%s\n", message);
 253        ggml_print_backtrace();
 254    }
 255
 256    abort();
 257}
 258
 259// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp
 260
 261//
 262// logging
 263//
 264
 265struct ggml_logger_state {
 266    ggml_log_callback log_callback;
 267    void * log_callback_user_data;
 268};
 269static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
 270
 271static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
 272    if (format == NULL) {
 273        return;
 274    }
 275    va_list args_copy;
 276    va_copy(args_copy, args);
 277    char buffer[128];
 278    int len = vsnprintf(buffer, 128, format, args);
 279    if (len < 128) {
 280        g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
 281    } else {
 282        char * buffer2 = (char *) calloc(len + 1, sizeof(char));
 283        vsnprintf(buffer2, len + 1, format, args_copy);
 284        buffer2[len] = 0;
 285        g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
 286        free(buffer2);
 287    }
 288    va_end(args_copy);
 289}
 290
 291void ggml_log_internal(enum ggml_log_level level, const char * format, ...) {
 292    va_list args;
 293    va_start(args, format);
 294    ggml_log_internal_v(level, format, args);
 295    va_end(args);
 296}
 297
 298void ggml_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) {
 299    (void) level;
 300    (void) user_data;
 301    fputs(text, stderr);
 302    fflush(stderr);
 303}
 304
 305//
 306// end of logging block
 307//
 308
 309#ifdef GGML_USE_ACCELERATE
 310// uncomment to use vDSP for soft max computation
 311// note: not sure if it is actually faster
 312//#define GGML_SOFT_MAX_ACCELERATE
 313#endif
 314
 315
 316void * ggml_aligned_malloc(size_t size) {
 317#if defined(__s390x__)
 318    const int alignment = 256;
 319#else
 320    const int alignment = 64;
 321#endif
 322
 323#if defined(_MSC_VER) || defined(__MINGW32__)
 324    return _aligned_malloc(size, alignment);
 325#else
 326    if (size == 0) {
 327        GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
 328        return NULL;
 329    }
 330    void * aligned_memory = NULL;
 331  #ifdef GGML_USE_CPU_HBM
 332    int result = hbw_posix_memalign(&aligned_memory, alignment, size);
 333  #elif TARGET_OS_OSX
 334    GGML_UNUSED(alignment);
 335    kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE);
 336    int result = EFAULT;
 337    switch (alloc_status) {
 338        case KERN_SUCCESS:
 339            result = 0;
 340            break;
 341        case KERN_INVALID_ADDRESS:
 342            result = EINVAL;
 343            break;
 344        case KERN_NO_SPACE:
 345            result = ENOMEM;
 346            break;
 347        default:
 348            result = EFAULT;
 349            break;
 350    }
 351  #else
 352    int result = posix_memalign(&aligned_memory, alignment, size);
 353  #endif
 354    if (result != 0) {
 355        // Handle allocation failure
 356        const char *error_desc = "unknown allocation error";
 357        switch (result) {
 358            case EINVAL:
 359                error_desc = "invalid alignment value";
 360                break;
 361            case ENOMEM:
 362                error_desc = "insufficient memory";
 363                break;
 364        }
 365        GGML_LOG_ERROR("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0));
 366        return NULL;
 367    }
 368    return aligned_memory;
 369#endif
 370}
 371
 372void ggml_aligned_free(void * ptr, size_t size) {
 373    GGML_UNUSED(size);
 374#if defined(_MSC_VER) || defined(__MINGW32__)
 375    _aligned_free(ptr);
 376#elif GGML_USE_CPU_HBM
 377    if (ptr != NULL) {
 378        hbw_free(ptr);
 379    }
 380#elif TARGET_OS_OSX
 381    if (ptr != NULL) {
 382        vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ptr, size);
 383    }
 384#else
 385    free(ptr);
 386#endif
 387}
 388
 389
 390inline static void * ggml_malloc(size_t size) {
 391    if (size == 0) {
 392        GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_malloc!\n");
 393        return NULL;
 394    }
 395    void * result = malloc(size);
 396    if (result == NULL) {
 397        GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
 398        GGML_ABORT("fatal error");
 399    }
 400    return result;
 401}
 402
 403// calloc
 404inline static void * ggml_calloc(size_t num, size_t size) {
 405    if (num == 0 || size == 0) {
 406        GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_calloc!\n");
 407        return NULL;
 408    }
 409    void * result = calloc(num, size);
 410    if (result == NULL) {
 411        GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
 412        GGML_ABORT("fatal error");
 413    }
 414    return result;
 415}
 416
 417#define GGML_MALLOC(size)      ggml_malloc(size)
 418#define GGML_CALLOC(num, size) ggml_calloc(num, size)
 419
 420#define GGML_FREE(ptr) free(ptr)
 421
 422const char * ggml_status_to_string(enum ggml_status status) {
 423    switch (status) {
 424        case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
 425        case GGML_STATUS_FAILED:       return "GGML status: error (operation failed)";
 426        case GGML_STATUS_SUCCESS:      return "GGML status: success";
 427        case GGML_STATUS_ABORTED:      return "GGML status: warning (operation aborted)";
 428    }
 429
 430    return "GGML status: unknown";
 431}
 432
 433float ggml_fp16_to_fp32(ggml_fp16_t x) {
 434#define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
 435    return GGML_FP16_TO_FP32(x);
 436}
 437
 438ggml_fp16_t ggml_fp32_to_fp16(float x) {
 439#define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
 440    return GGML_FP32_TO_FP16(x);
 441}
 442
 443float ggml_bf16_to_fp32(ggml_bf16_t x) {
 444#define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
 445    return GGML_BF16_TO_FP32(x);  // it just left shifts
 446}
 447
 448ggml_bf16_t ggml_fp32_to_bf16(float x) {
 449#define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
 450    return GGML_FP32_TO_BF16(x);
 451}
 452
 453void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
 454    for (int64_t i = 0; i < n; i++) {
 455        y[i] = GGML_FP16_TO_FP32(x[i]);
 456    }
 457}
 458
 459void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
 460    int i = 0;
 461    for (; i < n; ++i) {
 462        y[i] = GGML_FP32_TO_FP16(x[i]);
 463    }
 464}
 465
 466void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
 467    int i = 0;
 468    for (; i < n; ++i) {
 469        y[i] = GGML_BF16_TO_FP32(x[i]);
 470    }
 471}
 472
 473void ggml_fp32_to_bf16_row_ref(const float * x, ggml_bf16_t * y, int64_t n) {
 474    for (int i = 0; i < n; i++) {
 475        y[i] = ggml_compute_fp32_to_bf16(x[i]);
 476    }
 477}
 478
 479void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
 480  int i = 0;
 481#if defined(__AVX512BF16__)
 482  // subnormals are flushed to zero on this platform
 483  for (; i + 32 <= n; i += 32) {
 484        _mm512_storeu_si512(
 485            (__m512i *)(y + i),
 486            m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
 487                                _mm512_loadu_ps(x + i))));
 488  }
 489#endif
 490    for (; i < n; i++) {
 491        y[i] = GGML_FP32_TO_BF16(x[i]);
 492    }
 493}
 494
 495bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
 496    return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
 497}
 498
 499const char * ggml_version(void) {
 500    return GGML_VERSION;
 501}
 502
 503const char * ggml_commit(void) {
 504    return GGML_COMMIT;
 505}
 506
 507//
 508// timing
 509//
 510
 511#if defined(_MSC_VER) || defined(__MINGW32__)
 512static int64_t timer_freq, timer_start;
 513void ggml_time_init(void) {
 514    LARGE_INTEGER t;
 515    QueryPerformanceFrequency(&t);
 516    timer_freq = t.QuadPart;
 517
 518    // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
 519    // and the uptime is high enough.
 520    // We subtract the program start time to reduce the likelihood of that happening.
 521    QueryPerformanceCounter(&t);
 522    timer_start = t.QuadPart;
 523}
 524int64_t ggml_time_ms(void) {
 525    LARGE_INTEGER t;
 526    QueryPerformanceCounter(&t);
 527    return ((t.QuadPart-timer_start) * 1000) / timer_freq;
 528}
 529int64_t ggml_time_us(void) {
 530    LARGE_INTEGER t;
 531    QueryPerformanceCounter(&t);
 532    return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
 533}
 534#else
 535void ggml_time_init(void) {}
 536int64_t ggml_time_ms(void) {
 537    struct timespec ts;
 538    clock_gettime(CLOCK_MONOTONIC, &ts);
 539    return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000;
 540}
 541
 542int64_t ggml_time_us(void) {
 543    struct timespec ts;
 544    clock_gettime(CLOCK_MONOTONIC, &ts);
 545    return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;
 546}
 547#endif
 548
 549int64_t ggml_cycles(void) {
 550    return clock();
 551}
 552
 553int64_t ggml_cycles_per_ms(void) {
 554    return CLOCKS_PER_SEC/1000;
 555}
 556
 557//
 558// cross-platform UTF-8 file paths
 559//
 560
 561#ifdef _WIN32
 562static wchar_t * ggml_mbstowcs(const char * mbs) {
 563    int wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, NULL, 0);
 564    if (!wlen) {
 565        errno = EINVAL;
 566        return NULL;
 567    }
 568
 569    wchar_t * wbuf = GGML_MALLOC(wlen * sizeof(wchar_t));
 570    wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, wbuf, wlen);
 571    if (!wlen) {
 572        GGML_FREE(wbuf);
 573        errno = EINVAL;
 574        return NULL;
 575    }
 576
 577    return wbuf;
 578}
 579#endif
 580
 581FILE * ggml_fopen(const char * fname, const char * mode) {
 582#ifdef _WIN32
 583    FILE * file = NULL;
 584
 585    // convert fname (UTF-8)
 586    wchar_t * wfname = ggml_mbstowcs(fname);
 587    if (wfname) {
 588        // convert mode (ANSI)
 589        wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t));
 590        wchar_t * wmode_p = wmode;
 591        do {
 592            *wmode_p++ = (wchar_t)*mode;
 593        } while (*mode++);
 594
 595        // open file
 596        file = _wfopen(wfname, wmode);
 597
 598        GGML_FREE(wfname);
 599        GGML_FREE(wmode);
 600    }
 601
 602    return file;
 603#else
 604    return fopen(fname, mode);
 605#endif
 606
 607}
 608
 609static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
 610    [GGML_TYPE_I8] = {
 611        .type_name                = "i8",
 612        .blck_size                = 1,
 613        .type_size                = sizeof(int8_t),
 614        .is_quantized             = false,
 615    },
 616    [GGML_TYPE_I16] = {
 617        .type_name                = "i16",
 618        .blck_size                = 1,
 619        .type_size                = sizeof(int16_t),
 620        .is_quantized             = false,
 621    },
 622    [GGML_TYPE_I32] = {
 623        .type_name                = "i32",
 624        .blck_size                = 1,
 625        .type_size                = sizeof(int32_t),
 626        .is_quantized             = false,
 627    },
 628    [GGML_TYPE_I64] = {
 629        .type_name                = "i64",
 630        .blck_size                = 1,
 631        .type_size                = sizeof(int64_t),
 632        .is_quantized             = false,
 633    },
 634    [GGML_TYPE_F64] = {
 635        .type_name                = "f64",
 636        .blck_size                = 1,
 637        .type_size                = sizeof(double),
 638        .is_quantized             = false,
 639    },
 640    [GGML_TYPE_F32] = {
 641        .type_name                = "f32",
 642        .blck_size                = 1,
 643        .type_size                = sizeof(float),
 644        .is_quantized             = false,
 645    },
 646    [GGML_TYPE_F16] = {
 647        .type_name                = "f16",
 648        .blck_size                = 1,
 649        .type_size                = sizeof(ggml_fp16_t),
 650        .is_quantized             = false,
 651        .to_float                 = (ggml_to_float_t) ggml_fp16_to_fp32_row,
 652        .from_float_ref           = (ggml_from_float_t) ggml_fp32_to_fp16_row,
 653    },
 654    [GGML_TYPE_Q4_0] = {
 655        .type_name                = "q4_0",
 656        .blck_size                = QK4_0,
 657        .type_size                = sizeof(block_q4_0),
 658        .is_quantized             = true,
 659        .to_float                 = (ggml_to_float_t) dequantize_row_q4_0,
 660        .from_float_ref           = (ggml_from_float_t) quantize_row_q4_0_ref,
 661    },
 662    [GGML_TYPE_Q4_1] = {
 663        .type_name                = "q4_1",
 664        .blck_size                = QK4_1,
 665        .type_size                = sizeof(block_q4_1),
 666        .is_quantized             = true,
 667        .to_float                 = (ggml_to_float_t) dequantize_row_q4_1,
 668        .from_float_ref           = (ggml_from_float_t) quantize_row_q4_1_ref,
 669    },
 670    [4] = { // GGML_TYPE_Q4_2
 671        .type_name                = "DEPRECATED",
 672        .blck_size                = 0,
 673        .type_size                = 0,
 674        .is_quantized             = false,
 675    },
 676    [5] = { // GGML_TYPE_Q4_3
 677        .type_name                = "DEPRECATED",
 678        .blck_size                = 0,
 679        .type_size                = 0,
 680        .is_quantized             = false,
 681    },
 682    [GGML_TYPE_Q5_0] = {
 683        .type_name                = "q5_0",
 684        .blck_size                = QK5_0,
 685        .type_size                = sizeof(block_q5_0),
 686        .is_quantized             = true,
 687        .to_float                 = (ggml_to_float_t) dequantize_row_q5_0,
 688        .from_float_ref           = (ggml_from_float_t) quantize_row_q5_0_ref,
 689    },
 690    [GGML_TYPE_Q5_1] = {
 691        .type_name                = "q5_1",
 692        .blck_size                = QK5_1,
 693        .type_size                = sizeof(block_q5_1),
 694        .is_quantized             = true,
 695        .to_float                 = (ggml_to_float_t) dequantize_row_q5_1,
 696        .from_float_ref           = (ggml_from_float_t) quantize_row_q5_1_ref,
 697    },
 698    [GGML_TYPE_Q8_0] = {
 699        .type_name                = "q8_0",
 700        .blck_size                = QK8_0,
 701        .type_size                = sizeof(block_q8_0),
 702        .is_quantized             = true,
 703        .to_float                 = (ggml_to_float_t) dequantize_row_q8_0,
 704        .from_float_ref           = (ggml_from_float_t) quantize_row_q8_0_ref,
 705    },
 706    [GGML_TYPE_Q8_1] = {
 707        .type_name                = "q8_1",
 708        .blck_size                = QK8_1,
 709        .type_size                = sizeof(block_q8_1),
 710        .is_quantized             = true,
 711        .from_float_ref           = (ggml_from_float_t) quantize_row_q8_1_ref,
 712    },
 713    [GGML_TYPE_MXFP4] = {
 714        .type_name                = "mxfp4",
 715        .blck_size                = QK_MXFP4,
 716        .type_size                = sizeof(block_mxfp4),
 717        .is_quantized             = true,
 718        .to_float                 = (ggml_to_float_t) dequantize_row_mxfp4,
 719        .from_float_ref           = (ggml_from_float_t)quantize_row_mxfp4_ref,
 720    },
 721    [GGML_TYPE_Q2_K] = {
 722        .type_name                = "q2_K",
 723        .blck_size                = QK_K,
 724        .type_size                = sizeof(block_q2_K),
 725        .is_quantized             = true,
 726        .to_float                 = (ggml_to_float_t) dequantize_row_q2_K,
 727        .from_float_ref           = (ggml_from_float_t) quantize_row_q2_K_ref,
 728    },
 729    [GGML_TYPE_Q3_K] = {
 730        .type_name                = "q3_K",
 731        .blck_size                = QK_K,
 732        .type_size                = sizeof(block_q3_K),
 733        .is_quantized             = true,
 734        .to_float                 = (ggml_to_float_t) dequantize_row_q3_K,
 735        .from_float_ref           = (ggml_from_float_t) quantize_row_q3_K_ref,
 736    },
 737    [GGML_TYPE_Q4_K] = {
 738        .type_name                = "q4_K",
 739        .blck_size                = QK_K,
 740        .type_size                = sizeof(block_q4_K),
 741        .is_quantized             = true,
 742        .to_float                 = (ggml_to_float_t) dequantize_row_q4_K,
 743        .from_float_ref           = (ggml_from_float_t) quantize_row_q4_K_ref,
 744    },
 745    [GGML_TYPE_Q5_K] = {
 746        .type_name                = "q5_K",
 747        .blck_size                = QK_K,
 748        .type_size                = sizeof(block_q5_K),
 749        .is_quantized             = true,
 750        .to_float                 = (ggml_to_float_t) dequantize_row_q5_K,
 751        .from_float_ref           = (ggml_from_float_t) quantize_row_q5_K_ref,
 752    },
 753    [GGML_TYPE_Q6_K] = {
 754        .type_name                = "q6_K",
 755        .blck_size                = QK_K,
 756        .type_size                = sizeof(block_q6_K),
 757        .is_quantized             = true,
 758        .to_float                 = (ggml_to_float_t) dequantize_row_q6_K,
 759        .from_float_ref           = (ggml_from_float_t) quantize_row_q6_K_ref,
 760    },
 761    [GGML_TYPE_IQ2_XXS] = {
 762        .type_name                = "iq2_xxs",
 763        .blck_size                = QK_K,
 764        .type_size                = sizeof(block_iq2_xxs),
 765        .is_quantized             = true,
 766        .to_float                 = (ggml_to_float_t) dequantize_row_iq2_xxs,
 767        .from_float_ref           = NULL,
 768    },
 769    [GGML_TYPE_IQ2_XS] = {
 770        .type_name                = "iq2_xs",
 771        .blck_size                = QK_K,
 772        .type_size                = sizeof(block_iq2_xs),
 773        .is_quantized             = true,
 774        .to_float                 = (ggml_to_float_t) dequantize_row_iq2_xs,
 775        .from_float_ref           = NULL,
 776    },
 777    [GGML_TYPE_IQ3_XXS] = {
 778        .type_name                = "iq3_xxs",
 779        .blck_size                = QK_K,
 780        .type_size                = sizeof(block_iq3_xxs),
 781        .is_quantized             = true,
 782        .to_float                 = (ggml_to_float_t) dequantize_row_iq3_xxs,
 783        .from_float_ref           = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
 784    },
 785    [GGML_TYPE_IQ3_S] = {
 786        .type_name                = "iq3_s",
 787        .blck_size                = QK_K,
 788        .type_size                = sizeof(block_iq3_s),
 789        .is_quantized             = true,
 790        .to_float                 = (ggml_to_float_t) dequantize_row_iq3_s,
 791        .from_float_ref           = (ggml_from_float_t)quantize_row_iq3_s_ref,
 792    },
 793    [GGML_TYPE_IQ2_S] = {
 794        .type_name                = "iq2_s",
 795        .blck_size                = QK_K,
 796        .type_size                = sizeof(block_iq2_s),
 797        .is_quantized             = true,
 798        .to_float                 = (ggml_to_float_t) dequantize_row_iq2_s,
 799        .from_float_ref           = (ggml_from_float_t)quantize_row_iq2_s_ref,
 800    },
 801    [GGML_TYPE_IQ1_S] = {
 802        .type_name                = "iq1_s",
 803        .blck_size                = QK_K,
 804        .type_size                = sizeof(block_iq1_s),
 805        .is_quantized             = true,
 806        .to_float                 = (ggml_to_float_t) dequantize_row_iq1_s,
 807        .from_float_ref           = NULL,
 808    },
 809    [GGML_TYPE_IQ1_M] = {
 810        .type_name                = "iq1_m",
 811        .blck_size                = QK_K,
 812        .type_size                = sizeof(block_iq1_m),
 813        .is_quantized             = true,
 814        .to_float                 = (ggml_to_float_t) dequantize_row_iq1_m,
 815        .from_float_ref           = NULL,
 816    },
 817    [GGML_TYPE_IQ4_NL] = {
 818        .type_name                = "iq4_nl",
 819        .blck_size                = QK4_NL,
 820        .type_size                = sizeof(block_iq4_nl),
 821        .is_quantized             = true,
 822        .to_float                 = (ggml_to_float_t) dequantize_row_iq4_nl,
 823        .from_float_ref           = (ggml_from_float_t)quantize_row_iq4_nl_ref,
 824    },
 825    [GGML_TYPE_IQ4_XS] = {
 826        .type_name                = "iq4_xs",
 827        .blck_size                = QK_K,
 828        .type_size                = sizeof(block_iq4_xs),
 829        .is_quantized             = true,
 830        .to_float                 = (ggml_to_float_t) dequantize_row_iq4_xs,
 831        .from_float_ref           = (ggml_from_float_t)quantize_row_iq4_xs_ref,
 832    },
 833    [GGML_TYPE_Q8_K] = {
 834        .type_name                = "q8_K",
 835        .blck_size                = QK_K,
 836        .type_size                = sizeof(block_q8_K),
 837        .is_quantized             = true,
 838    },
 839    [GGML_TYPE_BF16] = {
 840        .type_name                = "bf16",
 841        .blck_size                = 1,
 842        .type_size                = sizeof(ggml_bf16_t),
 843        .is_quantized             = false,
 844        .to_float                 = (ggml_to_float_t) ggml_bf16_to_fp32_row,
 845        .from_float_ref           = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref,
 846    },
 847    [31] = { // GGML_TYPE_Q4_0_4_4
 848        .type_name                = "TYPE_Q4_0_4_4 REMOVED, use Q4_0 with runtime repacking",
 849        .blck_size                = 0,
 850        .type_size                = 0,
 851        .is_quantized             = false,
 852    },
 853    [32] = { // GGML_TYPE_Q4_0_4_8
 854        .type_name                = "TYPE_Q4_0_4_8 REMOVED, use Q4_0 with runtime repacking",
 855        .blck_size                = 0,
 856        .type_size                = 0,
 857        .is_quantized             = false,
 858    },
 859    [33] = { // GGML_TYPE_Q4_0_8_8
 860        .type_name                = "TYPE_Q4_0_8_8 REMOVED, use Q4_0 with runtime repacking",
 861        .blck_size                = 0,
 862        .type_size                = 0,
 863        .is_quantized             = false,
 864    },
 865    [GGML_TYPE_TQ1_0] = {
 866        .type_name                = "tq1_0",
 867        .blck_size                = QK_K,
 868        .type_size                = sizeof(block_tq1_0),
 869        .is_quantized             = true,
 870        .to_float                 = (ggml_to_float_t) dequantize_row_tq1_0,
 871        .from_float_ref           = (ggml_from_float_t) quantize_row_tq1_0_ref,
 872    },
 873    [GGML_TYPE_TQ2_0] = {
 874        .type_name                = "tq2_0",
 875        .blck_size                = QK_K,
 876        .type_size                = sizeof(block_tq2_0),
 877        .is_quantized             = true,
 878        .to_float                 = (ggml_to_float_t) dequantize_row_tq2_0,
 879        .from_float_ref           = (ggml_from_float_t) quantize_row_tq2_0_ref,
 880    },
 881    [36] = { // GGML_TYPE_IQ4_NL_4_4
 882        .type_name                = "TYPE_IQ4_NL_4_4 REMOVED, use IQ4_NL with runtime repacking",
 883        .blck_size                = 0,
 884        .type_size                = 0,
 885        .is_quantized             = false,
 886    },
 887    [37] = { // GGML_TYPE_IQ4_NL_4_8
 888        .type_name                = "TYPE_IQ4_NL_4_8 REMOVED, use IQ4_NL with runtime repacking",
 889        .blck_size                = 0,
 890        .type_size                = 0,
 891        .is_quantized             = false,
 892    },
 893    [38] = { // GGML_TYPE_IQ4_NL_8_8
 894        .type_name                = "TYPE_IQ4_NL_8_8 REMOVED, use IQ4_NL with runtime repacking",
 895        .blck_size                = 0,
 896        .type_size                = 0,
 897        .is_quantized             = false,
 898    },
 899};
 900
 901const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
 902    GGML_ASSERT(type < GGML_TYPE_COUNT);
 903    return &type_traits[type];
 904}
 905
 906//
 907// ggml object
 908//
 909
 910struct ggml_object {
 911    size_t offs;
 912    size_t size;
 913
 914    struct ggml_object * next;
 915
 916    enum ggml_object_type type;
 917
 918    char padding[4];
 919};
 920
 921static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
 922
 923//
 924// ggml context
 925//
 926
 927struct ggml_context {
 928    size_t mem_size;
 929    void * mem_buffer;
 930    bool   mem_buffer_owned;
 931    bool   no_alloc;
 932
 933    int    n_objects;
 934
 935    struct ggml_object * objects_begin;
 936    struct ggml_object * objects_end;
 937};
 938
 939//
 940// data types
 941//
 942
 943static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
 944    "NONE",
 945
 946    "DUP",
 947    "ADD",
 948    "ADD_ID",
 949    "ADD1",
 950    "ACC",
 951    "SUB",
 952    "MUL",
 953    "DIV",
 954    "SQR",
 955    "SQRT",
 956    "LOG",
 957    "SIN",
 958    "COS",
 959    "SUM",
 960    "SUM_ROWS",
 961    "CUMSUM",
 962    "MEAN",
 963    "ARGMAX",
 964    "COUNT_EQUAL",
 965    "REPEAT",
 966    "REPEAT_BACK",
 967    "CONCAT",
 968    "SILU_BACK",
 969    "NORM",
 970    "RMS_NORM",
 971    "RMS_NORM_BACK",
 972    "GROUP_NORM",
 973    "L2_NORM",
 974
 975    "MUL_MAT",
 976    "MUL_MAT_ID",
 977    "OUT_PROD",
 978
 979    "SCALE",
 980    "SET",
 981    "CPY",
 982    "CONT",
 983    "RESHAPE",
 984    "VIEW",
 985    "PERMUTE",
 986    "TRANSPOSE",
 987    "GET_ROWS",
 988    "GET_ROWS_BACK",
 989    "SET_ROWS",
 990    "DIAG",
 991    "DIAG_MASK_INF",
 992    "DIAG_MASK_ZERO",
 993    "SOFT_MAX",
 994    "SOFT_MAX_BACK",
 995    "ROPE",
 996    "ROPE_BACK",
 997    "CLAMP",
 998    "CONV_TRANSPOSE_1D",
 999    "IM2COL",
1000    "IM2COL_BACK",
1001    "IM2COL_3D",
1002    "CONV_2D",
1003    "CONV_3D",
1004    "CONV_2D_DW",
1005    "CONV_TRANSPOSE_2D",
1006    "POOL_1D",
1007    "POOL_2D",
1008    "POOL_2D_BACK",
1009    "UPSCALE",
1010    "PAD",
1011    "PAD_REFLECT_1D",
1012    "ROLL",
1013    "ARANGE",
1014    "TIMESTEP_EMBEDDING",
1015    "ARGSORT",
1016    "TOP_K",
1017    "LEAKY_RELU",
1018    "TRI",
1019    "FILL",
1020
1021    "FLASH_ATTN_EXT",
1022    "FLASH_ATTN_BACK",
1023    "SSM_CONV",
1024    "SSM_SCAN",
1025    "WIN_PART",
1026    "WIN_UNPART",
1027    "GET_REL_POS",
1028    "ADD_REL_POS",
1029    "RWKV_WKV6",
1030    "GATED_LINEAR_ATTN",
1031    "RWKV_WKV7",
1032    "SOLVE_TRI",
1033
1034    "UNARY",
1035
1036    "MAP_CUSTOM1",
1037    "MAP_CUSTOM2",
1038    "MAP_CUSTOM3",
1039
1040    "CUSTOM",
1041
1042    "CROSS_ENTROPY_LOSS",
1043    "CROSS_ENTROPY_LOSS_BACK",
1044    "OPT_STEP_ADAMW",
1045    "OPT_STEP_SGD",
1046
1047    "GLU",
1048};
1049
1050static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
1051
1052static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1053    "none",
1054
1055    "x",
1056    "x+y",
1057    "x[i]+y",
1058    "x+y",
1059    "view(x,nb,offset)+=y->x",
1060    "x-y",
1061    "x*y",
1062    "x/y",
1063    "x^2",
1064    "โˆšx",
1065    "log(x)",
1066    "sin(x)",
1067    "cos(x)",
1068    "ฮฃx",
1069    "ฮฃx_k",
1070    "cumsum(x)",
1071    "ฮฃx/n",
1072    "argmax(x)",
1073    "count_equal(x)",
1074    "repeat(x)",
1075    "repeat_back(x)",
1076    "concat(x, y)",
1077    "silu_back(x)",
1078    "norm(x)",
1079    "rms_norm(x)",
1080    "rms_norm_back(x)",
1081    "group_norm(x)",
1082    "l2_norm(x)",
1083
1084    "X*Y",
1085    "X[i]*Y",
1086    "X*Y",
1087
1088    "x*v",
1089    "y-\\>view(x)",
1090    "x-\\>y",
1091    "cont(x)",
1092    "reshape(x)",
1093    "view(x)",
1094    "permute(x)",
1095    "transpose(x)",
1096    "get_rows(x)",
1097    "get_rows_back(x)",
1098    "set_rows(x)",
1099    "diag(x)",
1100    "diag_mask_inf(x)",
1101    "diag_mask_zero(x)",
1102    "soft_max(x)",
1103    "soft_max_back(x)",
1104    "rope(x)",
1105    "rope_back(x)",
1106    "clamp(x)",
1107    "conv_transpose_1d(x)",
1108    "im2col(x)",
1109    "im2col_back(x)",
1110    "im2col_3d(x)",
1111    "conv_2d(x)",
1112    "conv_3d(x)",
1113    "conv_2d_dw(x)",
1114    "conv_transpose_2d(x)",
1115    "pool_1d(x)",
1116    "pool_2d(x)",
1117    "pool_2d_back(x)",
1118    "upscale(x)",
1119    "pad(x)",
1120    "pad_reflect_1d(x)",
1121    "roll(x)",
1122    "arange(start, stop, step)",
1123    "timestep_embedding(timesteps, dim, max_period)",
1124    "argsort(x)",
1125    "top_k(x)",
1126    "leaky_relu(x)",
1127    "tri(x)",
1128    "fill(x, c)",
1129
1130    "flash_attn_ext(x)",
1131    "flash_attn_back(x)",
1132    "ssm_conv(x)",
1133    "ssm_scan(x)",
1134    "win_part(x)",
1135    "win_unpart(x)",
1136    "get_rel_pos(x)",
1137    "add_rel_pos(x)",
1138    "rwkv_wkv6(k, v, r, tf, td, s)",
1139    "gated_linear_attn(k, v, q, gate, s)",
1140    "rwkv_wkv7(r, w, k, v, a, b, s)",
1141    "A X = B, A triangular, solve X",
1142
1143    "unary(x)",
1144
1145    "map_custom(x)",
1146    "map_custom(x,y)",
1147    "map_custom(x,y,z)",
1148
1149    "custom(x)",
1150
1151    "cross_entropy_loss(x,y)",
1152    "cross_entropy_loss_back(x,y)",
1153    "adamw(x)",
1154    "sgd(x)",
1155
1156    "glu(x)",
1157};
1158
1159static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
1160
1161static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1162
1163static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1164    "ABS",
1165    "SGN",
1166    "NEG",
1167    "STEP",
1168    "TANH",
1169    "ELU",
1170    "RELU",
1171    "SIGMOID",
1172    "GELU",
1173    "GELU_QUICK",
1174    "SILU",
1175    "HARDSWISH",
1176    "HARDSIGMOID",
1177    "EXP",
1178    "EXPM1",
1179    "SOFTPLUS",
1180    "GELU_ERF",
1181    "XIELU",
1182    "FLOOR",
1183    "CEIL",
1184    "ROUND",
1185    "TRUNC",
1186};
1187
1188static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22");
1189
1190static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
1191    "REGLU",
1192    "GEGLU",
1193    "SWIGLU",
1194    "SWIGLU_OAI",
1195    "GEGLU_ERF",
1196    "GEGLU_QUICK",
1197};
1198
1199static_assert(GGML_GLU_OP_COUNT == 6, "GGML_GLU_OP_COUNT != 6");
1200
1201
1202static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
1203static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
1204
1205
1206////////////////////////////////////////////////////////////////////////////////
1207
1208void ggml_print_object(const struct ggml_object * obj) {
1209    GGML_LOG_INFO(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n",
1210            obj->type, obj->offs, obj->size, (const void *) obj->next);
1211}
1212
1213void ggml_print_objects(const struct ggml_context * ctx) {
1214    struct ggml_object * obj = ctx->objects_begin;
1215
1216    GGML_LOG_INFO("%s: objects in context %p:\n", __func__, (const void *) ctx);
1217
1218    while (obj != NULL) {
1219        ggml_print_object(obj);
1220        obj = obj->next;
1221    }
1222
1223    GGML_LOG_INFO("%s: --- end ---\n", __func__);
1224}
1225
1226int64_t ggml_nelements(const struct ggml_tensor * tensor) {
1227    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1228
1229    return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
1230}
1231
1232int64_t ggml_nrows(const struct ggml_tensor * tensor) {
1233    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1234
1235    return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
1236}
1237
1238size_t ggml_nbytes(const struct ggml_tensor * tensor) {
1239    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
1240        if (tensor->ne[i] <= 0) {
1241            return 0;
1242        }
1243    }
1244
1245    size_t nbytes;
1246    const size_t blck_size = ggml_blck_size(tensor->type);
1247    if (blck_size == 1) {
1248        nbytes = ggml_type_size(tensor->type);
1249        for (int i = 0; i < GGML_MAX_DIMS; ++i) {
1250            nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
1251        }
1252    }
1253    else {
1254        nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
1255        for (int i = 1; i < GGML_MAX_DIMS; ++i) {
1256            nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
1257        }
1258    }
1259
1260    return nbytes;
1261}
1262
1263size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
1264    return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
1265}
1266
1267int64_t ggml_blck_size(enum ggml_type type) {
1268    return type_traits[type].blck_size;
1269}
1270
1271size_t ggml_type_size(enum ggml_type type) {
1272    return type_traits[type].type_size;
1273}
1274
1275size_t ggml_row_size(enum ggml_type type, int64_t ne) {
1276    assert(ne % ggml_blck_size(type) == 0);
1277    return ggml_type_size(type)*ne/ggml_blck_size(type);
1278}
1279
1280double ggml_type_sizef(enum ggml_type type) {
1281    return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
1282}
1283
1284const char * ggml_type_name(enum ggml_type type) {
1285    return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE";
1286}
1287
1288bool ggml_is_quantized(enum ggml_type type) {
1289    return type_traits[type].is_quantized;
1290}
1291
1292const char * ggml_op_name(enum ggml_op op) {
1293    return GGML_OP_NAME[op];
1294}
1295
1296const char * ggml_op_symbol(enum ggml_op op) {
1297    return GGML_OP_SYMBOL[op];
1298}
1299
1300const char * ggml_unary_op_name(enum ggml_unary_op op) {
1301    return GGML_UNARY_OP_NAME[op];
1302}
1303
1304const char * ggml_glu_op_name(enum ggml_glu_op op) {
1305    return GGML_GLU_OP_NAME[op];
1306}
1307
1308const char * ggml_op_desc(const struct ggml_tensor * t) {
1309    if (t->op == GGML_OP_UNARY) {
1310        enum ggml_unary_op uop = ggml_get_unary_op(t);
1311        return ggml_unary_op_name(uop);
1312    }
1313    if (t->op == GGML_OP_GLU) {
1314        enum ggml_glu_op gop = ggml_get_glu_op(t);
1315        return ggml_glu_op_name(gop);
1316    }
1317    return ggml_op_name(t->op);
1318}
1319
1320size_t ggml_element_size(const struct ggml_tensor * tensor) {
1321    return ggml_type_size(tensor->type);
1322}
1323
1324bool ggml_is_scalar(const struct ggml_tensor * tensor) {
1325    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1326
1327    return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
1328}
1329
1330bool ggml_is_vector(const struct ggml_tensor * tensor) {
1331    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1332
1333    return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
1334}
1335
1336bool ggml_is_matrix(const struct ggml_tensor * tensor) {
1337    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1338
1339    return tensor->ne[2] == 1 && tensor->ne[3] == 1;
1340}
1341
1342bool ggml_is_3d(const struct ggml_tensor * tensor) {
1343    return tensor->ne[3] == 1;
1344}
1345
1346int ggml_n_dims(const struct ggml_tensor * tensor) {
1347    for (int i = GGML_MAX_DIMS - 1; i >= 1; --i) {
1348        if (tensor->ne[i] > 1) {
1349            return i + 1;
1350        }
1351    }
1352    return 1;
1353}
1354
1355enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
1356    enum ggml_type wtype = GGML_TYPE_COUNT;
1357
1358    switch (ftype) {
1359        case GGML_FTYPE_ALL_F32:              wtype = GGML_TYPE_F32;   break;
1360        case GGML_FTYPE_MOSTLY_F16:           wtype = GGML_TYPE_F16;   break;
1361        case GGML_FTYPE_MOSTLY_BF16:          wtype = GGML_TYPE_BF16;  break;
1362        case GGML_FTYPE_MOSTLY_Q4_0:          wtype = GGML_TYPE_Q4_0;  break;
1363        case GGML_FTYPE_MOSTLY_Q4_1:          wtype = GGML_TYPE_Q4_1;  break;
1364        case GGML_FTYPE_MOSTLY_Q5_0:          wtype = GGML_TYPE_Q5_0;  break;
1365        case GGML_FTYPE_MOSTLY_Q5_1:          wtype = GGML_TYPE_Q5_1;  break;
1366        case GGML_FTYPE_MOSTLY_Q8_0:          wtype = GGML_TYPE_Q8_0;  break;
1367        case GGML_FTYPE_MOSTLY_MXFP4:         wtype = GGML_TYPE_MXFP4; break;
1368        case GGML_FTYPE_MOSTLY_Q2_K:          wtype = GGML_TYPE_Q2_K;  break;
1369        case GGML_FTYPE_MOSTLY_Q3_K:          wtype = GGML_TYPE_Q3_K;  break;
1370        case GGML_FTYPE_MOSTLY_Q4_K:          wtype = GGML_TYPE_Q4_K;  break;
1371        case GGML_FTYPE_MOSTLY_Q5_K:          wtype = GGML_TYPE_Q5_K;  break;
1372        case GGML_FTYPE_MOSTLY_Q6_K:          wtype = GGML_TYPE_Q6_K;  break;
1373        case GGML_FTYPE_MOSTLY_IQ2_XXS:       wtype = GGML_TYPE_IQ2_XXS;  break;
1374        case GGML_FTYPE_MOSTLY_IQ2_XS:        wtype = GGML_TYPE_IQ2_XS;   break;
1375        case GGML_FTYPE_MOSTLY_IQ3_XXS:       wtype = GGML_TYPE_IQ3_XXS;  break;
1376        case GGML_FTYPE_MOSTLY_IQ1_S:         wtype = GGML_TYPE_IQ1_S;    break;
1377        case GGML_FTYPE_MOSTLY_IQ1_M:         wtype = GGML_TYPE_IQ1_M;    break;
1378        case GGML_FTYPE_MOSTLY_IQ4_NL:        wtype = GGML_TYPE_IQ4_NL;   break;
1379        case GGML_FTYPE_MOSTLY_IQ4_XS:        wtype = GGML_TYPE_IQ4_XS;   break;
1380        case GGML_FTYPE_MOSTLY_IQ3_S:         wtype = GGML_TYPE_IQ3_S;    break;
1381        case GGML_FTYPE_MOSTLY_IQ2_S:         wtype = GGML_TYPE_IQ2_S;    break;
1382        case GGML_FTYPE_UNKNOWN:              wtype = GGML_TYPE_COUNT; break;
1383        case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
1384    }
1385
1386    GGML_ASSERT(wtype != GGML_TYPE_COUNT);
1387
1388    return wtype;
1389}
1390
1391size_t ggml_tensor_overhead(void) {
1392    return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE;
1393}
1394
1395bool ggml_is_transposed(const struct ggml_tensor * tensor) {
1396    return tensor->nb[0] > tensor->nb[1];
1397}
1398
1399static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) {
1400    size_t next_nb = ggml_type_size(tensor->type);
1401    if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) {
1402        return false;
1403    }
1404    next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type);
1405    for (int i = 1; i < GGML_MAX_DIMS; i++) {
1406        if (tensor->ne[i] != 1) {
1407            if (i > n) {
1408                if (tensor->nb[i] != next_nb) {
1409                    return false;
1410                }
1411                next_nb *= tensor->ne[i];
1412            } else {
1413                // this dimension does not need to be contiguous
1414                next_nb = tensor->ne[i]*tensor->nb[i];
1415            }
1416        }
1417    }
1418    return true;
1419}
1420
1421bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
1422    return ggml_is_contiguous_0(tensor);
1423}
1424
1425bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
1426    return ggml_is_contiguous_n(tensor, 0);
1427}
1428
1429bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
1430    return ggml_is_contiguous_n(tensor, 1);
1431}
1432
1433bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
1434    return ggml_is_contiguous_n(tensor, 2);
1435}
1436
1437bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
1438    return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
1439}
1440
1441bool ggml_is_permuted(const struct ggml_tensor * tensor) {
1442    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1443
1444    return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
1445}
1446
1447bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
1448    return
1449        tensor->nb[0] > tensor->nb[2] &&
1450        tensor->nb[1] > tensor->nb[0] &&
1451        tensor->nb[2] == ggml_type_size(tensor->type);
1452}
1453
1454bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
1455    return
1456        tensor->ne[0] == ggml_blck_size(tensor->type) ||
1457        tensor->nb[0] == ggml_type_size(tensor->type);
1458}
1459
1460static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
1461    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1462
1463    return
1464        tensor->nb[0] == ggml_type_size(tensor->type) &&
1465        tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
1466        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
1467}
1468
1469bool ggml_is_empty(const struct ggml_tensor * tensor) {
1470    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
1471        if (tensor->ne[i] == 0) {
1472            // empty if any dimension has no elements
1473            return true;
1474        }
1475    }
1476    return false;
1477}
1478
1479bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
1480    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1481
1482    return
1483        (t0->ne[0] == t1->ne[0]) &&
1484        (t0->ne[1] == t1->ne[1]) &&
1485        (t0->ne[2] == t1->ne[2]) &&
1486        (t0->ne[3] == t1->ne[3]);
1487}
1488
1489bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
1490    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1491
1492    return
1493        (t0->nb[0] == t1->nb[0]) &&
1494        (t0->nb[1] == t1->nb[1]) &&
1495        (t0->nb[2] == t1->nb[2]) &&
1496        (t0->nb[3] == t1->nb[3]);
1497}
1498
1499// check if t1 can be represented as a repetition of t0
1500bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
1501    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1502
1503    return ggml_is_empty(t0) ? ggml_is_empty(t1) :
1504        (t1->ne[0]%t0->ne[0] == 0) &&
1505        (t1->ne[1]%t0->ne[1] == 0) &&
1506        (t1->ne[2]%t0->ne[2] == 0) &&
1507        (t1->ne[3]%t0->ne[3] == 0);
1508}
1509
1510static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
1511    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
1512
1513    return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
1514}
1515
1516// assert that pointer is aligned to GGML_MEM_ALIGN
1517#define GGML_ASSERT_ALIGNED(ptr) \
1518    GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
1519
1520////////////////////////////////////////////////////////////////////////////////
1521
1522struct ggml_context * ggml_init(struct ggml_init_params params) {
1523    static bool is_first_call = true;
1524
1525    ggml_critical_section_start();
1526
1527    if (is_first_call) {
1528        // initialize time system (required on Windows)
1529        ggml_time_init();
1530
1531        is_first_call = false;
1532    }
1533
1534    ggml_critical_section_end();
1535
1536    struct ggml_context * ctx = GGML_MALLOC(sizeof(struct ggml_context));
1537
1538    // allow to call ggml_init with 0 size
1539    if (params.mem_size == 0) {
1540        params.mem_size = GGML_MEM_ALIGN;
1541    }
1542
1543    const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
1544
1545    *ctx = (struct ggml_context) {
1546        /*.mem_size           =*/ mem_size,
1547        /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : ggml_aligned_malloc(mem_size),
1548        /*.mem_buffer_owned   =*/ params.mem_buffer ? false : true,
1549        /*.no_alloc           =*/ params.no_alloc,
1550        /*.n_objects          =*/ 0,
1551        /*.objects_begin      =*/ NULL,
1552        /*.objects_end        =*/ NULL,
1553    };
1554
1555    GGML_ASSERT(ctx->mem_buffer != NULL);
1556
1557    GGML_ASSERT_ALIGNED(ctx->mem_buffer);
1558
1559    GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
1560
1561    return ctx;
1562}
1563
1564void ggml_reset(struct ggml_context * ctx) {
1565    if (ctx == NULL) {
1566        return;
1567    }
1568
1569    ctx->n_objects     = 0;
1570    ctx->objects_begin = NULL;
1571    ctx->objects_end   = NULL;
1572}
1573
1574void ggml_free(struct ggml_context * ctx) {
1575    if (ctx == NULL) {
1576        return;
1577    }
1578
1579    if (ctx->mem_buffer_owned) {
1580        ggml_aligned_free(ctx->mem_buffer, ctx->mem_size);
1581    }
1582
1583    GGML_FREE(ctx);
1584}
1585
1586size_t ggml_used_mem(const struct ggml_context * ctx) {
1587    return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
1588}
1589
1590bool ggml_get_no_alloc(struct ggml_context * ctx) {
1591    return ctx->no_alloc;
1592}
1593
1594void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
1595    ctx->no_alloc = no_alloc;
1596}
1597
1598void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
1599    return ctx->mem_buffer;
1600}
1601
1602size_t ggml_get_mem_size(const struct ggml_context * ctx) {
1603    return ctx->mem_size;
1604}
1605
1606size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
1607    size_t max_size = 0;
1608
1609    for (struct ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor != NULL; tensor = ggml_get_next_tensor(ctx, tensor)) {
1610        size_t bytes = ggml_nbytes(tensor);
1611        max_size = MAX(max_size, bytes);
1612    }
1613
1614    return max_size;
1615}
1616
1617////////////////////////////////////////////////////////////////////////////////
1618
1619static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) {
1620    // always insert objects at the end of the context's memory pool
1621    struct ggml_object * obj_cur = ctx->objects_end;
1622
1623    const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;
1624    const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
1625    const size_t cur_end  = cur_offs + cur_size;
1626
1627    // align to GGML_MEM_ALIGN
1628    size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
1629
1630    char * const mem_buffer = ctx->mem_buffer;
1631    struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
1632
1633    if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
1634        GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
1635                __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
1636#ifndef NDEBUG
1637        GGML_ABORT("not enough space in the context's memory pool");
1638#endif
1639        return NULL;
1640    }
1641
1642    *obj_new = (struct ggml_object) {
1643        .offs = cur_end + GGML_OBJECT_SIZE,
1644        .size = size_needed,
1645        .next = NULL,
1646        .type = type,
1647    };
1648
1649    GGML_ASSERT_ALIGNED(mem_buffer + obj_new->offs);
1650
1651    if (obj_cur != NULL) {
1652        obj_cur->next = obj_new;
1653    } else {
1654        // this is the first object in this context
1655        ctx->objects_begin = obj_new;
1656    }
1657
1658    ctx->objects_end = obj_new;
1659
1660    //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
1661
1662    return obj_new;
1663}
1664
1665static struct ggml_tensor * ggml_new_tensor_impl(
1666        struct ggml_context * ctx,
1667        enum   ggml_type      type,
1668        int                   n_dims,
1669        const int64_t       * ne,
1670        struct ggml_tensor  * view_src,
1671        size_t                view_offs) {
1672
1673    GGML_ASSERT(type >= 0 && type < GGML_TYPE_COUNT);
1674    GGML_ASSERT(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
1675
1676    // find the base tensor and absolute offset
1677    if (view_src != NULL && view_src->view_src != NULL) {
1678        view_offs += view_src->view_offs;
1679        view_src   = view_src->view_src;
1680    }
1681
1682    size_t data_size = ggml_row_size(type, ne[0]);
1683    for (int i = 1; i < n_dims; i++) {
1684        data_size *= ne[i];
1685    }
1686
1687    GGML_ASSERT(view_src == NULL || data_size == 0 || data_size + view_offs <= ggml_nbytes(view_src));
1688
1689    void * data = view_src != NULL ? view_src->data : NULL;
1690    if (data != NULL) {
1691        data = (char *) data + view_offs;
1692    }
1693
1694    size_t obj_alloc_size = 0;
1695
1696    if (view_src == NULL && !ctx->no_alloc) {
1697        // allocate tensor data in the context's memory pool
1698        obj_alloc_size = data_size;
1699    }
1700
1701    struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
1702    GGML_ASSERT(obj_new);
1703
1704    struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
1705
1706    *result = (struct ggml_tensor) {
1707        /*.type         =*/ type,
1708        /*.buffer       =*/ NULL,
1709        /*.ne           =*/ { 1, 1, 1, 1 },
1710        /*.nb           =*/ { 0, 0, 0, 0 },
1711        /*.op           =*/ GGML_OP_NONE,
1712        /*.op_params    =*/ { 0 },
1713        /*.flags        =*/ 0,
1714        /*.src          =*/ { NULL },
1715        /*.view_src     =*/ view_src,
1716        /*.view_offs    =*/ view_offs,
1717        /*.data         =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
1718        /*.name         =*/ { 0 },
1719        /*.extra        =*/ NULL,
1720        /*.padding      =*/ { 0 },
1721    };
1722
1723    // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
1724    //GGML_ASSERT_ALIGNED(result->data);
1725
1726    for (int i = 0; i < n_dims; i++) {
1727        result->ne[i] = ne[i];
1728    }
1729
1730    result->nb[0] = ggml_type_size(type);
1731    result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));
1732    for (int i = 2; i < GGML_MAX_DIMS; i++) {
1733        result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
1734    }
1735
1736    ctx->n_objects++;
1737
1738    return result;
1739}
1740
1741struct ggml_tensor * ggml_new_tensor(
1742        struct ggml_context * ctx,
1743        enum   ggml_type      type,
1744        int                   n_dims,
1745        const int64_t       * ne) {
1746    return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0);
1747}
1748
1749struct ggml_tensor * ggml_new_tensor_1d(
1750        struct ggml_context * ctx,
1751        enum   ggml_type      type,
1752        int64_t ne0) {
1753    return ggml_new_tensor(ctx, type, 1, &ne0);
1754}
1755
1756struct ggml_tensor * ggml_new_tensor_2d(
1757        struct ggml_context * ctx,
1758        enum   ggml_type      type,
1759        int64_t ne0,
1760        int64_t ne1) {
1761    const int64_t ne[2] = { ne0, ne1 };
1762    return ggml_new_tensor(ctx, type, 2, ne);
1763}
1764
1765struct ggml_tensor * ggml_new_tensor_3d(
1766        struct ggml_context * ctx,
1767        enum   ggml_type      type,
1768        int64_t ne0,
1769        int64_t ne1,
1770        int64_t ne2) {
1771    const int64_t ne[3] = { ne0, ne1, ne2 };
1772    return ggml_new_tensor(ctx, type, 3, ne);
1773}
1774
1775struct ggml_tensor * ggml_new_tensor_4d(
1776        struct ggml_context * ctx,
1777        enum   ggml_type type,
1778        int64_t ne0,
1779        int64_t ne1,
1780        int64_t ne2,
1781        int64_t ne3) {
1782    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
1783    return ggml_new_tensor(ctx, type, 4, ne);
1784}
1785
1786void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes) {
1787    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, nbytes);
1788
1789    return (uint8_t *)ctx->mem_buffer + obj->offs;
1790}
1791
1792struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {
1793    return ggml_new_tensor(ctx, src->type, GGML_MAX_DIMS, src->ne);
1794}
1795
1796void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
1797    const int64_t ne2 = tensor->ne[2];
1798    const int64_t ne1 = tensor->ne[1];
1799    const int64_t ne0 = tensor->ne[0];
1800
1801    const int64_t i3_ = (i/(ne2*ne1*ne0));
1802    const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
1803    const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
1804    const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
1805
1806    if (i0) {
1807        * i0 = i0_;
1808    }
1809    if (i1) {
1810        * i1 = i1_;
1811    }
1812    if (i2) {
1813        * i2 = i2_;
1814    }
1815    if (i3) {
1816        * i3 = i3_;
1817    }
1818}
1819
1820void * ggml_get_data(const struct ggml_tensor * tensor) {
1821    return tensor->data;
1822}
1823
1824float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
1825    assert(tensor->type == GGML_TYPE_F32);
1826    return (float *)(tensor->data);
1827}
1828
1829enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
1830    GGML_ASSERT(tensor->op == GGML_OP_UNARY);
1831    return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
1832}
1833
1834enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
1835    GGML_ASSERT(tensor->op == GGML_OP_GLU);
1836    return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
1837}
1838
1839const char * ggml_get_name(const struct ggml_tensor * tensor) {
1840    return tensor->name;
1841}
1842
1843struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) {
1844    size_t i;
1845    for (i = 0; i < sizeof(tensor->name) - 1 && name[i] != '\0'; i++) {
1846        tensor->name[i] = name[i];
1847    }
1848    tensor->name[i] = '\0';
1849    return tensor;
1850}
1851
1852struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) {
1853    va_list args;
1854    va_start(args, fmt);
1855    vsnprintf(tensor->name, sizeof(tensor->name), fmt, args);
1856    va_end(args);
1857    return tensor;
1858}
1859
1860struct ggml_tensor * ggml_view_tensor(
1861        struct ggml_context * ctx,
1862        struct ggml_tensor  * src) {
1863    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, GGML_MAX_DIMS, src->ne, src, 0);
1864    ggml_format_name(result, "%s (view)", src->name);
1865
1866    for (int i = 0; i < GGML_MAX_DIMS; i++) {
1867        result->nb[i] = src->nb[i];
1868    }
1869
1870    return result;
1871}
1872
1873struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx) {
1874    struct ggml_object * obj = ctx->objects_begin;
1875
1876    char * const mem_buffer = ctx->mem_buffer;
1877
1878    while (obj != NULL) {
1879        if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
1880            return (struct ggml_tensor *)(mem_buffer + obj->offs);
1881        }
1882
1883        obj = obj->next;
1884    }
1885
1886    return NULL;
1887}
1888
1889struct ggml_tensor * ggml_get_next_tensor(const struct ggml_context * ctx, struct ggml_tensor * tensor) {
1890    struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);
1891    obj = obj->next;
1892
1893    char * const mem_buffer = ctx->mem_buffer;
1894
1895    while (obj != NULL) {
1896        if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
1897            return (struct ggml_tensor *)(mem_buffer + obj->offs);
1898        }
1899
1900        obj = obj->next;
1901    }
1902
1903    return NULL;
1904}
1905
1906struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
1907    struct ggml_object * obj = ctx->objects_begin;
1908
1909    char * const mem_buffer = ctx->mem_buffer;
1910
1911    while (obj != NULL) {
1912        if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
1913            struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
1914            if (strcmp(cur->name, name) == 0) {
1915                return cur;
1916            }
1917        }
1918
1919        obj = obj->next;
1920    }
1921
1922    return NULL;
1923}
1924
1925////////////////////////////////////////////////////////////////////////////////
1926
1927// ggml_dup
1928
1929static struct ggml_tensor * ggml_dup_impl(
1930        struct ggml_context * ctx,
1931        struct ggml_tensor  * a,
1932        bool                  inplace) {
1933    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
1934
1935    result->op     = GGML_OP_DUP;
1936    result->src[0] = a;
1937
1938    return result;
1939}
1940
1941struct ggml_tensor * ggml_dup(
1942        struct ggml_context * ctx,
1943        struct ggml_tensor  * a) {
1944    return ggml_dup_impl(ctx, a, false);
1945}
1946
1947struct ggml_tensor * ggml_dup_inplace(
1948        struct ggml_context * ctx,
1949        struct ggml_tensor  * a) {
1950    return ggml_dup_impl(ctx, a, true);
1951}
1952
1953// ggml_add
1954
1955static struct ggml_tensor * ggml_add_impl(
1956        struct ggml_context * ctx,
1957        struct ggml_tensor  * a,
1958        struct ggml_tensor  * b,
1959        bool                  inplace) {
1960    GGML_ASSERT(ggml_can_repeat(b, a));
1961
1962    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
1963
1964    result->op     = GGML_OP_ADD;
1965    result->src[0] = a;
1966    result->src[1] = b;
1967
1968    return result;
1969}
1970
1971struct ggml_tensor * ggml_add(
1972        struct ggml_context * ctx,
1973        struct ggml_tensor  * a,
1974        struct ggml_tensor  * b) {
1975    return ggml_add_impl(ctx, a, b, false);
1976}
1977
1978struct ggml_tensor * ggml_add_inplace(
1979        struct ggml_context * ctx,
1980        struct ggml_tensor  * a,
1981        struct ggml_tensor  * b) {
1982    return ggml_add_impl(ctx, a, b, true);
1983}
1984
1985// ggml_add_cast
1986
1987static struct ggml_tensor * ggml_add_cast_impl(
1988        struct ggml_context * ctx,
1989        struct ggml_tensor  * a,
1990        struct ggml_tensor  * b,
1991        enum   ggml_type      type) {
1992    // TODO: support less-strict constraint
1993    //       GGML_ASSERT(ggml_can_repeat(b, a));
1994    GGML_ASSERT(ggml_can_repeat_rows(b, a));
1995
1996    // currently only supported for quantized input and f16
1997    GGML_ASSERT(ggml_is_quantized(a->type) ||
1998                a->type == GGML_TYPE_F16 ||
1999                a->type == GGML_TYPE_BF16);
2000
2001    struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
2002
2003    result->op     = GGML_OP_ADD;
2004    result->src[0] = a;
2005    result->src[1] = b;
2006
2007    return result;
2008}
2009
2010struct ggml_tensor * ggml_add_cast(
2011        struct ggml_context * ctx,
2012        struct ggml_tensor  * a,
2013        struct ggml_tensor  * b,
2014        enum   ggml_type      type) {
2015    return ggml_add_cast_impl(ctx, a, b, type);
2016}
2017
2018struct ggml_tensor * ggml_add_id(
2019            struct ggml_context * ctx,
2020            struct ggml_tensor  * a,
2021            struct ggml_tensor  * b,
2022            struct ggml_tensor  * ids) {
2023
2024    GGML_ASSERT(a->ne[0] == b->ne[0]);
2025    GGML_ASSERT(a->ne[1] == ids->ne[0]);
2026    GGML_ASSERT(a->ne[2] == ids->ne[1]);
2027    GGML_ASSERT(ids->type == GGML_TYPE_I32);
2028
2029    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
2030
2031    result->op     = GGML_OP_ADD_ID;
2032    result->src[0] = a;
2033    result->src[1] = b;
2034    result->src[2] = ids;
2035
2036    return result;
2037}
2038
2039// ggml_add1
2040
2041static struct ggml_tensor * ggml_add1_impl(
2042        struct ggml_context * ctx,
2043        struct ggml_tensor  * a,
2044        struct ggml_tensor  * b,
2045        bool                  inplace) {
2046    GGML_ASSERT(ggml_is_scalar(b));
2047    GGML_ASSERT(ggml_is_padded_1d(a));
2048
2049    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2050
2051    result->op     = GGML_OP_ADD1;
2052    result->src[0] = a;
2053    result->src[1] = b;
2054
2055    return result;
2056}
2057
2058struct ggml_tensor * ggml_add1(
2059        struct ggml_context * ctx,
2060        struct ggml_tensor  * a,
2061        struct ggml_tensor  * b) {
2062    return ggml_add1_impl(ctx, a, b, false);
2063}
2064
2065struct ggml_tensor * ggml_add1_inplace(
2066        struct ggml_context * ctx,
2067        struct ggml_tensor  * a,
2068        struct ggml_tensor  * b) {
2069    return ggml_add1_impl(ctx, a, b, true);
2070}
2071
2072// ggml_acc
2073
2074static struct ggml_tensor * ggml_acc_impl(
2075        struct ggml_context * ctx,
2076        struct ggml_tensor  * a,
2077        struct ggml_tensor  * b,
2078        size_t                nb1,
2079        size_t                nb2,
2080        size_t                nb3,
2081        size_t                offset,
2082        bool                  inplace) {
2083    GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a));
2084    GGML_ASSERT(ggml_is_contiguous(a));
2085    GGML_ASSERT(a->type == GGML_TYPE_F32);
2086    GGML_ASSERT(b->type == GGML_TYPE_F32);
2087
2088    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2089
2090    int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
2091    ggml_set_op_params(result, params, sizeof(params));
2092
2093    result->op     = GGML_OP_ACC;
2094    result->src[0] = a;
2095    result->src[1] = b;
2096
2097    return result;
2098}
2099
2100struct ggml_tensor * ggml_acc(
2101        struct ggml_context * ctx,
2102        struct ggml_tensor  * a,
2103        struct ggml_tensor  * b,
2104        size_t                nb1,
2105        size_t                nb2,
2106        size_t                nb3,
2107        size_t                offset) {
2108    return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
2109}
2110
2111struct ggml_tensor * ggml_acc_inplace(
2112        struct ggml_context * ctx,
2113        struct ggml_tensor  * a,
2114        struct ggml_tensor  * b,
2115        size_t                nb1,
2116        size_t                nb2,
2117        size_t                nb3,
2118        size_t                offset) {
2119    return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
2120}
2121
2122// ggml_sub
2123
2124static struct ggml_tensor * ggml_sub_impl(
2125        struct ggml_context * ctx,
2126        struct ggml_tensor  * a,
2127        struct ggml_tensor  * b,
2128        bool                  inplace) {
2129    GGML_ASSERT(ggml_can_repeat(b, a));
2130
2131    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2132
2133    result->op     = GGML_OP_SUB;
2134    result->src[0] = a;
2135    result->src[1] = b;
2136
2137    return result;
2138}
2139
2140struct ggml_tensor * ggml_sub(
2141        struct ggml_context * ctx,
2142        struct ggml_tensor  * a,
2143        struct ggml_tensor  * b) {
2144    return ggml_sub_impl(ctx, a, b, false);
2145}
2146
2147struct ggml_tensor * ggml_sub_inplace(
2148        struct ggml_context * ctx,
2149        struct ggml_tensor  * a,
2150        struct ggml_tensor  * b) {
2151    return ggml_sub_impl(ctx, a, b, true);
2152}
2153
2154// ggml_mul
2155
2156static struct ggml_tensor * ggml_mul_impl(
2157        struct ggml_context * ctx,
2158        struct ggml_tensor  * a,
2159        struct ggml_tensor  * b,
2160        bool                  inplace) {
2161    GGML_ASSERT(ggml_can_repeat(b, a));
2162
2163    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2164
2165    result->op     = GGML_OP_MUL;
2166    result->src[0] = a;
2167    result->src[1] = b;
2168
2169    return result;
2170}
2171
2172struct ggml_tensor * ggml_mul(
2173        struct ggml_context * ctx,
2174        struct ggml_tensor  * a,
2175        struct ggml_tensor  * b) {
2176    return ggml_mul_impl(ctx, a, b, false);
2177}
2178
2179struct ggml_tensor * ggml_mul_inplace(
2180        struct ggml_context * ctx,
2181        struct ggml_tensor  * a,
2182        struct ggml_tensor  * b) {
2183    return ggml_mul_impl(ctx, a, b, true);
2184}
2185
2186// ggml_div
2187
2188static struct ggml_tensor * ggml_div_impl(
2189        struct ggml_context * ctx,
2190        struct ggml_tensor  * a,
2191        struct ggml_tensor  * b,
2192        bool                  inplace) {
2193    GGML_ASSERT(ggml_can_repeat(b, a));
2194
2195    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2196
2197    result->op     = GGML_OP_DIV;
2198    result->src[0] = a;
2199    result->src[1] = b;
2200
2201    return result;
2202}
2203
2204struct ggml_tensor * ggml_div(
2205        struct ggml_context * ctx,
2206        struct ggml_tensor  * a,
2207        struct ggml_tensor  * b) {
2208    return ggml_div_impl(ctx, a, b, false);
2209}
2210
2211struct ggml_tensor * ggml_div_inplace(
2212        struct ggml_context * ctx,
2213        struct ggml_tensor  * a,
2214        struct ggml_tensor  * b) {
2215    return ggml_div_impl(ctx, a, b, true);
2216}
2217
2218// ggml_sqr
2219
2220static struct ggml_tensor * ggml_sqr_impl(
2221        struct ggml_context * ctx,
2222        struct ggml_tensor  * a,
2223        bool                  inplace) {
2224    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2225
2226    result->op     = GGML_OP_SQR;
2227    result->src[0] = a;
2228
2229    return result;
2230}
2231
2232struct ggml_tensor * ggml_sqr(
2233        struct ggml_context * ctx,
2234        struct ggml_tensor  * a) {
2235    return ggml_sqr_impl(ctx, a, false);
2236}
2237
2238struct ggml_tensor * ggml_sqr_inplace(
2239        struct ggml_context * ctx,
2240        struct ggml_tensor  * a) {
2241    return ggml_sqr_impl(ctx, a, true);
2242}
2243
2244// ggml_sqrt
2245
2246static struct ggml_tensor * ggml_sqrt_impl(
2247        struct ggml_context * ctx,
2248        struct ggml_tensor  * a,
2249        bool                  inplace) {
2250    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2251
2252    result->op     = GGML_OP_SQRT;
2253    result->src[0] = a;
2254
2255    return result;
2256}
2257
2258struct ggml_tensor * ggml_sqrt(
2259        struct ggml_context * ctx,
2260        struct ggml_tensor  * a) {
2261    return ggml_sqrt_impl(ctx, a, false);
2262}
2263
2264struct ggml_tensor * ggml_sqrt_inplace(
2265        struct ggml_context * ctx,
2266        struct ggml_tensor  * a) {
2267    return ggml_sqrt_impl(ctx, a, true);
2268}
2269
2270// ggml_log
2271
2272static struct ggml_tensor * ggml_log_impl(
2273        struct ggml_context * ctx,
2274        struct ggml_tensor  * a,
2275        bool                  inplace) {
2276    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2277
2278    result->op     = GGML_OP_LOG;
2279    result->src[0] = a;
2280
2281    return result;
2282}
2283
2284struct ggml_tensor * ggml_log(
2285        struct ggml_context * ctx,
2286        struct ggml_tensor  * a) {
2287    return ggml_log_impl(ctx, a, false);
2288}
2289
2290struct ggml_tensor * ggml_log_inplace(
2291        struct ggml_context * ctx,
2292        struct ggml_tensor  * a) {
2293    return ggml_log_impl(ctx, a, true);
2294}
2295
2296struct ggml_tensor * ggml_expm1(
2297        struct ggml_context * ctx,
2298        struct ggml_tensor  * a) {
2299    return ggml_unary(ctx, a, GGML_UNARY_OP_EXPM1);
2300}
2301
2302struct ggml_tensor * ggml_expm1_inplace(
2303        struct ggml_context * ctx,
2304        struct ggml_tensor  * a) {
2305    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXPM1);
2306}
2307
2308struct ggml_tensor * ggml_softplus(
2309        struct ggml_context * ctx,
2310        struct ggml_tensor  * a) {
2311    return ggml_unary(ctx, a, GGML_UNARY_OP_SOFTPLUS);
2312}
2313
2314struct ggml_tensor * ggml_softplus_inplace(
2315        struct ggml_context * ctx,
2316        struct ggml_tensor  * a) {
2317    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SOFTPLUS);
2318}
2319
2320// ggml_sin
2321
2322static struct ggml_tensor * ggml_sin_impl(
2323        struct ggml_context * ctx,
2324        struct ggml_tensor  * a,
2325        bool                  inplace) {
2326    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2327
2328    result->op     = GGML_OP_SIN;
2329    result->src[0] = a;
2330
2331    return result;
2332}
2333
2334struct ggml_tensor * ggml_sin(
2335        struct ggml_context * ctx,
2336        struct ggml_tensor  * a) {
2337    return ggml_sin_impl(ctx, a, false);
2338}
2339
2340struct ggml_tensor * ggml_sin_inplace(
2341        struct ggml_context * ctx,
2342        struct ggml_tensor  * a) {
2343    return ggml_sin_impl(ctx, a, true);
2344}
2345
2346// ggml_cos
2347
2348static struct ggml_tensor * ggml_cos_impl(
2349        struct ggml_context * ctx,
2350        struct ggml_tensor  * a,
2351        bool                  inplace) {
2352    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2353
2354    result->op     = GGML_OP_COS;
2355    result->src[0] = a;
2356
2357    return result;
2358}
2359
2360struct ggml_tensor * ggml_cos(
2361        struct ggml_context * ctx,
2362        struct ggml_tensor  * a) {
2363    return ggml_cos_impl(ctx, a, false);
2364}
2365
2366struct ggml_tensor * ggml_cos_inplace(
2367        struct ggml_context * ctx,
2368        struct ggml_tensor  * a) {
2369    return ggml_cos_impl(ctx, a, true);
2370}
2371
2372// ggml_sum
2373
2374struct ggml_tensor * ggml_sum(
2375        struct ggml_context * ctx,
2376        struct ggml_tensor  * a) {
2377    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
2378
2379    result->op     = GGML_OP_SUM;
2380    result->src[0] = a;
2381
2382    return result;
2383}
2384
2385// ggml_sum_rows
2386
2387struct ggml_tensor * ggml_sum_rows(
2388        struct ggml_context * ctx,
2389        struct ggml_tensor  * a) {
2390    int64_t ne[GGML_MAX_DIMS] = { 1 };
2391    for (int i = 1; i < GGML_MAX_DIMS; ++i) {
2392        ne[i] = a->ne[i];
2393    }
2394
2395    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
2396
2397    result->op     = GGML_OP_SUM_ROWS;
2398    result->src[0] = a;
2399
2400    return result;
2401}
2402
2403// ggml_cumsum
2404
2405struct ggml_tensor * ggml_cumsum(
2406        struct ggml_context * ctx,
2407        struct ggml_tensor  * a) {
2408    GGML_ASSERT(a->type == GGML_TYPE_F32);
2409
2410    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
2411
2412    result->op     = GGML_OP_CUMSUM;
2413    result->src[0] = a;
2414
2415    return result;
2416}
2417
2418// ggml_mean
2419
2420struct ggml_tensor * ggml_mean(
2421        struct ggml_context * ctx,
2422        struct ggml_tensor  * a) {
2423    int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] };
2424    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
2425
2426    result->op     = GGML_OP_MEAN;
2427    result->src[0] = a;
2428
2429    return result;
2430}
2431
2432// ggml_argmax
2433
2434struct ggml_tensor * ggml_argmax(
2435        struct ggml_context * ctx,
2436        struct ggml_tensor  * a) {
2437    GGML_ASSERT(ggml_is_matrix(a));
2438    GGML_ASSERT(a->ne[0] <= INT32_MAX);
2439
2440    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
2441
2442    result->op     = GGML_OP_ARGMAX;
2443    result->src[0] = a;
2444
2445    return result;
2446}
2447
2448// ggml_count_equal
2449
2450struct ggml_tensor * ggml_count_equal(
2451        struct ggml_context * ctx,
2452        struct ggml_tensor  * a,
2453        struct ggml_tensor  * b) {
2454    GGML_ASSERT(ggml_are_same_shape(a, b));
2455
2456    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1);
2457
2458    result->op     = GGML_OP_COUNT_EQUAL;
2459    result->src[0] = a;
2460    result->src[1] = b;
2461
2462    return result;
2463}
2464
2465// ggml_repeat
2466
2467struct ggml_tensor * ggml_repeat(
2468        struct ggml_context * ctx,
2469        struct ggml_tensor  * a,
2470        struct ggml_tensor  * b) {
2471    GGML_ASSERT(ggml_can_repeat(a, b));
2472
2473    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
2474
2475    result->op     = GGML_OP_REPEAT;
2476    result->src[0] = a;
2477
2478    return result;
2479}
2480
2481struct ggml_tensor * ggml_repeat_4d(
2482        struct ggml_context * ctx,
2483        struct ggml_tensor * a,
2484        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
2485    const bool can_repeat = ggml_is_empty(a) || (
2486        (ne0 % a->ne[0] == 0) &&
2487        (ne1 % a->ne[1] == 0) &&
2488        (ne2 % a->ne[2] == 0) &&
2489        (ne3 % a->ne[3] == 0)
2490    );
2491    GGML_ASSERT(can_repeat);
2492
2493    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
2494
2495    result->op     = GGML_OP_REPEAT;
2496    result->src[0] = a;
2497
2498    return result;
2499}
2500
2501// ggml_repeat_back
2502
2503struct ggml_tensor * ggml_repeat_back(
2504        struct ggml_context * ctx,
2505        struct ggml_tensor  * a,
2506        struct ggml_tensor  * b) {
2507    GGML_ASSERT(ggml_can_repeat(b, a));
2508
2509    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
2510
2511    result->op     = GGML_OP_REPEAT_BACK;
2512    result->src[0] = a;
2513
2514    return result;
2515}
2516
2517// ggml_concat
2518
2519struct ggml_tensor * ggml_concat(
2520    struct ggml_context * ctx,
2521    struct ggml_tensor  * a,
2522    struct ggml_tensor  * b,
2523    int                   dim) {
2524    GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
2525    GGML_ASSERT(a->type == b->type);
2526
2527    int64_t ne[GGML_MAX_DIMS];
2528    for (int d = 0; d < GGML_MAX_DIMS; ++d) {
2529        if (d == dim) {
2530            ne[d] = a->ne[d] + b->ne[d];
2531            continue;
2532        }
2533        GGML_ASSERT(a->ne[d] == b->ne[d]);
2534        ne[d] = a->ne[d];
2535    }
2536
2537    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
2538
2539    ggml_set_op_params_i32(result, 0, dim);
2540
2541    result->op     = GGML_OP_CONCAT;
2542    result->src[0] = a;
2543    result->src[1] = b;
2544
2545    return result;
2546}
2547
2548// ggml_abs
2549
2550struct ggml_tensor * ggml_abs(
2551        struct ggml_context * ctx,
2552        struct ggml_tensor  * a) {
2553    return ggml_unary(ctx, a, GGML_UNARY_OP_ABS);
2554}
2555
2556struct ggml_tensor * ggml_abs_inplace(
2557        struct ggml_context * ctx,
2558        struct ggml_tensor  * a) {
2559    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS);
2560}
2561
2562// ggml_sgn
2563
2564struct ggml_tensor * ggml_sgn(
2565        struct ggml_context * ctx,
2566        struct ggml_tensor  * a) {
2567    return ggml_unary(ctx, a, GGML_UNARY_OP_SGN);
2568}
2569
2570struct ggml_tensor * ggml_sgn_inplace(
2571        struct ggml_context * ctx,
2572        struct ggml_tensor  * a) {
2573    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN);
2574}
2575
2576// ggml_neg
2577
2578struct ggml_tensor * ggml_neg(
2579        struct ggml_context * ctx,
2580        struct ggml_tensor  * a) {
2581    return ggml_unary(ctx, a, GGML_UNARY_OP_NEG);
2582}
2583
2584struct ggml_tensor * ggml_neg_inplace(
2585        struct ggml_context * ctx,
2586        struct ggml_tensor  * a) {
2587    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG);
2588}
2589
2590// ggml_step
2591
2592struct ggml_tensor * ggml_step(
2593        struct ggml_context * ctx,
2594        struct ggml_tensor  * a) {
2595    return ggml_unary(ctx, a, GGML_UNARY_OP_STEP);
2596}
2597
2598struct ggml_tensor * ggml_step_inplace(
2599        struct ggml_context * ctx,
2600        struct ggml_tensor  * a) {
2601    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP);
2602}
2603
2604// ggml_tanh
2605
2606struct ggml_tensor * ggml_tanh(
2607        struct ggml_context * ctx,
2608        struct ggml_tensor  * a) {
2609    return ggml_unary(ctx, a, GGML_UNARY_OP_TANH);
2610}
2611
2612struct ggml_tensor * ggml_tanh_inplace(
2613        struct ggml_context * ctx,
2614        struct ggml_tensor  * a) {
2615    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH);
2616}
2617
2618// ggml_elu
2619
2620struct ggml_tensor * ggml_elu(
2621    struct ggml_context * ctx,
2622    struct ggml_tensor  * a) {
2623    return ggml_unary(ctx, a, GGML_UNARY_OP_ELU);
2624}
2625
2626struct ggml_tensor * ggml_elu_inplace(
2627    struct ggml_context * ctx,
2628    struct ggml_tensor  * a) {
2629    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU);
2630}
2631
2632// ggml_relu
2633
2634struct ggml_tensor * ggml_relu(
2635        struct ggml_context * ctx,
2636        struct ggml_tensor  * a) {
2637    return ggml_unary(ctx, a, GGML_UNARY_OP_RELU);
2638}
2639
2640struct ggml_tensor * ggml_relu_inplace(
2641        struct ggml_context * ctx,
2642        struct ggml_tensor  * a) {
2643    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
2644}
2645
2646// ggml_leaky_relu
2647
2648struct ggml_tensor * ggml_leaky_relu(
2649        struct ggml_context * ctx,
2650        struct ggml_tensor  * a,
2651        float                 negative_slope,
2652        bool                  inplace) {
2653    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
2654
2655    ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
2656
2657    result->op     = GGML_OP_LEAKY_RELU;
2658    result->src[0] = a;
2659
2660    return result;
2661}
2662
2663// ggml_sigmoid
2664
2665struct ggml_tensor * ggml_sigmoid(
2666        struct ggml_context * ctx,
2667        struct ggml_tensor  * a) {
2668    return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
2669}
2670
2671struct ggml_tensor * ggml_sigmoid_inplace(
2672        struct ggml_context * ctx,
2673        struct ggml_tensor  * a) {
2674    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
2675}
2676
2677// ggml_gelu
2678
2679struct ggml_tensor * ggml_gelu(
2680        struct ggml_context * ctx,
2681        struct ggml_tensor  * a) {
2682    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU);
2683}
2684
2685struct ggml_tensor * ggml_gelu_inplace(
2686        struct ggml_context * ctx,
2687        struct ggml_tensor  * a) {
2688    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
2689}
2690
2691// ggml_gelu_erf
2692
2693struct ggml_tensor * ggml_gelu_erf(
2694        struct ggml_context * ctx,
2695        struct ggml_tensor  * a) {
2696    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);
2697}
2698
2699struct ggml_tensor * ggml_gelu_erf_inplace(
2700        struct ggml_context * ctx,
2701        struct ggml_tensor  * a) {
2702    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);
2703}
2704
2705// ggml_gelu_quick
2706
2707struct ggml_tensor * ggml_gelu_quick(
2708        struct ggml_context * ctx,
2709        struct ggml_tensor  * a) {
2710    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK);
2711}
2712
2713struct ggml_tensor * ggml_gelu_quick_inplace(
2714        struct ggml_context * ctx,
2715        struct ggml_tensor  * a) {
2716    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK);
2717}
2718
2719// ggml_silu
2720
2721struct ggml_tensor * ggml_silu(
2722        struct ggml_context * ctx,
2723        struct ggml_tensor  * a) {
2724    return ggml_unary(ctx, a, GGML_UNARY_OP_SILU);
2725}
2726
2727struct ggml_tensor * ggml_silu_inplace(
2728        struct ggml_context * ctx,
2729        struct ggml_tensor  * a) {
2730    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
2731}
2732
2733// ggml_xielu
2734
2735struct ggml_tensor * ggml_xielu(
2736        struct ggml_context * ctx,
2737        struct ggml_tensor  * a,
2738        float alpha_n,
2739        float alpha_p,
2740        float beta,
2741        float eps) {
2742    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
2743
2744    ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
2745    ggml_set_op_params_f32(result, 1, beta + ggml_compute_softplus_f32(alpha_n));
2746    ggml_set_op_params_f32(result, 2, ggml_compute_softplus_f32(alpha_p));
2747    ggml_set_op_params_f32(result, 3, beta);
2748    ggml_set_op_params_f32(result, 4, eps);
2749
2750    result->op     = GGML_OP_UNARY;
2751    result->src[0] = a;
2752
2753    return result;
2754}
2755
2756// ggml_silu_back
2757
2758struct ggml_tensor * ggml_silu_back(
2759        struct ggml_context * ctx,
2760        struct ggml_tensor  * a,
2761        struct ggml_tensor  * b) {
2762    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
2763
2764    result->op     = GGML_OP_SILU_BACK;
2765    result->src[0] = a;
2766    result->src[1] = b;
2767
2768    return result;
2769}
2770
2771// ggml hardswish
2772
2773struct ggml_tensor * ggml_hardswish(
2774        struct ggml_context * ctx,
2775        struct ggml_tensor  * a) {
2776    return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSWISH);
2777}
2778
2779// ggml hardsigmoid
2780
2781struct ggml_tensor * ggml_hardsigmoid(
2782        struct ggml_context * ctx,
2783        struct ggml_tensor  * a) {
2784    return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
2785}
2786
2787// ggml exp
2788
2789struct ggml_tensor * ggml_exp(
2790        struct ggml_context * ctx,
2791        struct ggml_tensor  * a) {
2792    return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
2793}
2794
2795struct ggml_tensor * ggml_exp_inplace(
2796        struct ggml_context * ctx,
2797        struct ggml_tensor  * a) {
2798    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
2799}
2800
2801// ggml_glu
2802
2803static struct ggml_tensor * ggml_glu_impl(
2804        struct ggml_context * ctx,
2805        struct ggml_tensor  * a,
2806        struct ggml_tensor  * b,
2807        enum ggml_glu_op      op,
2808        bool                  swapped) {
2809    GGML_ASSERT(ggml_is_contiguous_1(a));
2810
2811    if (b) {
2812        GGML_ASSERT(ggml_is_contiguous_1(b));
2813        GGML_ASSERT(ggml_are_same_shape(a, b));
2814        GGML_ASSERT(a->type == b->type);
2815    }
2816
2817    int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
2818    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
2819
2820    ggml_set_op_params_i32(result, 0, (int32_t) op);
2821    ggml_set_op_params_i32(result, 1, (int32_t) swapped);
2822
2823    result->op     = GGML_OP_GLU;
2824    result->src[0] = a;
2825    result->src[1] = b;
2826
2827    return result;
2828}
2829
2830// ggml_floor
2831
2832struct ggml_tensor * ggml_floor(
2833        struct ggml_context * ctx,
2834        struct ggml_tensor  * a) {
2835    return ggml_unary(ctx, a, GGML_UNARY_OP_FLOOR);
2836}
2837
2838struct ggml_tensor * ggml_floor_inplace(
2839        struct ggml_context * ctx,
2840        struct ggml_tensor  * a) {
2841    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_FLOOR);
2842}
2843
2844// ggml_ceil
2845
2846struct ggml_tensor * ggml_ceil(
2847        struct ggml_context * ctx,
2848        struct ggml_tensor  * a) {
2849    return ggml_unary(ctx, a, GGML_UNARY_OP_CEIL);
2850}
2851
2852struct ggml_tensor * ggml_ceil_inplace(
2853        struct ggml_context * ctx,
2854        struct ggml_tensor  * a) {
2855    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_CEIL);
2856}
2857
2858//ggml_round
2859
2860struct ggml_tensor * ggml_round(
2861        struct ggml_context * ctx,
2862        struct ggml_tensor  * a) {
2863    return ggml_unary(ctx, a, GGML_UNARY_OP_ROUND);
2864}
2865
2866struct ggml_tensor * ggml_round_inplace(
2867        struct ggml_context * ctx,
2868        struct ggml_tensor  * a) {
2869    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ROUND);
2870}
2871
2872//ggml_trunc
2873
2874struct ggml_tensor * ggml_trunc(
2875        struct ggml_context * ctx,
2876        struct ggml_tensor  * a) {
2877    return ggml_unary(ctx, a, GGML_UNARY_OP_TRUNC);
2878}
2879
2880struct ggml_tensor * ggml_trunc_inplace(
2881        struct ggml_context * ctx,
2882        struct ggml_tensor  * a) {
2883    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TRUNC);
2884}
2885
2886struct ggml_tensor * ggml_glu(
2887        struct ggml_context * ctx,
2888        struct ggml_tensor  * a,
2889        enum ggml_glu_op      op,
2890        bool                  swapped) {
2891    return ggml_glu_impl(ctx, a, NULL, op, swapped);
2892}
2893
2894struct ggml_tensor * ggml_glu_split(
2895        struct ggml_context * ctx,
2896        struct ggml_tensor  * a,
2897        struct ggml_tensor  * b,
2898        enum ggml_glu_op      op) {
2899    return ggml_glu_impl(ctx, a, b, op, false);
2900}
2901
2902// ggml_reglu
2903
2904struct ggml_tensor * ggml_reglu(
2905        struct ggml_context * ctx,
2906        struct ggml_tensor  * a) {
2907    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
2908}
2909
2910struct ggml_tensor * ggml_reglu_swapped(
2911        struct ggml_context * ctx,
2912        struct ggml_tensor  * a) {
2913    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
2914}
2915
2916struct ggml_tensor * ggml_reglu_split(
2917        struct ggml_context * ctx,
2918        struct ggml_tensor  * a,
2919        struct ggml_tensor  * b) {
2920    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
2921}
2922
2923// ggml_geglu
2924
2925struct ggml_tensor * ggml_geglu(
2926        struct ggml_context * ctx,
2927        struct ggml_tensor  * a) {
2928    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
2929}
2930
2931struct ggml_tensor * ggml_geglu_swapped(
2932        struct ggml_context * ctx,
2933        struct ggml_tensor  * a) {
2934    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
2935}
2936
2937struct ggml_tensor * ggml_geglu_split(
2938        struct ggml_context * ctx,
2939        struct ggml_tensor  * a,
2940        struct ggml_tensor  * b) {
2941    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
2942}
2943
2944// ggml_swiglu
2945
2946struct ggml_tensor * ggml_swiglu(
2947        struct ggml_context * ctx,
2948        struct ggml_tensor  * a) {
2949    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
2950}
2951
2952struct ggml_tensor * ggml_swiglu_swapped(
2953        struct ggml_context * ctx,
2954        struct ggml_tensor  * a) {
2955    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
2956}
2957
2958struct ggml_tensor * ggml_swiglu_split(
2959        struct ggml_context * ctx,
2960        struct ggml_tensor  * a,
2961        struct ggml_tensor  * b) {
2962    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
2963}
2964
2965// ggml_geglu_erf
2966
2967struct ggml_tensor * ggml_geglu_erf(
2968        struct ggml_context * ctx,
2969        struct ggml_tensor  * a) {
2970    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
2971}
2972
2973struct ggml_tensor * ggml_geglu_erf_swapped(
2974        struct ggml_context * ctx,
2975        struct ggml_tensor  * a) {
2976    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
2977}
2978
2979struct ggml_tensor * ggml_geglu_erf_split(
2980        struct ggml_context * ctx,
2981        struct ggml_tensor  * a,
2982        struct ggml_tensor  * b) {
2983    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
2984}
2985
2986// ggml_geglu_quick
2987
2988struct ggml_tensor * ggml_geglu_quick(
2989        struct ggml_context * ctx,
2990        struct ggml_tensor  * a) {
2991    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
2992}
2993
2994struct ggml_tensor * ggml_geglu_quick_swapped(
2995        struct ggml_context * ctx,
2996        struct ggml_tensor  * a) {
2997    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
2998}
2999
3000struct ggml_tensor * ggml_geglu_quick_split(
3001        struct ggml_context * ctx,
3002        struct ggml_tensor  * a,
3003        struct ggml_tensor  * b) {
3004    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
3005}
3006
3007struct ggml_tensor * ggml_swiglu_oai(
3008        struct ggml_context * ctx,
3009        struct ggml_tensor  * a,
3010        struct ggml_tensor  * b,
3011        float                 alpha,
3012        float                 limit) {
3013    struct ggml_tensor * result = ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU_OAI, false);
3014    ggml_set_op_params_f32(result, 2, alpha);
3015    ggml_set_op_params_f32(result, 3, limit);
3016
3017    return result;
3018}
3019
3020// ggml_norm
3021
3022static struct ggml_tensor * ggml_norm_impl(
3023        struct ggml_context * ctx,
3024        struct ggml_tensor  * a,
3025        float                 eps,
3026        bool                  inplace) {
3027    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3028
3029    ggml_set_op_params(result, &eps, sizeof(eps));
3030
3031    result->op     = GGML_OP_NORM;
3032    result->src[0] = a;
3033
3034    return result;
3035}
3036
3037struct ggml_tensor * ggml_norm(
3038        struct ggml_context * ctx,
3039        struct ggml_tensor  * a,
3040        float                 eps) {
3041    return ggml_norm_impl(ctx, a, eps, false);
3042}
3043
3044struct ggml_tensor * ggml_norm_inplace(
3045        struct ggml_context * ctx,
3046        struct ggml_tensor  * a,
3047        float                 eps) {
3048    return ggml_norm_impl(ctx, a, eps, true);
3049}
3050
3051// ggml_rms_norm
3052
3053static struct ggml_tensor * ggml_rms_norm_impl(
3054        struct ggml_context * ctx,
3055        struct ggml_tensor  * a,
3056        float                 eps,
3057        bool                  inplace) {
3058    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3059
3060    ggml_set_op_params(result, &eps, sizeof(eps));
3061
3062    result->op     = GGML_OP_RMS_NORM;
3063    result->src[0] = a;
3064
3065    return result;
3066}
3067
3068struct ggml_tensor * ggml_rms_norm(
3069        struct ggml_context * ctx,
3070        struct ggml_tensor  * a,
3071        float                 eps) {
3072    return ggml_rms_norm_impl(ctx, a, eps, false);
3073}
3074
3075struct ggml_tensor * ggml_rms_norm_inplace(
3076        struct ggml_context * ctx,
3077        struct ggml_tensor  * a,
3078        float                 eps) {
3079    return ggml_rms_norm_impl(ctx, a, eps, true);
3080}
3081
3082// ggml_rms_norm_back
3083
3084struct ggml_tensor * ggml_rms_norm_back(
3085        struct ggml_context * ctx,
3086        struct ggml_tensor  * a,
3087        struct ggml_tensor  * b,
3088        float                 eps) {
3089    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
3090
3091    ggml_set_op_params(result, &eps, sizeof(eps));
3092
3093    result->op     = GGML_OP_RMS_NORM_BACK;
3094    result->src[0] = a;
3095    result->src[1] = b;
3096
3097    return result;
3098}
3099
3100// ggml_group_norm
3101
3102static struct ggml_tensor * ggml_group_norm_impl(
3103        struct ggml_context * ctx,
3104        struct ggml_tensor  * a,
3105        int                   n_groups,
3106        float                 eps,
3107        bool                  inplace) {
3108    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3109
3110    ggml_set_op_params_i32(result, 0, n_groups);
3111    ggml_set_op_params_f32(result, 1, eps);
3112
3113    result->op     = GGML_OP_GROUP_NORM;
3114    result->src[0] = a;
3115
3116    return result;
3117}
3118
3119struct ggml_tensor * ggml_group_norm(
3120        struct ggml_context * ctx,
3121        struct ggml_tensor  * a,
3122        int                   n_groups,
3123        float                 eps) {
3124    return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
3125}
3126
3127struct ggml_tensor * ggml_group_norm_inplace(
3128        struct ggml_context * ctx,
3129        struct ggml_tensor  * a,
3130        int                   n_groups,
3131        float                 eps) {
3132    return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
3133}
3134
3135// ggml_l2_norm
3136
3137static struct ggml_tensor * ggml_l2_norm_impl(
3138        struct ggml_context * ctx,
3139        struct ggml_tensor  * a,
3140        float                 eps,
3141        bool                  inplace) {
3142    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3143
3144    ggml_set_op_params_f32(result, 0, eps);
3145
3146    result->op     = GGML_OP_L2_NORM;
3147    result->src[0] = a;
3148
3149    return result;
3150}
3151
3152struct ggml_tensor * ggml_l2_norm(
3153        struct ggml_context * ctx,
3154        struct ggml_tensor  * a,
3155        float                 eps) {
3156    return ggml_l2_norm_impl(ctx, a, eps, false);
3157}
3158
3159struct ggml_tensor * ggml_l2_norm_inplace(
3160        struct ggml_context * ctx,
3161        struct ggml_tensor  * a,
3162        float                 eps) {
3163    return ggml_l2_norm_impl(ctx, a, eps, true);
3164}
3165
3166// ggml_mul_mat
3167
3168static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3169    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3170
3171    return (t0->ne[0]           == t1->ne[0])  &&
3172           (t1->ne[2]%t0->ne[2] == 0)          && // verify t0 is broadcastable
3173           (t1->ne[3]%t0->ne[3] == 0);
3174}
3175
3176struct ggml_tensor * ggml_mul_mat(
3177        struct ggml_context * ctx,
3178        struct ggml_tensor  * a,
3179        struct ggml_tensor  * b) {
3180    GGML_ASSERT(ggml_can_mul_mat(a, b));
3181    GGML_ASSERT(!ggml_is_transposed(a));
3182
3183    const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
3184    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
3185
3186    result->op     = GGML_OP_MUL_MAT;
3187    result->src[0] = a;
3188    result->src[1] = b;
3189
3190    return result;
3191}
3192
3193void ggml_mul_mat_set_prec(
3194        struct ggml_tensor * a,
3195        enum ggml_prec       prec) {
3196    GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
3197
3198    const int32_t prec_i32 = (int32_t) prec;
3199
3200    ggml_set_op_params_i32(a, 0, prec_i32);
3201}
3202
3203// ggml_mul_mat_id
3204
3205/*
3206    c = ggml_mul_mat_id(ctx, as, b, ids);
3207
3208    as  -> [cols, rows, n_expert]
3209    b   -> [cols, n_expert_used, n_tokens]
3210    ids -> [n_expert_used, n_tokens] (i32)
3211    c   -> [rows, n_expert_used, n_tokens]
3212
3213    in b, n_expert_used can be broadcasted to match the n_expert_used of ids
3214
3215    c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
3216*/
3217struct ggml_tensor * ggml_mul_mat_id(
3218        struct ggml_context * ctx,
3219        struct ggml_tensor  * as,
3220        struct ggml_tensor  * b,
3221        struct ggml_tensor  * ids) {
3222    GGML_ASSERT(!ggml_is_transposed(as));
3223    GGML_ASSERT(ids->type == GGML_TYPE_I32);
3224
3225    GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
3226    GGML_ASSERT(b->ne[3] == 1); // b is 3d
3227    GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
3228    GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
3229    GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
3230    GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
3231
3232    const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
3233    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
3234
3235    result->op     = GGML_OP_MUL_MAT_ID;
3236    result->src[0] = as;
3237    result->src[1] = b;
3238    result->src[2] = ids;
3239
3240    return result;
3241}
3242
3243// ggml_out_prod
3244
3245static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3246    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3247
3248    return (t0->ne[1] == t1->ne[1])   &&
3249           (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
3250           (t1->ne[3]%t0->ne[3] == 0);
3251}
3252
3253struct ggml_tensor * ggml_out_prod(
3254        struct ggml_context * ctx,
3255        struct ggml_tensor  * a,
3256        struct ggml_tensor  * b) {
3257    GGML_ASSERT(ggml_can_out_prod(a, b));
3258    GGML_ASSERT(!ggml_is_transposed(a));
3259
3260    // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
3261    const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
3262    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
3263
3264    result->op     = GGML_OP_OUT_PROD;
3265    result->src[0] = a;
3266    result->src[1] = b;
3267
3268    return result;
3269}
3270
3271// ggml_scale
3272
3273static struct ggml_tensor * ggml_scale_impl(
3274        struct ggml_context * ctx,
3275        struct ggml_tensor  * a,
3276        float                 s,
3277        float                 b,
3278        bool                  inplace) {
3279    GGML_ASSERT(ggml_is_padded_1d(a));
3280
3281    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3282
3283    float params[2] = { s, b };
3284    ggml_set_op_params(result, &params, sizeof(params));
3285
3286    result->op     = GGML_OP_SCALE;
3287    result->src[0] = a;
3288
3289    return result;
3290}
3291
3292struct ggml_tensor * ggml_scale(
3293        struct ggml_context * ctx,
3294        struct ggml_tensor  * a,
3295        float                 s) {
3296    return ggml_scale_impl(ctx, a, s, 0.0, false);
3297}
3298
3299struct ggml_tensor * ggml_scale_inplace(
3300        struct ggml_context * ctx,
3301        struct ggml_tensor  * a,
3302        float                 s) {
3303    return ggml_scale_impl(ctx, a, s, 0.0, true);
3304}
3305
3306struct ggml_tensor * ggml_scale_bias(
3307        struct ggml_context * ctx,
3308        struct ggml_tensor  * a,
3309        float                 s,
3310        float                 b) {
3311    return ggml_scale_impl(ctx, a, s, b, false);
3312}
3313
3314struct ggml_tensor * ggml_scale_bias_inplace(
3315        struct ggml_context * ctx,
3316        struct ggml_tensor  * a,
3317        float                 s,
3318        float                 b) {
3319    return ggml_scale_impl(ctx, a, s, b, true);
3320}
3321
3322// ggml_set
3323
3324static struct ggml_tensor * ggml_set_impl(
3325        struct ggml_context * ctx,
3326        struct ggml_tensor  * a,
3327        struct ggml_tensor  * b,
3328        size_t                nb1,
3329        size_t                nb2,
3330        size_t                nb3,
3331        size_t                offset,
3332        bool                  inplace) {
3333    GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b));
3334
3335    // make a view of the destination
3336    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3337
3338    GGML_ASSERT(offset < (size_t)(1 << 30));
3339    int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
3340    ggml_set_op_params(result, params, sizeof(params));
3341
3342    result->op     = GGML_OP_SET;
3343    result->src[0] = a;
3344    result->src[1] = b;
3345
3346    return result;
3347}
3348
3349struct ggml_tensor * ggml_set(
3350        struct ggml_context * ctx,
3351        struct ggml_tensor  * a,
3352        struct ggml_tensor  * b,
3353        size_t                nb1,
3354        size_t                nb2,
3355        size_t                nb3,
3356        size_t                offset) {
3357    return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
3358}
3359
3360struct ggml_tensor * ggml_set_inplace(
3361        struct ggml_context * ctx,
3362        struct ggml_tensor  * a,
3363        struct ggml_tensor  * b,
3364        size_t                nb1,
3365        size_t                nb2,
3366        size_t                nb3,
3367        size_t                offset) {
3368    return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
3369}
3370
3371struct ggml_tensor * ggml_set_1d(
3372        struct ggml_context * ctx,
3373        struct ggml_tensor  * a,
3374        struct ggml_tensor  * b,
3375        size_t                offset) {
3376    return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false);
3377}
3378
3379struct ggml_tensor * ggml_set_1d_inplace(
3380        struct ggml_context * ctx,
3381        struct ggml_tensor  * a,
3382        struct ggml_tensor  * b,
3383        size_t                offset) {
3384    return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true);
3385}
3386
3387struct ggml_tensor * ggml_set_2d(
3388        struct ggml_context * ctx,
3389        struct ggml_tensor  * a,
3390        struct ggml_tensor  * b,
3391        size_t                nb1,
3392        size_t                offset) {
3393    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
3394}
3395
3396struct ggml_tensor * ggml_set_2d_inplace(
3397        struct ggml_context * ctx,
3398        struct ggml_tensor  * a,
3399        struct ggml_tensor  * b,
3400        size_t                nb1,
3401        size_t                offset) {
3402    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
3403}
3404
3405// ggml_cpy
3406
3407static struct ggml_tensor * ggml_cpy_impl(
3408        struct ggml_context * ctx,
3409        struct ggml_tensor  * a,
3410        struct ggml_tensor  * b) {
3411    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
3412
3413    // make a view of the destination
3414    struct ggml_tensor * result = ggml_view_tensor(ctx, b);
3415    if (strlen(b->name) > 0) {
3416        ggml_format_name(result, "%s (copy of %s)", b->name, a->name);
3417    } else {
3418        ggml_format_name(result, "%s (copy)", a->name);
3419    }
3420
3421    result->op     = GGML_OP_CPY;
3422    result->src[0] = a;
3423    result->src[1] = b;
3424
3425    return result;
3426}
3427
3428struct ggml_tensor * ggml_cpy(
3429        struct ggml_context * ctx,
3430        struct ggml_tensor * a,
3431        struct ggml_tensor * b) {
3432    return ggml_cpy_impl(ctx, a, b);
3433}
3434
3435struct ggml_tensor * ggml_cast(
3436        struct ggml_context * ctx,
3437        struct ggml_tensor  * a,
3438        enum   ggml_type      type) {
3439    struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
3440    ggml_format_name(result, "%s (copy)", a->name);
3441
3442    result->op     = GGML_OP_CPY;
3443    result->src[0] = a;
3444    result->src[1] = result; // note: this self-reference might seem redundant, but it's actually needed by some
3445                             //       backends for consistency with ggml_cpy_impl() above
3446
3447    return result;
3448}
3449
3450// ggml_cont
3451
3452static struct ggml_tensor * ggml_cont_impl(
3453        struct ggml_context * ctx,
3454        struct ggml_tensor  * a) {
3455    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
3456    ggml_format_name(result, "%s (cont)", a->name);
3457
3458    result->op     = GGML_OP_CONT;
3459    result->src[0] = a;
3460
3461    return result;
3462}
3463
3464struct ggml_tensor * ggml_cont(
3465        struct ggml_context * ctx,
3466        struct ggml_tensor * a) {
3467    return ggml_cont_impl(ctx, a);
3468}
3469
3470// make contiguous, with new shape
3471GGML_API struct ggml_tensor * ggml_cont_1d(
3472        struct ggml_context * ctx,
3473        struct ggml_tensor  * a,
3474        int64_t               ne0) {
3475    return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
3476}
3477
3478GGML_API struct ggml_tensor * ggml_cont_2d(
3479        struct ggml_context * ctx,
3480        struct ggml_tensor  * a,
3481        int64_t               ne0,
3482        int64_t               ne1) {
3483    return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
3484}
3485
3486GGML_API struct ggml_tensor * ggml_cont_3d(
3487        struct ggml_context * ctx,
3488        struct ggml_tensor  * a,
3489        int64_t               ne0,
3490        int64_t               ne1,
3491        int64_t               ne2) {
3492    return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
3493}
3494
3495struct ggml_tensor * ggml_cont_4d(
3496        struct ggml_context * ctx,
3497        struct ggml_tensor  * a,
3498        int64_t               ne0,
3499        int64_t               ne1,
3500        int64_t               ne2,
3501        int64_t               ne3) {
3502    GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
3503
3504    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
3505    ggml_format_name(result, "%s (cont)", a->name);
3506
3507    result->op     = GGML_OP_CONT;
3508    result->src[0] = a;
3509
3510    return result;
3511}
3512
3513// ggml_reshape
3514
3515struct ggml_tensor * ggml_reshape(
3516        struct ggml_context * ctx,
3517        struct ggml_tensor * a,
3518        struct ggml_tensor * b) {
3519    GGML_ASSERT(ggml_is_contiguous(a));
3520    // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
3521    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
3522
3523    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0);
3524    ggml_format_name(result, "%s (reshaped)", a->name);
3525
3526    result->op     = GGML_OP_RESHAPE;
3527    result->src[0] = a;
3528
3529    return result;
3530}
3531
3532struct ggml_tensor * ggml_reshape_1d(
3533        struct ggml_context * ctx,
3534        struct ggml_tensor  * a,
3535        int64_t               ne0) {
3536    GGML_ASSERT(ggml_is_contiguous(a));
3537    GGML_ASSERT(ggml_nelements(a) == ne0);
3538
3539    const int64_t ne[1] = { ne0 };
3540    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0);
3541    ggml_format_name(result, "%s (reshaped)", a->name);
3542
3543    result->op     = GGML_OP_RESHAPE;
3544    result->src[0] = a;
3545
3546    return result;
3547}
3548
3549struct ggml_tensor * ggml_reshape_2d(
3550        struct ggml_context * ctx,
3551        struct ggml_tensor  * a,
3552        int64_t               ne0,
3553        int64_t               ne1) {
3554    GGML_ASSERT(ggml_is_contiguous(a));
3555    GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
3556
3557    const int64_t ne[2] = { ne0, ne1 };
3558    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0);
3559    ggml_format_name(result, "%s (reshaped)", a->name);
3560
3561    result->op     = GGML_OP_RESHAPE;
3562    result->src[0] = a;
3563
3564    return result;
3565}
3566
3567struct ggml_tensor * ggml_reshape_3d(
3568        struct ggml_context * ctx,
3569        struct ggml_tensor  * a,
3570        int64_t               ne0,
3571        int64_t               ne1,
3572        int64_t               ne2) {
3573    GGML_ASSERT(ggml_is_contiguous(a));
3574    GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
3575
3576    const int64_t ne[3] = { ne0, ne1, ne2 };
3577    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0);
3578    ggml_format_name(result, "%s (reshaped)", a->name);
3579
3580    result->op     = GGML_OP_RESHAPE;
3581    result->src[0] = a;
3582
3583    return result;
3584}
3585
3586struct ggml_tensor * ggml_reshape_4d(
3587        struct ggml_context * ctx,
3588        struct ggml_tensor  * a,
3589        int64_t               ne0,
3590        int64_t               ne1,
3591        int64_t               ne2,
3592        int64_t               ne3) {
3593    GGML_ASSERT(ggml_is_contiguous(a));
3594    GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
3595
3596    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
3597    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
3598    ggml_format_name(result, "%s (reshaped)", a->name);
3599
3600    result->op     = GGML_OP_RESHAPE;
3601    result->src[0] = a;
3602
3603    return result;
3604}
3605
3606static struct ggml_tensor * ggml_view_impl(
3607        struct ggml_context * ctx,
3608        struct ggml_tensor  * a,
3609        int                   n_dims,
3610        const int64_t       * ne,
3611        size_t                offset) {
3612    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset);
3613    ggml_format_name(result, "%s (view)", a->name);
3614
3615    ggml_set_op_params(result, &offset, sizeof(offset));
3616
3617    result->op     = GGML_OP_VIEW;
3618    result->src[0] = a;
3619
3620    return result;
3621}
3622
3623// ggml_view_1d
3624
3625struct ggml_tensor * ggml_view_1d(
3626        struct ggml_context * ctx,
3627        struct ggml_tensor  * a,
3628        int64_t               ne0,
3629        size_t                offset) {
3630    struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset);
3631
3632    return result;
3633}
3634
3635// ggml_view_2d
3636
3637struct ggml_tensor * ggml_view_2d(
3638        struct ggml_context * ctx,
3639        struct ggml_tensor  * a,
3640        int64_t               ne0,
3641        int64_t               ne1,
3642        size_t                nb1,
3643        size_t                offset) {
3644    const int64_t ne[2] = { ne0, ne1 };
3645
3646    struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset);
3647
3648    result->nb[1] = nb1;
3649    result->nb[2] = result->nb[1]*ne1;
3650    result->nb[3] = result->nb[2];
3651
3652    return result;
3653}
3654
3655// ggml_view_3d
3656
3657struct ggml_tensor * ggml_view_3d(
3658        struct ggml_context * ctx,
3659        struct ggml_tensor  * a,
3660        int64_t               ne0,
3661        int64_t               ne1,
3662        int64_t               ne2,
3663        size_t                nb1,
3664        size_t                nb2,
3665        size_t                offset) {
3666    const int64_t ne[3] = { ne0, ne1, ne2 };
3667
3668    struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset);
3669
3670    result->nb[1] = nb1;
3671    result->nb[2] = nb2;
3672    result->nb[3] = result->nb[2]*ne2;
3673
3674    return result;
3675}
3676
3677// ggml_view_4d
3678
3679struct ggml_tensor * ggml_view_4d(
3680        struct ggml_context * ctx,
3681        struct ggml_tensor  * a,
3682        int64_t               ne0,
3683        int64_t               ne1,
3684        int64_t               ne2,
3685        int64_t               ne3,
3686        size_t                nb1,
3687        size_t                nb2,
3688        size_t                nb3,
3689        size_t                offset) {
3690    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
3691
3692    struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset);
3693
3694    result->nb[1] = nb1;
3695    result->nb[2] = nb2;
3696    result->nb[3] = nb3;
3697
3698    return result;
3699}
3700
3701// ggml_permute
3702
3703struct ggml_tensor * ggml_permute(
3704        struct ggml_context * ctx,
3705        struct ggml_tensor  * a,
3706        int                   axis0,
3707        int                   axis1,
3708        int                   axis2,
3709        int                   axis3) {
3710    GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
3711    GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
3712    GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
3713    GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
3714
3715    GGML_ASSERT(axis0 != axis1);
3716    GGML_ASSERT(axis0 != axis2);
3717    GGML_ASSERT(axis0 != axis3);
3718    GGML_ASSERT(axis1 != axis2);
3719    GGML_ASSERT(axis1 != axis3);
3720    GGML_ASSERT(axis2 != axis3);
3721
3722    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
3723    ggml_format_name(result, "%s (permuted)", a->name);
3724
3725    int ne[GGML_MAX_DIMS];
3726    int nb[GGML_MAX_DIMS];
3727
3728    ne[axis0] = a->ne[0];
3729    ne[axis1] = a->ne[1];
3730    ne[axis2] = a->ne[2];
3731    ne[axis3] = a->ne[3];
3732
3733    nb[axis0] = a->nb[0];
3734    nb[axis1] = a->nb[1];
3735    nb[axis2] = a->nb[2];
3736    nb[axis3] = a->nb[3];
3737
3738    result->ne[0] = ne[0];
3739    result->ne[1] = ne[1];
3740    result->ne[2] = ne[2];
3741    result->ne[3] = ne[3];
3742
3743    result->nb[0] = nb[0];
3744    result->nb[1] = nb[1];
3745    result->nb[2] = nb[2];
3746    result->nb[3] = nb[3];
3747
3748    result->op     = GGML_OP_PERMUTE;
3749    result->src[0] = a;
3750
3751    int32_t params[] = { axis0, axis1, axis2, axis3 };
3752    ggml_set_op_params(result, params, sizeof(params));
3753
3754    return result;
3755}
3756
3757// ggml_transpose
3758
3759struct ggml_tensor * ggml_transpose(
3760        struct ggml_context * ctx,
3761        struct ggml_tensor  * a) {
3762    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
3763    ggml_format_name(result, "%s (transposed)", a->name);
3764
3765    result->ne[0] = a->ne[1];
3766    result->ne[1] = a->ne[0];
3767
3768    result->nb[0] = a->nb[1];
3769    result->nb[1] = a->nb[0];
3770
3771    result->op     = GGML_OP_TRANSPOSE;
3772    result->src[0] = a;
3773
3774    return result;
3775}
3776
3777// ggml_get_rows
3778
3779struct ggml_tensor * ggml_get_rows(
3780        struct ggml_context * ctx,
3781        struct ggml_tensor  * a,
3782        struct ggml_tensor  * b) {
3783    GGML_ASSERT(a->ne[2] == b->ne[1]);
3784    GGML_ASSERT(a->ne[3] == b->ne[2]);
3785    GGML_ASSERT(b->ne[3] == 1);
3786    GGML_ASSERT(b->type == GGML_TYPE_I32);
3787
3788    // TODO: implement non F32 return
3789    enum ggml_type type = GGML_TYPE_F32;
3790    if (a->type == GGML_TYPE_I32) {
3791        type = a->type;
3792    }
3793    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
3794
3795    result->op     = GGML_OP_GET_ROWS;
3796    result->src[0] = a;
3797    result->src[1] = b;
3798
3799    return result;
3800}
3801
3802// ggml_get_rows_back
3803
3804struct ggml_tensor * ggml_get_rows_back(
3805        struct ggml_context * ctx,
3806        struct ggml_tensor  * a,
3807        struct ggml_tensor  * b,
3808        struct ggml_tensor  * c) {
3809    GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
3810    GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0]));
3811
3812    // TODO: implement non F32 return
3813    //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
3814    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]);
3815
3816    result->op     = GGML_OP_GET_ROWS_BACK;
3817    result->src[0] = a;
3818    result->src[1] = b;
3819
3820    return result;
3821}
3822
3823// ggml_set_rows
3824
3825struct ggml_tensor * ggml_set_rows(
3826        struct ggml_context * ctx,
3827        struct ggml_tensor  * a,
3828        struct ggml_tensor  * b,
3829        struct ggml_tensor  * c) {
3830    GGML_ASSERT(a->ne[0] == b->ne[0]);
3831    GGML_ASSERT(a->ne[2] == b->ne[2]);
3832    GGML_ASSERT(a->ne[3] == b->ne[3]);
3833    GGML_ASSERT(b->ne[1] == c->ne[0]);
3834    GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
3835    GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
3836    GGML_ASSERT(c->ne[3] == 1);
3837    GGML_ASSERT(b->type == GGML_TYPE_F32);
3838    GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32);
3839
3840    GGML_ASSERT(ggml_is_contiguous_rows(a));
3841    GGML_ASSERT(ggml_is_contiguous_rows(b));
3842
3843    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
3844
3845    result->op     = GGML_OP_SET_ROWS;
3846    result->src[0] = b;
3847    result->src[1] = c;
3848    result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931)
3849
3850    return result;
3851}
3852
3853// ggml_diag
3854
3855struct ggml_tensor * ggml_diag(
3856        struct ggml_context * ctx,
3857        struct ggml_tensor  * a) {
3858    GGML_ASSERT(a->ne[1] == 1);
3859
3860    const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };
3861    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne);
3862
3863    result->op     = GGML_OP_DIAG;
3864    result->src[0] = a;
3865
3866    return result;
3867}
3868
3869// ggml_diag_mask_inf
3870
3871static struct ggml_tensor * ggml_diag_mask_inf_impl(
3872        struct ggml_context * ctx,
3873        struct ggml_tensor  * a,
3874        int                   n_past,
3875        bool                  inplace) {
3876    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3877
3878    int32_t params[] = { n_past };
3879    ggml_set_op_params(result, params, sizeof(params));
3880
3881    result->op     = GGML_OP_DIAG_MASK_INF;
3882    result->src[0] = a;
3883
3884    return result;
3885}
3886
3887struct ggml_tensor * ggml_diag_mask_inf(
3888        struct ggml_context * ctx,
3889        struct ggml_tensor  * a,
3890        int                   n_past) {
3891    return ggml_diag_mask_inf_impl(ctx, a, n_past, false);
3892}
3893
3894struct ggml_tensor * ggml_diag_mask_inf_inplace(
3895        struct ggml_context * ctx,
3896        struct ggml_tensor  * a,
3897        int                   n_past) {
3898    return ggml_diag_mask_inf_impl(ctx, a, n_past, true);
3899}
3900
3901// ggml_diag_mask_zero
3902
3903static struct ggml_tensor * ggml_diag_mask_zero_impl(
3904        struct ggml_context * ctx,
3905        struct ggml_tensor  * a,
3906        int                   n_past,
3907        bool                  inplace) {
3908    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3909
3910    int32_t params[] = { n_past };
3911    ggml_set_op_params(result, params, sizeof(params));
3912
3913    result->op     = GGML_OP_DIAG_MASK_ZERO;
3914    result->src[0] = a;
3915
3916    return result;
3917}
3918
3919struct ggml_tensor * ggml_diag_mask_zero(
3920        struct ggml_context * ctx,
3921        struct ggml_tensor  * a,
3922        int                   n_past) {
3923    return ggml_diag_mask_zero_impl(ctx, a, n_past, false);
3924}
3925
3926struct ggml_tensor * ggml_diag_mask_zero_inplace(
3927        struct ggml_context * ctx,
3928        struct ggml_tensor  * a,
3929        int                   n_past) {
3930    return ggml_diag_mask_zero_impl(ctx, a, n_past, true);
3931}
3932
3933// ggml_soft_max
3934
3935static struct ggml_tensor * ggml_soft_max_impl(
3936        struct ggml_context * ctx,
3937        struct ggml_tensor  * a,
3938        struct ggml_tensor  * mask,
3939        float                 scale,
3940        float                 max_bias,
3941        bool                  inplace) {
3942    GGML_ASSERT(ggml_is_contiguous(a));
3943
3944    if (mask) {
3945        GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
3946        GGML_ASSERT(ggml_is_contiguous(mask));
3947        GGML_ASSERT(mask->ne[0] == a->ne[0]);
3948        GGML_ASSERT(mask->ne[1] >= a->ne[1]);
3949        GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
3950        GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
3951    }
3952
3953    if (max_bias > 0.0f) {
3954        GGML_ASSERT(mask);
3955    }
3956
3957    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3958
3959    float params[] = { scale, max_bias };
3960    ggml_set_op_params(result, params, sizeof(params));
3961
3962    result->op     = GGML_OP_SOFT_MAX;
3963    result->src[0] = a;
3964    result->src[1] = mask;
3965
3966    return result;
3967}
3968
3969struct ggml_tensor * ggml_soft_max(
3970        struct ggml_context * ctx,
3971        struct ggml_tensor  * a) {
3972    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
3973}
3974
3975struct ggml_tensor * ggml_soft_max_inplace(
3976        struct ggml_context * ctx,
3977        struct ggml_tensor  * a) {
3978    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
3979}
3980
3981struct ggml_tensor * ggml_soft_max_ext(
3982        struct ggml_context * ctx,
3983        struct ggml_tensor  * a,
3984        struct ggml_tensor  * mask,
3985        float                 scale,
3986        float                 max_bias) {
3987    return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
3988}
3989
3990struct ggml_tensor * ggml_soft_max_ext_inplace(
3991        struct ggml_context * ctx,
3992        struct ggml_tensor  * a,
3993        struct ggml_tensor  * mask,
3994        float                 scale,
3995        float                 max_bias) {
3996    return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true);
3997}
3998
3999void ggml_soft_max_add_sinks(
4000        struct ggml_tensor * a,
4001        struct ggml_tensor * sinks) {
4002    if (!sinks) {
4003        a->src[2] = NULL;
4004        return;
4005    }
4006
4007    GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
4008    GGML_ASSERT(a->src[2] == NULL);
4009    GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
4010    GGML_ASSERT(sinks->type == GGML_TYPE_F32);
4011
4012    a->src[2] = sinks;
4013}
4014
4015// ggml_soft_max_ext_back
4016
4017static struct ggml_tensor * ggml_soft_max_ext_back_impl(
4018        struct ggml_context * ctx,
4019        struct ggml_tensor  * a,
4020        struct ggml_tensor  * b,
4021        float                 scale,
4022        float                 max_bias,
4023        bool                  inplace) {
4024    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4025
4026    result->op     = GGML_OP_SOFT_MAX_BACK;
4027    result->src[0] = a;
4028    result->src[1] = b;
4029
4030    memcpy((float *) result->op_params + 0, &scale,    sizeof(float));
4031    memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
4032
4033    return result;
4034}
4035
4036struct ggml_tensor * ggml_soft_max_ext_back(
4037        struct ggml_context * ctx,
4038        struct ggml_tensor  * a,
4039        struct ggml_tensor  * b,
4040        float                 scale,
4041        float                 max_bias) {
4042    return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
4043}
4044
4045struct ggml_tensor * ggml_soft_max_ext_back_inplace(
4046        struct ggml_context * ctx,
4047        struct ggml_tensor  * a,
4048        struct ggml_tensor  * b,
4049        float                 scale,
4050        float                 max_bias) {
4051    return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
4052}
4053
4054// ggml_rope
4055
4056static struct ggml_tensor * ggml_rope_impl(
4057        struct ggml_context * ctx,
4058        struct ggml_tensor  * a,
4059        struct ggml_tensor  * b,
4060        struct ggml_tensor  * c,
4061        int                   n_dims,
4062        int                   sections[GGML_MROPE_SECTIONS],
4063        int                   mode,
4064        int                   n_ctx_orig,
4065        float                 freq_base,
4066        float                 freq_scale,
4067        float                 ext_factor,
4068        float                 attn_factor,
4069        float                 beta_fast,
4070        float                 beta_slow,
4071        bool                  inplace) {
4072    GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
4073
4074    GGML_ASSERT(ggml_is_vector(b));
4075    GGML_ASSERT(b->type == GGML_TYPE_I32);
4076
4077    bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
4078    if (mrope_used) {
4079        GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
4080    } else {
4081        GGML_ASSERT(a->ne[2] == b->ne[0]);
4082    }
4083
4084    if (c) {
4085        GGML_ASSERT(c->type == GGML_TYPE_F32);
4086        GGML_ASSERT(c->ne[0] >= n_dims / 2);
4087    }
4088
4089    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4090
4091    int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
4092    memcpy(params +  5, &freq_base,    sizeof(float));
4093    memcpy(params +  6, &freq_scale,   sizeof(float));
4094    memcpy(params +  7, &ext_factor,   sizeof(float));
4095    memcpy(params +  8, &attn_factor,  sizeof(float));
4096    memcpy(params +  9, &beta_fast,    sizeof(float));
4097    memcpy(params + 10, &beta_slow,    sizeof(float));
4098    if (mrope_used && sections) {
4099        memcpy(params + 11, sections,  sizeof(int32_t) * GGML_MROPE_SECTIONS);
4100    } else {
4101        memset(params + 11, 0,         sizeof(int32_t) * GGML_MROPE_SECTIONS);
4102    }
4103    ggml_set_op_params(result, params, sizeof(params));
4104
4105    result->op     = GGML_OP_ROPE;
4106    result->src[0] = a;
4107    result->src[1] = b;
4108    result->src[2] = c;
4109
4110    return result;
4111}
4112
4113struct ggml_tensor * ggml_rope(
4114        struct ggml_context * ctx,
4115        struct ggml_tensor  * a,
4116        struct ggml_tensor  * b,
4117        int                   n_dims,
4118        int                   mode) {
4119    return ggml_rope_impl(
4120        ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
4121    );
4122}
4123
4124struct ggml_tensor * ggml_rope_multi(
4125        struct ggml_context * ctx,
4126        struct ggml_tensor  * a,
4127        struct ggml_tensor  * b,
4128        struct ggml_tensor  * c,
4129        int                   n_dims,
4130        int                   sections[GGML_MROPE_SECTIONS],
4131        int                   mode,
4132        int                   n_ctx_orig,
4133        float                 freq_base,
4134        float                 freq_scale,
4135        float                 ext_factor,
4136        float                 attn_factor,
4137        float                 beta_fast,
4138        float                 beta_slow) {
4139    return ggml_rope_impl(
4140        ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
4141        ext_factor, attn_factor, beta_fast, beta_slow, false
4142    );
4143}
4144
4145struct ggml_tensor * ggml_rope_multi_inplace(
4146        struct ggml_context * ctx,
4147        struct ggml_tensor  * a,
4148        struct ggml_tensor  * b,
4149        struct ggml_tensor  * c,
4150        int                   n_dims,
4151        int                   sections[GGML_MROPE_SECTIONS],
4152        int                   mode,
4153        int                   n_ctx_orig,
4154        float                 freq_base,
4155        float                 freq_scale,
4156        float                 ext_factor,
4157        float                 attn_factor,
4158        float                 beta_fast,
4159        float                 beta_slow) {
4160    return ggml_rope_impl(
4161        ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
4162        ext_factor, attn_factor, beta_fast, beta_slow, true
4163    );
4164}
4165
4166struct ggml_tensor * ggml_rope_inplace(
4167        struct ggml_context * ctx,
4168        struct ggml_tensor  * a,
4169        struct ggml_tensor  * b,
4170        int                   n_dims,
4171        int                   mode) {
4172    return ggml_rope_impl(
4173        ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
4174    );
4175}
4176
4177struct ggml_tensor * ggml_rope_ext(
4178        struct ggml_context * ctx,
4179        struct ggml_tensor  * a,
4180        struct ggml_tensor  * b,
4181        struct ggml_tensor  * c,
4182        int                   n_dims,
4183        int                   mode,
4184        int                   n_ctx_orig,
4185        float                 freq_base,
4186        float                 freq_scale,
4187        float                 ext_factor,
4188        float                 attn_factor,
4189        float                 beta_fast,
4190        float                 beta_slow) {
4191    return ggml_rope_impl(
4192        ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
4193        ext_factor, attn_factor, beta_fast, beta_slow, false
4194    );
4195}
4196
4197struct ggml_tensor * ggml_rope_ext_inplace(
4198        struct ggml_context * ctx,
4199        struct ggml_tensor  * a,
4200        struct ggml_tensor  * b,
4201        struct ggml_tensor  * c,
4202        int                   n_dims,
4203        int                   mode,
4204        int                   n_ctx_orig,
4205        float                 freq_base,
4206        float                 freq_scale,
4207        float                 ext_factor,
4208        float                 attn_factor,
4209        float                 beta_fast,
4210        float                 beta_slow) {
4211    return ggml_rope_impl(
4212        ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
4213        ext_factor, attn_factor, beta_fast, beta_slow, true
4214    );
4215}
4216
4217struct ggml_tensor * ggml_rope_custom(
4218        struct ggml_context * ctx,
4219        struct ggml_tensor  * a,
4220        struct ggml_tensor  * b,
4221        int                   n_dims,
4222        int                   mode,
4223        int                   n_ctx_orig,
4224        float                 freq_base,
4225        float                 freq_scale,
4226        float                 ext_factor,
4227        float                 attn_factor,
4228        float                 beta_fast,
4229        float                 beta_slow) {
4230    return ggml_rope_impl(
4231        ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
4232        ext_factor, attn_factor, beta_fast, beta_slow, false
4233    );
4234}
4235
4236struct ggml_tensor * ggml_rope_custom_inplace(
4237        struct ggml_context * ctx,
4238        struct ggml_tensor  * a,
4239        struct ggml_tensor  * b,
4240        int                   n_dims,
4241        int                   mode,
4242        int                   n_ctx_orig,
4243        float                 freq_base,
4244        float                 freq_scale,
4245        float                 ext_factor,
4246        float                 attn_factor,
4247        float                 beta_fast,
4248        float                 beta_slow) {
4249    return ggml_rope_impl(
4250        ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
4251        ext_factor, attn_factor, beta_fast, beta_slow, true
4252    );
4253}
4254
4255// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
4256// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
4257static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
4258    return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
4259}
4260
4261void ggml_rope_yarn_corr_dims(
4262    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
4263) {
4264    // start and end correction dims
4265    float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
4266    float end   =  ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
4267    dims[0] = MAX(0, start);
4268    dims[1] = MIN(n_dims - 1, end);
4269}
4270
4271// ggml_rope_back
4272
4273struct ggml_tensor * ggml_rope_ext_back(
4274        struct ggml_context * ctx,
4275        struct ggml_tensor  * a,
4276        struct ggml_tensor  * b,
4277        struct ggml_tensor  * c,
4278        int                   n_dims,
4279        int                   mode,
4280        int                   n_ctx_orig,
4281        float                 freq_base,
4282        float                 freq_scale,
4283        float                 ext_factor,
4284        float                 attn_factor,
4285        float                 beta_fast,
4286        float                 beta_slow) {
4287    struct ggml_tensor * result = ggml_rope_ext(
4288        ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
4289    result->op = GGML_OP_ROPE_BACK;
4290    return result;
4291}
4292
4293struct ggml_tensor * ggml_rope_multi_back(
4294        struct ggml_context * ctx,
4295        struct ggml_tensor  * a,
4296        struct ggml_tensor  * b,
4297        struct ggml_tensor  * c,
4298        int                   n_dims,
4299        int                   sections[4],
4300        int                   mode,
4301        int                   n_ctx_orig,
4302        float                 freq_base,
4303        float                 freq_scale,
4304        float                 ext_factor,
4305        float                 attn_factor,
4306        float                 beta_fast,
4307        float                 beta_slow) {
4308    struct ggml_tensor * result = ggml_rope_multi(
4309        ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
4310    result->op = GGML_OP_ROPE_BACK;
4311    return result;
4312}
4313// ggml_clamp
4314
4315struct ggml_tensor * ggml_clamp(
4316        struct ggml_context * ctx,
4317        struct ggml_tensor  * a,
4318        float                 min,
4319        float                 max) {
4320    // TODO: when implement backward, fix this:
4321    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
4322
4323    float params[] = { min, max };
4324    ggml_set_op_params(result, params, sizeof(params));
4325
4326    result->op     = GGML_OP_CLAMP;
4327    result->src[0] = a;
4328
4329    return result;
4330}
4331
4332static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
4333    return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
4334}
4335
4336// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
4337// a: [OC๏ผŒIC, KH, KW]
4338// b: [N, IC, IH, IW]
4339// result: [N, OH, OW, IC*KH*KW]
4340struct ggml_tensor * ggml_im2col(
4341        struct ggml_context * ctx,
4342        struct ggml_tensor  * a,
4343        struct ggml_tensor  * b,
4344        int                   s0,
4345        int                   s1,
4346        int                   p0,
4347        int                   p1,
4348        int                   d0,
4349        int                   d1,
4350        bool                  is_2D,
4351        enum ggml_type        dst_type) {
4352    if (is_2D) {
4353        GGML_ASSERT(a->ne[2] == b->ne[2]);
4354    } else {
4355        //GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
4356        GGML_ASSERT(b->ne[1] == a->ne[1]);
4357        GGML_ASSERT(b->ne[3] == 1);
4358    }
4359
4360    const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
4361    const int64_t OW =         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4362
4363    GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a");
4364    GGML_ASSERT((OW > 0)           && "b too small compared to a");
4365
4366    const int64_t ne[4] = {
4367        is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
4368        OW,
4369        is_2D ? OH : b->ne[2],
4370        is_2D ?      b->ne[3] : 1,
4371    };
4372
4373    struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
4374    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
4375    ggml_set_op_params(result, params, sizeof(params));
4376
4377    result->op     = GGML_OP_IM2COL;
4378    result->src[0] = a;
4379    result->src[1] = b;
4380
4381    return result;
4382}
4383
4384struct ggml_tensor * ggml_im2col_back(
4385        struct ggml_context * ctx,
4386        struct ggml_tensor  * a,
4387        struct ggml_tensor  * b,
4388        int64_t             * ne,
4389        int                   s0,
4390        int                   s1,
4391        int                   p0,
4392        int                   p1,
4393        int                   d0,
4394        int                   d1,
4395        bool                  is_2D) {
4396    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4397    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
4398    ggml_set_op_params(result, params, sizeof(params));
4399
4400    result->op     = GGML_OP_IM2COL_BACK;
4401    result->src[0] = a;
4402    result->src[1] = b;
4403
4404    return result;
4405}
4406
4407// ggml_conv_1d
4408
4409struct ggml_tensor * ggml_conv_1d(
4410        struct ggml_context * ctx,
4411        struct ggml_tensor  * a,
4412        struct ggml_tensor  * b,
4413        int                   s0,
4414        int                   p0,
4415        int                   d0) {
4416    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
4417
4418    struct ggml_tensor * result =
4419        ggml_mul_mat(ctx,
4420                ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
4421                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2]));                    // [OC๏ผŒIC, K] => [OC, IC * K]
4422
4423    result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
4424
4425    return result;
4426}
4427
4428// ggml_conv_1d_ph
4429
4430struct ggml_tensor* ggml_conv_1d_ph(
4431        struct ggml_context * ctx,
4432        struct ggml_tensor  * a,
4433        struct ggml_tensor  * b,
4434        int                   s,
4435        int                   d) {
4436    return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
4437}
4438
4439// ggml_conv_1d_dw
4440
4441struct ggml_tensor * ggml_conv_1d_dw(
4442        struct ggml_context * ctx,
4443        struct ggml_tensor  * a,
4444        struct ggml_tensor  * b,
4445        int                   s0,
4446        int                   p0,
4447        int                   d0) {
4448    struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
4449
4450    struct ggml_tensor * im2col = ggml_im2col(ctx, a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
4451
4452    struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
4453
4454    result = ggml_reshape_3d(ctx, result, result->ne[0], result->ne[2], 1);
4455
4456    return result;
4457}
4458
4459// ggml_conv_1d_dw_ph
4460
4461struct ggml_tensor * ggml_conv_1d_dw_ph(
4462        struct ggml_context * ctx,
4463        struct ggml_tensor  * a,
4464        struct ggml_tensor  * b,
4465        int                   s0,
4466        int                   d0) {
4467    return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
4468}
4469
4470// ggml_conv_transpose_1d
4471
4472static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
4473    return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
4474}
4475
4476GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
4477        struct ggml_context * ctx,
4478        struct ggml_tensor  * a,
4479        struct ggml_tensor  * b,
4480        int                   s0,
4481        int                   p0,
4482        int                   d0) {
4483    GGML_ASSERT(ggml_is_matrix(b));
4484    GGML_ASSERT(a->ne[2] == b->ne[1]);
4485    GGML_ASSERT(a->ne[3] == 1);
4486
4487    GGML_ASSERT(p0 == 0);
4488    GGML_ASSERT(d0 == 1);
4489
4490    const int64_t ne[4] = {
4491        ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
4492        a->ne[1], b->ne[2], 1,
4493    };
4494    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4495
4496    int32_t params[] = { s0, p0, d0 };
4497    ggml_set_op_params(result, params, sizeof(params));
4498
4499    result->op     = GGML_OP_CONV_TRANSPOSE_1D;
4500    result->src[0] = a;
4501    result->src[1] = b;
4502
4503    return result;
4504}
4505
4506// ggml_conv_2d
4507
4508// a: [OC๏ผŒIC, KH, KW]
4509// b: [N, IC, IH, IW]
4510// result: [N, OC, OH, OW]
4511struct ggml_tensor * ggml_conv_2d(
4512        struct ggml_context * ctx,
4513        struct ggml_tensor  * a,
4514        struct ggml_tensor  * b,
4515        int                   s0,
4516        int                   s1,
4517        int                   p0,
4518        int                   p1,
4519        int                   d0,
4520        int                   d1) {
4521    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
4522
4523    struct ggml_tensor * result =
4524        ggml_mul_mat(ctx,
4525                ggml_reshape_2d(ctx, im2col, im2col->ne[0],  im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
4526                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]),  a->ne[3]));                       // [OC๏ผŒIC, KH, KW] => [OC, IC * KH * KW]
4527
4528    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]
4529    result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]
4530
4531
4532    return result;
4533}
4534
4535// a: [OC*IC, KD, KH, KW]
4536// b: [N*IC, ID, IH, IW]
4537// result: [N*OD, OH, OW, IC * KD * KH * KW]
4538struct ggml_tensor * ggml_im2col_3d(
4539        struct ggml_context * ctx,
4540        struct ggml_tensor  * a,
4541        struct ggml_tensor  * b,
4542        int64_t               IC,
4543        int                   s0, // stride width
4544        int                   s1, // stride height
4545        int                   s2, // stride depth
4546        int                   p0, // padding width
4547        int                   p1, // padding height
4548        int                   p2, // padding depth
4549        int                   d0, // dilation width
4550        int                   d1, // dilation height
4551        int                   d2, // dilation depth
4552        enum ggml_type        dst_type) {
4553    const int64_t N = b->ne[3] / IC;
4554    const int64_t ID = b->ne[2];
4555    const int64_t IH = b->ne[1];
4556    const int64_t IW = b->ne[0];
4557
4558    const int64_t OC = a->ne[3] / IC;
4559    UNUSED(OC);
4560    const int64_t KD = a->ne[2];
4561    const int64_t KH = a->ne[1];
4562    const int64_t KW = a->ne[0];
4563    const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
4564    const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1);
4565    const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0);
4566
4567    GGML_ASSERT((OD > 0)  && "b too small compared to a");
4568    GGML_ASSERT((OH > 0)  && "b too small compared to a");
4569    GGML_ASSERT((OW > 0)  && "b too small compared to a");
4570
4571
4572    const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};
4573
4574    struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
4575    int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};
4576    ggml_set_op_params(result, params, sizeof(params));
4577
4578    result->op     = GGML_OP_IM2COL_3D;
4579    result->src[0] = a;
4580    result->src[1] = b;
4581
4582    return result;
4583}
4584
4585// a: [OC*IC, KD, KH, KW]
4586// b: [N*IC, ID, IH, IW]
4587// result: [N*OC, OD, OH, OW]
4588struct ggml_tensor * ggml_conv_3d(
4589        struct ggml_context * ctx,
4590        struct ggml_tensor  * a,
4591        struct ggml_tensor  * b,
4592        int64_t               IC,
4593        int                   s0, // stride width
4594        int                   s1, // stride height
4595        int                   s2, // stride depth
4596        int                   p0, // padding width
4597        int                   p1, // padding height
4598        int                   p2, // padding depth
4599        int                   d0, // dilation width
4600        int                   d1, // dilation height
4601        int                   d2  // dilation depth
4602        ) {
4603    struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]
4604
4605    int64_t OC = a->ne[3] / IC;
4606    int64_t N = b->ne[3] / IC;
4607    struct ggml_tensor * result =
4608        ggml_mul_mat(ctx,
4609                ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]
4610                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC));                          // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]
4611
4612    int64_t OD = im2col->ne[3] / N;
4613    result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]
4614    result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]
4615    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]
4616
4617    return result;
4618}
4619
4620// ggml_conv_2d_sk_p0
4621
4622struct ggml_tensor * ggml_conv_2d_sk_p0(
4623        struct ggml_context * ctx,
4624        struct ggml_tensor  * a,
4625        struct ggml_tensor  * b) {
4626    return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);
4627}
4628
4629// ggml_conv_2d_s1_ph
4630
4631struct ggml_tensor * ggml_conv_2d_s1_ph(
4632        struct ggml_context * ctx,
4633        struct ggml_tensor  * a,
4634        struct ggml_tensor  * b) {
4635    return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
4636}
4637
4638// ggml_conv_2d_dw
4639
4640struct ggml_tensor * ggml_conv_2d_dw(
4641        struct ggml_context * ctx,
4642        struct ggml_tensor  * a,
4643        struct ggml_tensor  * b,
4644        int                   s0,
4645        int                   s1,
4646        int                   p0,
4647        int                   p1,
4648        int                   d0,
4649        int                   d1) {
4650    struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
4651    struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
4652                                        ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
4653                                        s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
4654    struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
4655
4656    new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2],  new_a->ne[3], 1);                       // [OC๏ผŒ1, KH, KW] => [1, OC, 1, KH * KW]
4657    struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
4658    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
4659
4660    return result;
4661}
4662
4663// ggml_conv_2d_dw_direct
4664
4665struct ggml_tensor * ggml_conv_2d_dw_direct(
4666        struct ggml_context * ctx,
4667        struct ggml_tensor  * a,
4668        struct ggml_tensor  * b,
4669        int                   stride0,
4670        int                   stride1,
4671        int                   pad0,
4672        int                   pad1,
4673        int                   dilation0,
4674        int                   dilation1) {
4675    GGML_ASSERT(a->ne[2] == 1);
4676    GGML_ASSERT(a->ne[3] == b->ne[2]);
4677    int64_t ne[4];
4678    ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
4679    ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
4680    ne[2] = b->ne[2];
4681    ne[3] = b->ne[3];
4682
4683    struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4684
4685    if (ggml_is_contiguous_channels(b)) {
4686        // Result will be permuted the same way as input (CWHN order)
4687        const int64_t type_size = ggml_type_size(result->type);
4688        GGML_ASSERT(ggml_blck_size(result->type) == 1);
4689        result->nb[0] = result->ne[2] * type_size;
4690        result->nb[1] = result->ne[0] * result->nb[0];
4691        result->nb[2] = type_size;
4692    }
4693
4694    int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
4695    ggml_set_op_params(result, params, sizeof(params));
4696
4697    result->op     = GGML_OP_CONV_2D_DW;
4698    result->src[0] = a;
4699    result->src[1] = b;
4700    return result;
4701}
4702
4703// ggml_conv_2d_direct
4704
4705struct ggml_tensor * ggml_conv_2d_direct(
4706        struct ggml_context * ctx,
4707        struct ggml_tensor  * a,   // convolution kernel [KW, KH, IC, OC]
4708        struct ggml_tensor  * b,   // input data [W, H, C, N]
4709        int                   s0,  // stride dimension 0
4710        int                   s1,  // stride dimension 1
4711        int                   p0,  // padding dimension 0
4712        int                   p1,  // padding dimension 1
4713        int                   d0,  // dilation dimension 0
4714        int                   d1) {// dilation dimension 1
4715
4716    GGML_ASSERT(a->ne[2] == b->ne[2]);
4717    //GGML_ASSERT(a->type == b->type);
4718
4719    int64_t ne[4];
4720    ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4721    ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4722    ne[2] = a->ne[3];
4723    ne[3] = b->ne[3];
4724
4725    struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4726
4727    ggml_set_op_params_i32(result, 0, s0);
4728    ggml_set_op_params_i32(result, 1, s1);
4729    ggml_set_op_params_i32(result, 2, p0);
4730    ggml_set_op_params_i32(result, 3, p1);
4731    ggml_set_op_params_i32(result, 4, d0);
4732    ggml_set_op_params_i32(result, 5, d1);
4733
4734    result->op = GGML_OP_CONV_2D;
4735    result->src[0] = a;
4736    result->src[1] = b;
4737
4738    return result;
4739}
4740
4741// ggml_conv_3d_direct
4742
4743struct ggml_tensor * ggml_conv_3d_direct(
4744        struct ggml_context * ctx,
4745        struct ggml_tensor  * a,
4746        struct ggml_tensor  * b,
4747        int                   s0,
4748        int                   s1,
4749        int                   s2,
4750        int                   p0,
4751        int                   p1,
4752        int                   p2,
4753        int                   d0,
4754        int                   d1,
4755        int                   d2,
4756        int                   c,
4757        int                   n,
4758        int                   oc) {
4759
4760    GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
4761    GGML_ASSERT(b->ne[3] == (int64_t) c * n);
4762
4763    int64_t ne[4];
4764    ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4765    ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4766    ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
4767    ne[3] = (int64_t) oc * n;
4768
4769    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4770
4771    ggml_set_op_params_i32(result, 0,  s0);
4772    ggml_set_op_params_i32(result, 1,  s1);
4773    ggml_set_op_params_i32(result, 2,  s2);
4774    ggml_set_op_params_i32(result, 3,  p0);
4775    ggml_set_op_params_i32(result, 4,  p1);
4776    ggml_set_op_params_i32(result, 5,  p2);
4777    ggml_set_op_params_i32(result, 6,  d0);
4778    ggml_set_op_params_i32(result, 7,  d1);
4779    ggml_set_op_params_i32(result, 8,  d2);
4780    ggml_set_op_params_i32(result, 9,  c);
4781    ggml_set_op_params_i32(result, 10, n);
4782    ggml_set_op_params_i32(result, 11, oc);
4783
4784    result->op = GGML_OP_CONV_3D;
4785    result->src[0] = a;
4786    result->src[1] = b;
4787
4788    return result;
4789}
4790
4791// ggml_conv_transpose_2d_p0
4792
4793static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
4794    return (ins - 1) * s - 2 * p + ks;
4795}
4796
4797struct ggml_tensor * ggml_conv_transpose_2d_p0(
4798        struct ggml_context * ctx,
4799        struct ggml_tensor  * a,
4800        struct ggml_tensor  * b,
4801        int                   stride) {
4802    GGML_ASSERT(a->ne[3] == b->ne[2]);
4803
4804    const int64_t ne[4] = {
4805        ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/),
4806        ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/),
4807        a->ne[2], b->ne[3],
4808    };
4809
4810    struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4811
4812    ggml_set_op_params_i32(result, 0, stride);
4813
4814    result->op     = GGML_OP_CONV_TRANSPOSE_2D;
4815    result->src[0] = a;
4816    result->src[1] = b;
4817
4818    return result;
4819}
4820
4821// ggml_pool_*
4822
4823static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) {
4824    return (ins + 2 * p - ks) / s + 1;
4825}
4826
4827// ggml_pool_1d
4828
4829struct ggml_tensor * ggml_pool_1d(
4830        struct ggml_context * ctx,
4831        struct ggml_tensor  * a,
4832        enum ggml_op_pool     op,
4833        int                   k0,
4834        int                   s0,
4835        int                   p0) {
4836    const int64_t ne[4] = {
4837        ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
4838        a->ne[1],
4839        a->ne[2],
4840        a->ne[3],
4841    };
4842    GGML_ASSERT(ne[0] > 0);
4843
4844    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4845
4846    int32_t params[] = { op, k0, s0, p0 };
4847    ggml_set_op_params(result, params, sizeof(params));
4848
4849    result->op     = GGML_OP_POOL_1D;
4850    result->src[0] = a;
4851
4852    return result;
4853}
4854
4855// ggml_pool_2d
4856
4857struct ggml_tensor * ggml_pool_2d(
4858        struct ggml_context * ctx,
4859        struct ggml_tensor  * a,
4860        enum ggml_op_pool     op,
4861        int                   k0,
4862        int                   k1,
4863        int                   s0,
4864        int                   s1,
4865        float                 p0,
4866        float                 p1) {
4867    struct ggml_tensor * result;
4868    const int64_t ne[4] = {
4869        ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
4870        ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
4871        a->ne[2],
4872        a->ne[3],
4873    };
4874    GGML_ASSERT(ne[0] > 0);
4875    GGML_ASSERT(ne[1] > 0);
4876
4877    result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4878
4879    int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
4880    ggml_set_op_params(result, params, sizeof(params));
4881
4882    result->op     = GGML_OP_POOL_2D;
4883    result->src[0] = a;
4884
4885    return result;
4886}
4887
4888struct ggml_tensor * ggml_pool_2d_back(
4889        struct ggml_context * ctx,
4890        struct ggml_tensor  * a,
4891        struct ggml_tensor  * af,
4892        enum ggml_op_pool     op,
4893        int                   k0,
4894        int                   k1,
4895        int                   s0,
4896        int                   s1,
4897        float                 p0,
4898        float                 p1) {
4899    struct ggml_tensor * result;
4900    result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne);
4901
4902    int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
4903    ggml_set_op_params(result, params, sizeof(params));
4904
4905    result->op     = GGML_OP_POOL_2D_BACK;
4906    result->src[0] = a;
4907    result->src[1] = af;
4908
4909    return result;
4910}
4911
4912// ggml_upscale / ggml_interpolate
4913
4914static struct ggml_tensor * ggml_interpolate_impl(
4915        struct ggml_context * ctx,
4916        struct ggml_tensor  * a,
4917        int64_t               ne0,
4918        int64_t               ne1,
4919        int64_t               ne2,
4920        int64_t               ne3,
4921        uint32_t              mode) {
4922    GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
4923    // TODO: implement antialias for modes other than bilinear
4924    GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR);
4925
4926    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
4927
4928    ggml_set_op_params_i32(result, 0, (int32_t)mode);
4929
4930    result->op     = GGML_OP_UPSCALE;
4931    result->src[0] = a;
4932
4933    return result;
4934}
4935
4936struct ggml_tensor * ggml_upscale(
4937        struct ggml_context * ctx,
4938        struct ggml_tensor  * a,
4939        int                   scale_factor,
4940        enum ggml_scale_mode  mode) {
4941    GGML_ASSERT(scale_factor > 1);
4942    return ggml_interpolate_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
4943}
4944
4945struct ggml_tensor * ggml_upscale_ext(
4946        struct ggml_context * ctx,
4947        struct ggml_tensor  * a,
4948        int                   ne0,
4949        int                   ne1,
4950        int                   ne2,
4951        int                   ne3,
4952        enum ggml_scale_mode  mode) {
4953    return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4954}
4955
4956struct ggml_tensor * ggml_interpolate(
4957        struct ggml_context * ctx,
4958        struct ggml_tensor  * a,
4959        int64_t               ne0,
4960        int64_t               ne1,
4961        int64_t               ne2,
4962        int64_t               ne3,
4963        uint32_t              mode) {
4964    return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
4965}
4966
4967// ggml_pad
4968
4969struct ggml_tensor * ggml_pad(
4970        struct ggml_context * ctx,
4971        struct ggml_tensor  * a,
4972        int                   p0,
4973        int                   p1,
4974        int                   p2,
4975        int                   p3) {
4976    return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
4977}
4978
4979// ggml_pad_circular
4980
4981struct ggml_tensor * ggml_pad_circular(
4982        struct ggml_context * ctx,
4983        struct ggml_tensor  * a,
4984        int                   p0,
4985        int                   p1,
4986        int                   p2,
4987        int                   p3) {
4988    return ggml_pad_ext_circular(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
4989}
4990
4991struct ggml_tensor * ggml_pad_ext(
4992            struct ggml_context * ctx,
4993            struct ggml_tensor  * a,
4994            int                  lp0,
4995            int                  rp0,
4996            int                  lp1,
4997            int                  rp1,
4998            int                  lp2,
4999            int                  rp2,
5000            int                  lp3,
5001            int                  rp3
5002            ) {
5003    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
5004            a->ne[0] + lp0 + rp0,
5005            a->ne[1] + lp1 + rp1,
5006            a->ne[2] + lp2 + rp2,
5007            a->ne[3] + lp3 + rp3);
5008
5009    ggml_set_op_params_i32(result, 0, lp0);
5010    ggml_set_op_params_i32(result, 1, rp0);
5011    ggml_set_op_params_i32(result, 2, lp1);
5012    ggml_set_op_params_i32(result, 3, rp1);
5013    ggml_set_op_params_i32(result, 4, lp2);
5014    ggml_set_op_params_i32(result, 5, rp2);
5015    ggml_set_op_params_i32(result, 6, lp3);
5016    ggml_set_op_params_i32(result, 7, rp3);
5017    ggml_set_op_params_i32(result, 8, 0); // not circular by default
5018
5019
5020    result->op     = GGML_OP_PAD;
5021    result->src[0] = a;
5022
5023    return result;
5024}
5025
5026// ggml_pad_ext_circular
5027
5028struct ggml_tensor * ggml_pad_ext_circular(
5029        struct ggml_context * ctx,
5030        struct ggml_tensor  * a,
5031        int                  lp0,
5032        int                  rp0,
5033        int                  lp1,
5034        int                  rp1,
5035        int                  lp2,
5036        int                  rp2,
5037        int                  lp3,
5038        int                  rp3
5039        ) {
5040    struct ggml_tensor * result = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
5041    ggml_set_op_params_i32(result, 8, 1); // circular
5042    return result;
5043}
5044
5045// ggml_pad_reflect_1d
5046
5047struct ggml_tensor * ggml_pad_reflect_1d(
5048        struct ggml_context * ctx,
5049        struct ggml_tensor  * a,
5050        int                   p0,
5051        int                   p1) {
5052    GGML_ASSERT(p0 >= 0);
5053    GGML_ASSERT(p1 >= 0);
5054
5055    GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the
5056    GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded
5057
5058    GGML_ASSERT(ggml_is_contiguous(a));
5059    GGML_ASSERT(a->type == GGML_TYPE_F32);
5060
5061    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
5062            a->ne[0] + p0 + p1,
5063            a->ne[1],
5064            a->ne[2],
5065            a->ne[3]);
5066
5067    int32_t params[] = { p0, p1 };
5068    ggml_set_op_params(result, params, sizeof(params));
5069
5070    result->op     = GGML_OP_PAD_REFLECT_1D;
5071    result->src[0] = a;
5072
5073    return result;
5074}
5075
5076// ggml_roll
5077
5078struct ggml_tensor * ggml_roll(
5079        struct ggml_context * ctx,
5080        struct ggml_tensor  * a,
5081        int                   shift0,
5082        int                   shift1,
5083        int                   shift2,
5084        int                   shift3) {
5085    GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
5086    GGML_ASSERT(abs(shift0) < a->ne[0]);
5087    GGML_ASSERT(abs(shift1) < a->ne[1]);
5088    GGML_ASSERT(abs(shift2) < a->ne[2]);
5089    GGML_ASSERT(abs(shift3) < a->ne[3]);
5090
5091    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
5092
5093    ggml_set_op_params_i32(result, 0, shift0);
5094    ggml_set_op_params_i32(result, 1, shift1);
5095    ggml_set_op_params_i32(result, 2, shift2);
5096    ggml_set_op_params_i32(result, 3, shift3);
5097
5098    result->op     = GGML_OP_ROLL;
5099    result->src[0] = a;
5100
5101    return result;
5102}
5103
5104// ggml_timestep_embedding
5105
5106struct ggml_tensor * ggml_timestep_embedding(
5107        struct ggml_context * ctx,
5108        struct ggml_tensor  * timesteps,
5109        int                   dim,
5110        int                   max_period) {
5111
5112    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, timesteps->ne[0]);
5113
5114    ggml_set_op_params_i32(result, 0, dim);
5115    ggml_set_op_params_i32(result, 1, max_period);
5116
5117    result->op     = GGML_OP_TIMESTEP_EMBEDDING;
5118    result->src[0] = timesteps;
5119
5120    return result;
5121}
5122
5123// ggml_tri
5124
5125struct ggml_tensor * ggml_tri(
5126    struct ggml_context * ctx,
5127    struct ggml_tensor  * a,
5128    enum ggml_tri_type    type) {
5129    GGML_ASSERT(a->type == GGML_TYPE_F32);
5130
5131    GGML_ASSERT(ggml_is_contiguous(a));
5132    GGML_ASSERT(a->ne[0] == a->ne[1]);
5133
5134    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
5135
5136    ggml_set_op_params_i32(result, 0, type);
5137
5138    result->op = GGML_OP_TRI;
5139    result->src[0] = a;
5140
5141    return result;
5142}
5143
5144// ggml_fill
5145
5146static struct ggml_tensor * ggml_fill_impl(
5147    struct ggml_context * ctx,
5148    struct ggml_tensor  * a,
5149    float                 c,
5150    bool                  inplace) {
5151    GGML_ASSERT(a->type == GGML_TYPE_F32);
5152    GGML_ASSERT(ggml_is_contiguous(a));
5153
5154    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5155
5156    ggml_set_op_params_f32(result, 0, c);
5157
5158    result->op = GGML_OP_FILL;
5159    result->src[0] = a;
5160
5161    return result;
5162}
5163
5164struct ggml_tensor * ggml_fill(
5165    struct ggml_context * ctx,
5166    struct ggml_tensor  * a,
5167    float                 c) {
5168    return ggml_fill_impl(ctx, a, c, false);
5169}
5170
5171struct ggml_tensor * ggml_fill_inplace(
5172    struct ggml_context * ctx,
5173    struct ggml_tensor  * a,
5174    float                 c) {
5175    return ggml_fill_impl(ctx, a, c, true);
5176}
5177
5178// ggml_argsort
5179
5180struct ggml_tensor * ggml_argsort(
5181        struct ggml_context  * ctx,
5182        struct ggml_tensor   * a,
5183        enum ggml_sort_order   order) {
5184    GGML_ASSERT(a->ne[0] <= INT32_MAX);
5185
5186    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
5187
5188    ggml_set_op_params_i32(result, 0, (int32_t) order);
5189
5190    result->op     = GGML_OP_ARGSORT;
5191    result->src[0] = a;
5192
5193    return result;
5194}
5195
5196// ggml_argsort_top_k
5197
5198struct ggml_tensor * ggml_argsort_top_k(
5199        struct ggml_context * ctx,
5200        struct ggml_tensor  * a,
5201        int                   k) {
5202    GGML_ASSERT(a->ne[0] >= k);
5203
5204    struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
5205
5206    result = ggml_view_4d(ctx, result,
5207                k, result->ne[1], result->ne[2], result->ne[3],
5208                   result->nb[1], result->nb[2], result->nb[3],
5209                0);
5210
5211    return result;
5212}
5213
5214// ggml_top_k
5215
5216struct ggml_tensor * ggml_top_k(
5217        struct ggml_context * ctx,
5218        struct ggml_tensor  * a,
5219        int                   k) {
5220    GGML_ASSERT(a->ne[0] >= k);
5221
5222    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_I32, k, a->ne[1], a->ne[2], a->ne[3]);
5223
5224    result->op     = GGML_OP_TOP_K;
5225    result->src[0] = a;
5226
5227    return result;
5228}
5229
5230// ggml_arange
5231
5232struct ggml_tensor * ggml_arange(
5233        struct ggml_context * ctx,
5234        float                 start,
5235        float                 stop,
5236        float                 step) {
5237    GGML_ASSERT(stop > start);
5238
5239    const int64_t steps = (int64_t) ceilf((stop - start) / step);
5240
5241    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
5242
5243    ggml_set_op_params_f32(result, 0, start);
5244    ggml_set_op_params_f32(result, 1, stop);
5245    ggml_set_op_params_f32(result, 2, step);
5246
5247    result->op = GGML_OP_ARANGE;
5248
5249    return result;
5250}
5251
5252// ggml_flash_attn_ext
5253
5254struct ggml_tensor * ggml_flash_attn_ext(
5255        struct ggml_context * ctx,
5256        struct ggml_tensor  * q,
5257        struct ggml_tensor  * k,
5258        struct ggml_tensor  * v,
5259        struct ggml_tensor  * mask,
5260        float                 scale,
5261        float                 max_bias,
5262        float                 logit_softcap) {
5263    GGML_ASSERT(ggml_can_mul_mat(k, q));
5264    // TODO: check if vT can be multiplied by (k*qT)
5265
5266    GGML_ASSERT(q->ne[3] == k->ne[3]);
5267    GGML_ASSERT(q->ne[3] == v->ne[3]);
5268
5269    if (mask) {
5270        GGML_ASSERT(ggml_is_contiguous(mask));
5271        //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
5272
5273        GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
5274        GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
5275    }
5276
5277    if (max_bias > 0.0f) {
5278        GGML_ASSERT(mask);
5279    }
5280
5281    // permute(0, 2, 1, 3)
5282    int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
5283    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5284
5285    float params[] = { scale, max_bias, logit_softcap };
5286    ggml_set_op_params(result, params, sizeof(params));
5287
5288    result->op     = GGML_OP_FLASH_ATTN_EXT;
5289    result->src[0] = q;
5290    result->src[1] = k;
5291    result->src[2] = v;
5292    result->src[3] = mask;
5293
5294    return result;
5295}
5296
5297void ggml_flash_attn_ext_set_prec(
5298        struct ggml_tensor * a,
5299        enum ggml_prec       prec) {
5300    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
5301
5302    const int32_t prec_i32 = (int32_t) prec;
5303
5304    ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
5305}
5306
5307enum ggml_prec ggml_flash_attn_ext_get_prec(
5308        const struct ggml_tensor * a) {
5309    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
5310
5311    const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
5312
5313    return (enum ggml_prec) prec_i32;
5314}
5315
5316void ggml_flash_attn_ext_add_sinks(
5317        struct ggml_tensor * a,
5318        struct ggml_tensor * sinks) {
5319    if (!sinks) {
5320        a->src[4] = NULL;
5321        return;
5322    }
5323
5324    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
5325    GGML_ASSERT(a->src[4] == NULL);
5326    GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
5327    GGML_ASSERT(sinks->type == GGML_TYPE_F32);
5328
5329    a->src[4] = sinks;
5330}
5331
5332// ggml_flash_attn_back
5333
5334struct ggml_tensor * ggml_flash_attn_back(
5335        struct ggml_context * ctx,
5336        struct ggml_tensor  * q,
5337        struct ggml_tensor  * k,
5338        struct ggml_tensor  * v,
5339        struct ggml_tensor  * d,
5340        bool                  masked) {
5341    GGML_ABORT("TODO: adapt to ggml_flash_attn_ext() changes");
5342
5343    GGML_ASSERT(ggml_can_mul_mat(k, q));
5344    // TODO: check if vT can be multiplied by (k*qT)
5345
5346    // d shape [D,N,ne2,ne3]
5347    // q shape [D,N,ne2,ne3]
5348    // k shape [D,M,kvne2,ne3]
5349    // v shape [M,D,kvne2,ne3]
5350
5351    const int64_t     D = q->ne[0];
5352    const int64_t     N = q->ne[1];
5353    const int64_t     M = k->ne[1];
5354    const int64_t   ne2 = q->ne[2];
5355    const int64_t   ne3 = q->ne[3];
5356    const int64_t kvne2 = k->ne[2];
5357
5358    GGML_ASSERT(k->ne[0] == D);
5359    GGML_ASSERT(v->ne[0] == M);
5360    GGML_ASSERT(v->ne[1] == D);
5361    GGML_ASSERT(d->ne[0] == D);
5362    GGML_ASSERT(d->ne[1] == N);
5363    GGML_ASSERT(k->ne[2] == kvne2);
5364    GGML_ASSERT(k->ne[3] == ne3);
5365    GGML_ASSERT(v->ne[2] == kvne2);
5366    GGML_ASSERT(v->ne[3] == ne3);
5367    GGML_ASSERT(d->ne[2] == ne2);
5368    GGML_ASSERT(d->ne[3] == ne3);
5369
5370    GGML_ASSERT(ne2 % kvne2 == 0);
5371
5372    // store gradients of q, k and v as continuous tensors concatenated in result.
5373    // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
5374    const int64_t elem_q = ggml_nelements(q);
5375    const int64_t elem_k = ggml_nelements(k);
5376    const int64_t elem_v = ggml_nelements(v);
5377
5378    enum ggml_type result_type = GGML_TYPE_F32;
5379    GGML_ASSERT(ggml_blck_size(result_type) == 1);
5380    const size_t tsize = ggml_type_size(result_type);
5381
5382    const size_t offs_q = 0;
5383    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
5384    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
5385    const size_t end    = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
5386
5387    const size_t nelements = (end + tsize - 1)/tsize;
5388
5389    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
5390
5391    int32_t masked_i = masked ? 1 : 0;
5392    ggml_set_op_params(result, &masked_i, sizeof(masked_i));
5393
5394    result->op     = GGML_OP_FLASH_ATTN_BACK;
5395    result->src[0] = q;
5396    result->src[1] = k;
5397    result->src[2] = v;
5398    result->src[3] = d;
5399
5400    return result;
5401}
5402
5403// ggml_ssm_conv
5404
5405struct ggml_tensor * ggml_ssm_conv(
5406        struct ggml_context * ctx,
5407        struct ggml_tensor  * sx,
5408        struct ggml_tensor  * c) {
5409    GGML_ASSERT(ggml_is_3d(sx));
5410    GGML_ASSERT(ggml_is_matrix(c));
5411
5412    const int64_t d_conv  = c->ne[0];
5413    const int64_t d_inner = c->ne[1];
5414    const int64_t n_t     = sx->ne[0] - d_conv + 1; // tokens per sequence
5415    const int64_t n_s     = sx->ne[2];
5416
5417    // TODO: maybe support other strides than 1?
5418    GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
5419    GGML_ASSERT(sx->ne[1] == d_inner);
5420    GGML_ASSERT(n_t >= 0);
5421
5422    struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
5423
5424    result->op     = GGML_OP_SSM_CONV;
5425    result->src[0] = sx;
5426    result->src[1] = c;
5427
5428    return result;
5429}
5430
5431// ggml_ssm_scan
5432
5433struct ggml_tensor * ggml_ssm_scan(
5434        struct ggml_context * ctx,
5435        struct ggml_tensor  * s,
5436        struct ggml_tensor  * x,
5437        struct ggml_tensor  * dt,
5438        struct ggml_tensor  * A,
5439        struct ggml_tensor  * B,
5440        struct ggml_tensor  * C,
5441        struct ggml_tensor  * ids) {
5442    GGML_ASSERT(ggml_is_contiguous(s));
5443    GGML_ASSERT(ggml_is_contiguous(dt));
5444    GGML_ASSERT(ggml_is_contiguous(A));
5445    GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
5446    GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
5447    GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
5448    GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
5449    GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
5450    GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
5451    GGML_ASSERT(ggml_are_same_shape(B, C));
5452    GGML_ASSERT(ids->type == GGML_TYPE_I32);
5453
5454    {
5455        const int64_t d_state      = s->ne[0];
5456        const int64_t head_dim     = x->ne[0];
5457        const int64_t n_head       = x->ne[1];
5458        const int64_t n_seq_tokens = x->ne[2];
5459        const int64_t n_seqs       = x->ne[3];
5460
5461        GGML_ASSERT(dt->ne[0] == n_head);
5462        GGML_ASSERT(dt->ne[1] == n_seq_tokens);
5463        GGML_ASSERT(dt->ne[2] == n_seqs);
5464        GGML_ASSERT(ggml_is_3d(dt));
5465        GGML_ASSERT(s->ne[1] == head_dim);
5466        GGML_ASSERT(s->ne[2] == n_head);
5467        GGML_ASSERT(B->ne[0] == d_state);
5468        GGML_ASSERT(B->ne[2] == n_seq_tokens);
5469        GGML_ASSERT(B->ne[3] == n_seqs);
5470        GGML_ASSERT(ids->ne[0] == n_seqs);
5471        GGML_ASSERT(ggml_is_vector(ids));
5472        GGML_ASSERT(A->ne[1] == n_head);
5473        GGML_ASSERT(ggml_is_matrix(A));
5474
5475        if (A->ne[0] != 1) {
5476            // Mamba-1 has more granular decay factors
5477            GGML_ASSERT(A->ne[0] == d_state);
5478        }
5479    }
5480
5481    // concatenated y + ssm_states
5482    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
5483
5484    result->op   = GGML_OP_SSM_SCAN;
5485    result->src[0] = s;
5486    result->src[1] = x;
5487    result->src[2] = dt;
5488    result->src[3] = A;
5489    result->src[4] = B;
5490    result->src[5] = C;
5491    result->src[6] = ids;
5492
5493    return result;
5494}
5495
5496// ggml_win_part
5497
5498struct ggml_tensor * ggml_win_part(
5499        struct ggml_context * ctx,
5500        struct ggml_tensor  * a,
5501        int                   w) {
5502    GGML_ASSERT(a->ne[3] == 1);
5503    GGML_ASSERT(a->type  == GGML_TYPE_F32);
5504
5505    // padding
5506    const int px = (w - a->ne[1]%w)%w;
5507    const int py = (w - a->ne[2]%w)%w;
5508
5509    const int npx = (px + a->ne[1])/w;
5510    const int npy = (py + a->ne[2])/w;
5511    const int np  = npx*npy;
5512
5513    const int64_t ne[4] = { a->ne[0], w, w, np, };
5514    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5515
5516    int32_t params[] = { npx, npy, w };
5517    ggml_set_op_params(result, params, sizeof(params));
5518
5519    result->op     = GGML_OP_WIN_PART;
5520    result->src[0] = a;
5521
5522    return result;
5523}
5524
5525// ggml_win_unpart
5526
5527struct ggml_tensor * ggml_win_unpart(
5528        struct ggml_context * ctx,
5529        struct ggml_tensor  * a,
5530        int                   w0,
5531        int                   h0,
5532        int                   w) {
5533    GGML_ASSERT(a->type == GGML_TYPE_F32);
5534
5535    const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
5536    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
5537
5538    int32_t params[] = { w };
5539    ggml_set_op_params(result, params, sizeof(params));
5540
5541    result->op     = GGML_OP_WIN_UNPART;
5542    result->src[0] = a;
5543
5544    return result;
5545}
5546
5547// ggml_get_rel_pos
5548
5549struct ggml_tensor * ggml_get_rel_pos(
5550        struct ggml_context * ctx,
5551        struct ggml_tensor  * a,
5552        int                   qh,
5553        int                   kh) {
5554    GGML_ASSERT(qh == kh);
5555    GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);
5556
5557    const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
5558    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne);
5559
5560    result->op     = GGML_OP_GET_REL_POS;
5561    result->src[0] = a;
5562
5563    return result;
5564}
5565
5566// ggml_add_rel_pos
5567
5568static struct ggml_tensor * ggml_add_rel_pos_impl(
5569        struct ggml_context * ctx,
5570        struct ggml_tensor  * a,
5571        struct ggml_tensor  * pw,
5572        struct ggml_tensor  * ph,
5573        bool                  inplace) {
5574    GGML_ASSERT(ggml_are_same_shape(pw, ph));
5575    GGML_ASSERT(ggml_is_contiguous(a));
5576    GGML_ASSERT(ggml_is_contiguous(pw));
5577    GGML_ASSERT(ggml_is_contiguous(ph));
5578    GGML_ASSERT(ph->type == GGML_TYPE_F32);
5579    GGML_ASSERT(pw->type == GGML_TYPE_F32);
5580    GGML_ASSERT(pw->ne[3] == a->ne[2]);
5581    GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]);
5582    GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);
5583
5584    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5585    ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);
5586
5587    result->op     = GGML_OP_ADD_REL_POS;
5588    result->src[0] = a;
5589    result->src[1] = pw;
5590    result->src[2] = ph;
5591
5592    return result;
5593}
5594
5595struct ggml_tensor * ggml_add_rel_pos(
5596        struct ggml_context * ctx,
5597        struct ggml_tensor  * a,
5598        struct ggml_tensor  * pw,
5599        struct ggml_tensor  * ph) {
5600    return ggml_add_rel_pos_impl(ctx, a, pw, ph, false);
5601}
5602
5603struct ggml_tensor * ggml_add_rel_pos_inplace(
5604        struct ggml_context * ctx,
5605        struct ggml_tensor  * a,
5606        struct ggml_tensor  * pw,
5607        struct ggml_tensor  * ph) {
5608    return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
5609}
5610
5611// ggml_rwkv_wkv6
5612
5613struct ggml_tensor * ggml_rwkv_wkv6(
5614        struct ggml_context * ctx,
5615        struct ggml_tensor  * k,
5616        struct ggml_tensor  * v,
5617        struct ggml_tensor  * r,
5618        struct ggml_tensor  * tf,
5619        struct ggml_tensor  * td,
5620        struct ggml_tensor  * state) {
5621    GGML_ASSERT(ggml_is_contiguous(k));
5622    GGML_ASSERT(ggml_is_contiguous(v));
5623    GGML_ASSERT(ggml_is_contiguous(r));
5624    GGML_ASSERT(ggml_is_contiguous(tf));
5625    GGML_ASSERT(ggml_is_contiguous(td));
5626    GGML_ASSERT(ggml_is_contiguous(state));
5627
5628    const int64_t S = k->ne[0];
5629    const int64_t H = k->ne[1];
5630    const int64_t n_tokens = k->ne[2];
5631    const int64_t n_seqs = state->ne[1];
5632    {
5633        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
5634        GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
5635        GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
5636        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
5637    }
5638
5639    // concat output and new_state
5640    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
5641    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5642
5643    result->op     = GGML_OP_RWKV_WKV6;
5644    result->src[0] = k;
5645    result->src[1] = v;
5646    result->src[2] = r;
5647    result->src[3] = tf;
5648    result->src[4] = td;
5649    result->src[5] = state;
5650
5651    return result;
5652}
5653
5654// ggml_gated_linear_attn
5655
5656struct ggml_tensor * ggml_gated_linear_attn(
5657        struct ggml_context * ctx,
5658        struct ggml_tensor  * k,
5659        struct ggml_tensor  * v,
5660        struct ggml_tensor  * q,
5661        struct ggml_tensor  * g,
5662        struct ggml_tensor  * state,
5663        float scale) {
5664    GGML_ASSERT(ggml_is_contiguous(k));
5665    GGML_ASSERT(ggml_is_contiguous(v));
5666    GGML_ASSERT(ggml_is_contiguous(q));
5667    GGML_ASSERT(ggml_is_contiguous(g));
5668    GGML_ASSERT(ggml_is_contiguous(state));
5669
5670    const int64_t S = k->ne[0];
5671    const int64_t H = k->ne[1];
5672    const int64_t n_tokens = k->ne[2];
5673    const int64_t n_seqs = state->ne[1];
5674    {
5675        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
5676        GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
5677        GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
5678        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
5679    }
5680
5681    // concat output and new_state
5682    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
5683    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5684
5685    ggml_set_op_params_f32(result, 0, scale);
5686
5687    result->op     = GGML_OP_GATED_LINEAR_ATTN;
5688    result->src[0] = k;
5689    result->src[1] = v;
5690    result->src[2] = q;
5691    result->src[3] = g;
5692    result->src[4] = state;
5693
5694    return result;
5695}
5696
5697// ggml_rwkv_wkv7
5698
5699struct ggml_tensor * ggml_rwkv_wkv7(
5700        struct ggml_context * ctx,
5701        struct ggml_tensor  * r,
5702        struct ggml_tensor  * w,
5703        struct ggml_tensor  * k,
5704        struct ggml_tensor  * v,
5705        struct ggml_tensor  * a,
5706        struct ggml_tensor  * b,
5707        struct ggml_tensor  * state) {
5708    GGML_ASSERT(ggml_is_contiguous(r));
5709    GGML_ASSERT(ggml_is_contiguous(w));
5710    GGML_ASSERT(ggml_is_contiguous(k));
5711    GGML_ASSERT(ggml_is_contiguous(v));
5712    GGML_ASSERT(ggml_is_contiguous(a));
5713    GGML_ASSERT(ggml_is_contiguous(b));
5714    GGML_ASSERT(ggml_is_contiguous(state));
5715
5716    const int64_t S = k->ne[0];
5717    const int64_t H = k->ne[1];
5718    const int64_t n_tokens = k->ne[2];
5719    const int64_t n_seqs = state->ne[1];
5720    {
5721        GGML_ASSERT(w->ne[0] == S && w->ne[1] == H && w->ne[2] == n_tokens);
5722        GGML_ASSERT(k->ne[0] == S && k->ne[1] == H && k->ne[2] == n_tokens);
5723        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
5724        GGML_ASSERT(a->ne[0] == S && a->ne[1] == H && a->ne[2] == n_tokens);
5725        GGML_ASSERT(b->ne[0] == S && b->ne[1] == H && b->ne[2] == n_tokens);
5726        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
5727    }
5728
5729    // concat output and new_state
5730    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
5731    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5732
5733    result->op     = GGML_OP_RWKV_WKV7;
5734    result->src[0] = r;
5735    result->src[1] = w;
5736    result->src[2] = k;
5737    result->src[3] = v;
5738    result->src[4] = a;
5739    result->src[5] = b;
5740    result->src[6] = state;
5741
5742    return result;
5743}
5744
5745// ggml_unary
5746
5747static struct ggml_tensor * ggml_unary_impl(
5748        struct ggml_context * ctx,
5749        struct ggml_tensor  * a,
5750        enum ggml_unary_op    op,
5751        bool                  inplace) {
5752    GGML_ASSERT(ggml_is_contiguous_rows(a));
5753
5754    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5755
5756    ggml_set_op_params_i32(result, 0, (int32_t) op);
5757
5758    result->op     = GGML_OP_UNARY;
5759    result->src[0] = a;
5760
5761    return result;
5762}
5763
5764struct ggml_tensor * ggml_unary(
5765        struct ggml_context * ctx,
5766        struct ggml_tensor  * a,
5767        enum ggml_unary_op    op) {
5768    return ggml_unary_impl(ctx, a, op, false);
5769}
5770
5771struct ggml_tensor * ggml_unary_inplace(
5772        struct ggml_context * ctx,
5773        struct ggml_tensor  * a,
5774        enum ggml_unary_op    op) {
5775    return ggml_unary_impl(ctx, a, op, true);
5776}
5777
5778// ggml_map_custom1
5779
5780static struct ggml_tensor * ggml_map_custom1_impl(
5781        struct ggml_context      * ctx,
5782        struct ggml_tensor       * a,
5783        const  ggml_custom1_op_t   fun,
5784        int                        n_tasks,
5785        void                     * userdata,
5786        bool                       inplace) {
5787    GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
5788
5789    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5790
5791    struct ggml_map_custom1_op_params params = {
5792        /*.fun      =*/ fun,
5793        /*.n_tasks  =*/ n_tasks,
5794        /*.userdata =*/ userdata
5795    };
5796    ggml_set_op_params(result, &params, sizeof(params));
5797
5798    result->op     = GGML_OP_MAP_CUSTOM1;
5799    result->src[0] = a;
5800
5801    return result;
5802}
5803
5804struct ggml_tensor * ggml_map_custom1(
5805        struct ggml_context      * ctx,
5806        struct ggml_tensor       * a,
5807        const  ggml_custom1_op_t   fun,
5808        int                        n_tasks,
5809        void                     * userdata) {
5810    return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false);
5811}
5812
5813struct ggml_tensor * ggml_map_custom1_inplace(
5814        struct ggml_context      * ctx,
5815        struct ggml_tensor       * a,
5816        const  ggml_custom1_op_t   fun,
5817        int                        n_tasks,
5818        void                     * userdata) {
5819    return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true);
5820}
5821
5822// ggml_map_custom2
5823
5824static struct ggml_tensor * ggml_map_custom2_impl(
5825        struct ggml_context      * ctx,
5826        struct ggml_tensor       * a,
5827        struct ggml_tensor       * b,
5828        const  ggml_custom2_op_t   fun,
5829        int                        n_tasks,
5830        void                     * userdata,
5831        bool                       inplace) {
5832    GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
5833
5834    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5835
5836    struct ggml_map_custom2_op_params params = {
5837        /*.fun      =*/ fun,
5838        /*.n_tasks  =*/ n_tasks,
5839        /*.userdata =*/ userdata
5840    };
5841    ggml_set_op_params(result, &params, sizeof(params));
5842
5843    result->op     = GGML_OP_MAP_CUSTOM2;
5844    result->src[0] = a;
5845    result->src[1] = b;
5846
5847    return result;
5848}
5849
5850struct ggml_tensor * ggml_map_custom2(
5851        struct ggml_context      * ctx,
5852        struct ggml_tensor       * a,
5853        struct ggml_tensor       * b,
5854        const  ggml_custom2_op_t   fun,
5855        int                        n_tasks,
5856        void                     * userdata) {
5857    return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false);
5858}
5859
5860struct ggml_tensor * ggml_map_custom2_inplace(
5861        struct ggml_context      * ctx,
5862        struct ggml_tensor       * a,
5863        struct ggml_tensor       * b,
5864        const  ggml_custom2_op_t   fun,
5865        int                        n_tasks,
5866        void                     * userdata) {
5867    return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true);
5868}
5869
5870// ggml_map_custom3
5871
5872static struct ggml_tensor * ggml_map_custom3_impl(
5873        struct ggml_context      * ctx,
5874        struct ggml_tensor       * a,
5875        struct ggml_tensor       * b,
5876        struct ggml_tensor       * c,
5877        const  ggml_custom3_op_t   fun,
5878        int                        n_tasks,
5879        void                     * userdata,
5880        bool                       inplace) {
5881    GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
5882
5883    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5884
5885    struct ggml_map_custom3_op_params params = {
5886        /*.fun      =*/ fun,
5887        /*.n_tasks  =*/ n_tasks,
5888        /*.userdata =*/ userdata
5889    };
5890    ggml_set_op_params(result, &params, sizeof(params));
5891
5892    result->op     = GGML_OP_MAP_CUSTOM3;
5893    result->src[0] = a;
5894    result->src[1] = b;
5895    result->src[2] = c;
5896
5897    return result;
5898}
5899
5900struct ggml_tensor * ggml_map_custom3(
5901        struct ggml_context      * ctx,
5902        struct ggml_tensor       * a,
5903        struct ggml_tensor       * b,
5904        struct ggml_tensor       * c,
5905        const  ggml_custom3_op_t   fun,
5906        int                        n_tasks,
5907        void                     * userdata) {
5908    return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false);
5909}
5910
5911struct ggml_tensor * ggml_map_custom3_inplace(
5912        struct ggml_context      * ctx,
5913        struct ggml_tensor       * a,
5914        struct ggml_tensor       * b,
5915        struct ggml_tensor       * c,
5916        const  ggml_custom3_op_t   fun,
5917        int                        n_tasks,
5918        void                     * userdata) {
5919    return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
5920}
5921
5922struct ggml_tensor * ggml_custom_4d(
5923        struct ggml_context * ctx,
5924        enum ggml_type        type,
5925        int64_t               ne0,
5926        int64_t               ne1,
5927        int64_t               ne2,
5928        int64_t               ne3,
5929        struct ggml_tensor ** args,
5930        int                   n_args,
5931        ggml_custom_op_t      fun,
5932        int                   n_tasks,
5933        void                * userdata) {
5934
5935    GGML_ASSERT(n_args < GGML_MAX_SRC);
5936
5937    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
5938
5939    struct ggml_custom_op_params params = {
5940        /*.fun      =*/ fun,
5941        /*.n_tasks  =*/ n_tasks,
5942        /*.userdata =*/ userdata
5943    };
5944    ggml_set_op_params(result, &params, sizeof(params));
5945
5946    result->op = GGML_OP_CUSTOM;
5947    for (int i = 0; i < n_args; i++) {
5948        result->src[i] = args[i];
5949    }
5950
5951    return result;
5952}
5953
5954struct ggml_tensor * ggml_custom_inplace(
5955        struct ggml_context * ctx,
5956        struct ggml_tensor  * a,
5957        struct ggml_tensor ** args,
5958        int                   n_args,
5959        ggml_custom_op_t      fun,
5960        int                   n_tasks,
5961        void                * userdata) {
5962
5963    GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
5964
5965    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5966
5967    struct ggml_custom_op_params params = {
5968        /*.fun      =*/ fun,
5969        /*.n_tasks  =*/ n_tasks,
5970        /*.userdata =*/ userdata
5971    };
5972    ggml_set_op_params(result, &params, sizeof(params));
5973
5974    result->op = GGML_OP_CUSTOM;
5975    result->src[0] = a;
5976    for (int i = 0; i < n_args; i++) {
5977        result->src[i + 1] = args[i];
5978    }
5979
5980    return result;
5981}
5982// ggml_cross_entropy_loss
5983
5984struct ggml_tensor * ggml_cross_entropy_loss(
5985        struct ggml_context * ctx,
5986        struct ggml_tensor  * a,
5987        struct ggml_tensor  * b) {
5988    GGML_ASSERT(ggml_are_same_shape(a, b));
5989
5990    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
5991
5992    result->op     = GGML_OP_CROSS_ENTROPY_LOSS;
5993    result->src[0] = a;
5994    result->src[1] = b;
5995
5996    return result;
5997}
5998
5999// ggml_cross_entropy_loss_back
6000
6001struct ggml_tensor * ggml_cross_entropy_loss_back(
6002        struct ggml_context * ctx,
6003        struct ggml_tensor  * a,
6004        struct ggml_tensor  * b,
6005        struct ggml_tensor  * c) {
6006    GGML_ASSERT(ggml_is_scalar(a));
6007    GGML_ASSERT(ggml_are_same_shape(b, c));
6008
6009    struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
6010
6011    result->op     = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
6012    result->src[0] = a;
6013    result->src[1] = b;
6014    result->src[2] = c;
6015
6016    return result;
6017}
6018
6019// opt_step_adamw
6020
6021struct ggml_tensor * ggml_opt_step_adamw(
6022        struct ggml_context * ctx,
6023        struct ggml_tensor  * a,
6024        struct ggml_tensor  * grad,
6025        struct ggml_tensor  * m,
6026        struct ggml_tensor  * v,
6027        struct ggml_tensor  * adamw_params) {
6028    GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
6029    GGML_ASSERT(ggml_are_same_shape(a, grad));
6030    GGML_ASSERT(ggml_are_same_shape(a, m));
6031    GGML_ASSERT(ggml_are_same_shape(a, v));
6032    GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
6033    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
6034
6035    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
6036
6037    result->op     = GGML_OP_OPT_STEP_ADAMW;
6038    result->src[0] = a;
6039    result->src[1] = grad;
6040    result->src[2] = m;
6041    result->src[3] = v;
6042    result->src[4] = adamw_params;
6043
6044    return result;
6045}
6046
6047// opt_step_sgd
6048
6049struct ggml_tensor * ggml_opt_step_sgd(
6050        struct ggml_context * ctx,
6051        struct ggml_tensor  * a,
6052        struct ggml_tensor  * grad,
6053        struct ggml_tensor  * params) {
6054    GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
6055    GGML_ASSERT(ggml_are_same_shape(a, grad));
6056    GGML_ASSERT(params->type == GGML_TYPE_F32);
6057    GGML_ASSERT(ggml_nelements(params) == 2);
6058
6059    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
6060
6061    result->op     = GGML_OP_OPT_STEP_SGD;
6062    result->src[0] = a;
6063    result->src[1] = grad;
6064    result->src[2] = params;
6065
6066    return result;
6067}
6068
6069// solve_tri
6070
6071struct ggml_tensor * ggml_solve_tri(
6072        struct ggml_context * ctx,
6073        struct ggml_tensor  * a,
6074        struct ggml_tensor  * b,
6075        bool                  left,
6076        bool                  lower,
6077        bool                  uni) {
6078    GGML_ASSERT(a->type == GGML_TYPE_F32);
6079    GGML_ASSERT(b->type == GGML_TYPE_F32);
6080
6081    // A must be square and lower diagonal
6082    GGML_ASSERT(a->ne[0] == a->ne[1]);
6083    // B must have same outer dimension as A
6084    GGML_ASSERT(a->ne[1] == b->ne[1]);
6085
6086    // batch dimensions must be equal
6087    GGML_ASSERT(a->ne[2] == b->ne[2]);
6088    GGML_ASSERT(a->ne[3] == b->ne[3]);
6089
6090    GGML_ASSERT(ggml_is_contiguous(a));
6091    GGML_ASSERT(ggml_is_contiguous(b));
6092
6093    GGML_ASSERT(lower && left && !uni); // TODO: support other variants
6094
6095    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
6096
6097    result->op     = GGML_OP_SOLVE_TRI;
6098    result->src[0] = a;
6099    result->src[1] = b;
6100
6101    return result;
6102}
6103
6104////////////////////////////////////////////////////////////////////////////////
6105
6106struct ggml_hash_set ggml_hash_set_new(size_t size) {
6107    size = ggml_hash_size(size);
6108    struct ggml_hash_set result;
6109    result.size = size;
6110    result.keys = GGML_MALLOC(sizeof(struct ggml_tensor *) * size);
6111    result.used = GGML_CALLOC(ggml_bitset_size(size), sizeof(ggml_bitset_t));
6112    return result;
6113}
6114
6115void ggml_hash_set_reset(struct ggml_hash_set * hash_set) {
6116    memset(hash_set->used, 0, sizeof(ggml_bitset_t) * ggml_bitset_size(hash_set->size));
6117}
6118
6119void ggml_hash_set_free(struct ggml_hash_set * hash_set) {
6120    GGML_FREE(hash_set->used);
6121    GGML_FREE(hash_set->keys);
6122}
6123
6124size_t ggml_hash_size(size_t min_sz) {
6125    // next primes after powers of two
6126    static const size_t primes[] = {
6127        2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031,
6128        2053, 4099, 8209, 16411, 32771, 65537, 131101,
6129        262147, 524309, 1048583, 2097169, 4194319, 8388617,
6130        16777259, 33554467, 67108879, 134217757, 268435459,
6131        536870923, 1073741827, 2147483659
6132    };
6133    static const size_t n_primes = sizeof(primes)/sizeof(primes[0]);
6134
6135    // find the smallest prime that is larger or equal than min_sz
6136    size_t l = 0;
6137    size_t r = n_primes;
6138    while (l < r) {
6139        size_t m = (l + r)/2;
6140        if (primes[m] < min_sz) {
6141            l = m + 1;
6142        } else {
6143            r = m;
6144        }
6145    }
6146    size_t sz = l < n_primes ? primes[l] : min_sz | 1;
6147    return sz;
6148}
6149
6150struct hash_map {
6151    struct ggml_hash_set set;
6152    struct ggml_tensor ** vals;
6153};
6154
6155static struct hash_map * ggml_new_hash_map(size_t size) {
6156    struct hash_map * result = GGML_MALLOC(sizeof(struct hash_map));
6157    result->set = ggml_hash_set_new(size);
6158    result->vals = GGML_CALLOC(result->set.size, sizeof(struct ggml_tensor *));
6159    return result;
6160}
6161
6162static void ggml_hash_map_free(struct hash_map * map) {
6163    ggml_hash_set_free(&map->set);
6164    GGML_FREE(map->vals);
6165    GGML_FREE(map);
6166}
6167
6168// utility functions to change gradients
6169// isrc is the index of tensor in cgraph->visited_has_set.keys
6170// the corresponding gradient (accumulators) are also at position isrc
6171// if tensor has a gradient accumulator, modify that accumulator in-place
6172// else if there is no gradient for tensor, set the corresponding value
6173// else, just add/subtract/etc. the gradients
6174
6175static void ggml_add_or_set(
6176        struct ggml_context * ctx,
6177        struct ggml_cgraph  * cgraph,
6178        size_t                isrc,
6179        struct ggml_tensor  * tensor) {
6180    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
6181    GGML_ASSERT(src);
6182    if (cgraph->grads[isrc]) {
6183        cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]);
6184    } else {
6185        cgraph->grads[isrc] = tensor;
6186    }
6187    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
6188    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
6189}
6190
6191static void ggml_acc_or_set(
6192        struct ggml_context * ctx,
6193        struct ggml_cgraph  * cgraph,
6194        size_t                isrc,
6195        struct ggml_tensor  * tensor,
6196        const  size_t         nb1,
6197        const  size_t         nb2,
6198        const  size_t         nb3,
6199        const  size_t         offset) {
6200    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
6201    GGML_ASSERT(src);
6202    if (cgraph->grads[isrc]) {
6203        cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
6204    } else {
6205        struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
6206        cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
6207    }
6208    ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name);
6209    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
6210}
6211
6212static void ggml_add1_or_set(
6213        struct ggml_context * ctx,
6214        struct ggml_cgraph  * cgraph,
6215        size_t                isrc,
6216        struct ggml_tensor  * tensor) {
6217    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
6218    GGML_ASSERT(src);
6219    if (cgraph->grads[isrc]) {
6220        cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
6221    } else {
6222        cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
6223    }
6224    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
6225    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
6226}
6227
6228static void ggml_sub_or_set(
6229        struct ggml_context * ctx,
6230        struct ggml_cgraph  * cgraph,
6231        size_t                isrc,
6232        struct ggml_tensor  * tensor) {
6233    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
6234    GGML_ASSERT(src);
6235    if (cgraph->grads[isrc]) {
6236        cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
6237    } else {
6238        cgraph->grads[isrc] = ggml_neg(ctx, tensor);
6239    }
6240    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
6241    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
6242}
6243
6244static void ggml_compute_backward(
6245        struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
6246    struct ggml_tensor * tensor = cgraph->nodes[i];
6247    struct ggml_tensor * grad   = ggml_graph_get_grad(cgraph, tensor);
6248
6249    if (!grad) {
6250        return;
6251    }
6252
6253    struct ggml_tensor * src0 = tensor->src[0];
6254    struct ggml_tensor * src1 = tensor->src[1];
6255    struct ggml_tensor * src2 = tensor->src[2];
6256    struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
6257    const size_t isrc0 = src0 ? ggml_hash_find(hash_set, src0) : (size_t) -1;
6258    const size_t isrc1 = src1 ? ggml_hash_find(hash_set, src1) : (size_t) -1;
6259    const size_t isrc2 = src2 ? ggml_hash_find(hash_set, src2) : (size_t) -1;
6260    const bool src0_needs_grads = src0 && isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
6261    const bool src1_needs_grads = src1 && isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
6262    const bool src2_needs_grads = src2 && isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
6263
6264    switch (tensor->op) {
6265        case GGML_OP_DUP: {
6266            if (src0_needs_grads) {
6267                ggml_add_or_set(ctx, cgraph, isrc0, grad);
6268            }
6269        } break;
6270        case GGML_OP_ADD: {
6271            if (src0_needs_grads) {
6272                ggml_add_or_set(ctx, cgraph, isrc0, grad);
6273            }
6274            if (src1_needs_grads) {
6275                struct ggml_tensor * tmp = grad;
6276                if (!ggml_are_same_shape(src0, src1)) {
6277                    tmp = ggml_repeat_back(ctx, tmp, src1);
6278                }
6279                ggml_add_or_set(ctx, cgraph, isrc1, tmp);
6280            }
6281        } break;
6282        case GGML_OP_ADD1: {
6283            if (src0_needs_grads) {
6284                ggml_add_or_set(ctx, cgraph, isrc0, grad);
6285            }
6286            if (src1_needs_grads) {
6287                ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean
6288            }
6289        } break;
6290        case GGML_OP_ACC: {
6291            if (src0_needs_grads) {
6292                ggml_add_or_set(ctx, cgraph, isrc0, grad);
6293            }
6294            if (src1_needs_grads) {
6295                const size_t nb1    = ((int32_t *) tensor->op_params)[0];
6296                const size_t nb2    = ((int32_t *) tensor->op_params)[1];
6297                const size_t nb3    = ((int32_t *) tensor->op_params)[2];
6298                const size_t offset = ((int32_t *) tensor->op_params)[3];
6299
6300                struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
6301                    grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
6302                    nb1, nb2, nb3, offset);
6303
6304                ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
6305            }
6306        } break;
6307        case GGML_OP_SUB: {
6308            if (src0_needs_grads) {
6309                ggml_add_or_set(ctx, cgraph, isrc0, grad);
6310            }
6311            if (src1_needs_grads) {
6312                ggml_sub_or_set(ctx, cgraph, isrc1, grad);
6313            }
6314        } break;
6315        case GGML_OP_MUL: {
6316            if (src0_needs_grads) {
6317                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
6318            }
6319            if (src1_needs_grads) {
6320                struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
6321                if (!ggml_are_same_shape(src0, src1)) {
6322                    tmp = ggml_repeat_back(ctx, tmp, src1);
6323                }
6324                ggml_add_or_set(ctx, cgraph, isrc1, tmp);
6325            }
6326        } break;
6327        case GGML_OP_DIV: {
6328            if (src0_needs_grads) {
6329                ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1));
6330            }
6331            if (src1_needs_grads) {
6332                ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1)));
6333            }
6334        } break;
6335        case GGML_OP_SQR: {
6336            if (src0_needs_grads) {
6337                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f));
6338            }
6339        } break;
6340        case GGML_OP_SQRT: {
6341            if (src0_needs_grads) {
6342                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f));
6343            }
6344        } break;
6345        case GGML_OP_LOG: {
6346            if (src0_needs_grads) {
6347                ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0));
6348            }
6349        } break;
6350        case GGML_OP_SIN: {
6351            if (src0_needs_grads) {
6352                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0)));
6353            }
6354        } break;
6355        case GGML_OP_COS: {
6356            if (src0_needs_grads) {
6357                ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0)));
6358            }
6359        } break;
6360        case GGML_OP_SUM: {
6361            if (src0_needs_grads) {
6362                ggml_add1_or_set(ctx, cgraph, isrc0, grad);
6363            }
6364        } break;
6365        case GGML_OP_SUM_ROWS: {
6366            if (src0_needs_grads) {
6367                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
6368            }
6369        } break;
6370        case GGML_OP_MEAN: {
6371            if (src0_needs_grads) {
6372                ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
6373            }
6374        } break;
6375        case GGML_OP_REPEAT: {
6376            if (src0_needs_grads) {
6377                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0));
6378            }
6379        } break;
6380        case GGML_OP_REPEAT_BACK: {
6381            if (src0_needs_grads) {
6382                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
6383            }
6384        } break;
6385        case GGML_OP_RMS_NORM: {
6386            if (src0_needs_grads) {
6387                float eps;
6388                memcpy(&eps, tensor->op_params, sizeof(float));
6389                ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
6390            }
6391        } break;
6392        case GGML_OP_MUL_MAT: {
6393            // https://cs231n.github.io/optimization-2/#staged
6394            // # forward pass
6395            // s0 = np.random.randn(5, 10)
6396            // s1 = np.random.randn(10, 3)
6397            // t = s0.dot(s1)
6398
6399            // # now suppose we had the gradient on t from above in the circuit
6400            // dt = np.random.randn(*t.shape) # same shape as t
6401            // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
6402            // ds1 = t.T.dot(dt)
6403
6404            // tensor.shape [m,p,qq,rr]
6405            // src0.shape   [n,m,q1,r1]
6406            // src1.shape   [n,p,qq,rr]
6407
6408            if (src0_needs_grads) {
6409                GGML_ASSERT(grad->ne[2] == src1->ne[2]);
6410                GGML_ASSERT(grad->ne[3] == src1->ne[3]);
6411                struct ggml_tensor * tmp =
6412                    ggml_out_prod(ctx, // [n,m,qq,rr]
6413                        src1,          // [n,p,qq,rr]
6414                        grad);         // [m,p,qq,rr]
6415                if (!ggml_are_same_shape(tmp, src0)) {
6416                    GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
6417                    GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
6418                    GGML_ASSERT(tmp->ne[3] == 1);
6419
6420                    const int64_t nr2 = tmp->ne[2] / src0->ne[2];
6421                    const size_t nb2 = tmp->nb[2] * nr2;
6422                    const size_t nb3 = tmp->nb[2];
6423
6424                    tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
6425                    tmp = ggml_repeat_back(ctx, tmp, src0);
6426                }
6427                ggml_add_or_set(ctx, cgraph, isrc0, tmp);
6428            }
6429            if (src1_needs_grads) {
6430                ggml_add_or_set(ctx, cgraph, isrc1,
6431                        // ggml_mul_mat(ctx,                   // [n,p,qq,rr]
6432                        //     ggml_cont(ctx,                  // [m,n,q1,r1]
6433                        //         ggml_transpose(ctx, src0)), // [m,n,q1,r1]
6434                        //     grad),                          // [m,p,qq,rr]
6435
6436                        // when src0 is bigger than tensor->grad (this is mostly the case in llama),
6437                        // avoid transpose of src0, rather transpose smaller tensor->grad
6438                        // and then use ggml_out_prod
6439                        ggml_out_prod(ctx,      // [n,p,qq,rr]
6440                            src0,               // [n,m,q1,r1]
6441                            ggml_transpose(ctx, // [p,m,qq,rr]
6442                                grad)));        // [m,p,qq,rr]
6443            }
6444        } break;
6445        case GGML_OP_SCALE: {
6446            if (src0_needs_grads) {
6447                float s;
6448                memcpy(&s, tensor->op_params, sizeof(float));
6449                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
6450            }
6451        } break;
6452        case GGML_OP_SET: {
6453            const size_t nb1    = ((const int32_t *) tensor->op_params)[0];
6454            const size_t nb2    = ((const int32_t *) tensor->op_params)[1];
6455            const size_t nb3    = ((const int32_t *) tensor->op_params)[2];
6456            const size_t offset = ((const int32_t *) tensor->op_params)[3];
6457
6458            struct ggml_tensor * tensor_grad_view = NULL;
6459
6460            if (src0_needs_grads || src1_needs_grads) {
6461                GGML_ASSERT(src0->type == tensor->type);
6462                GGML_ASSERT(!cgraph->grads[isrc0] ||                      cgraph->grads[isrc0]->type == grad->type);
6463                GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type);
6464
6465                tensor_grad_view = ggml_view_4d(ctx,
6466                    grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
6467                    nb1, nb2, nb3, offset);
6468            }
6469
6470            if (src0_needs_grads) {
6471                struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view);
6472                ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false));
6473            }
6474
6475            if (src1_needs_grads) {
6476                ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
6477            }
6478        } break;
6479        case GGML_OP_CPY: {
6480            // cpy overwrites value of src1 by src0 and returns view(src1)
6481            // the overwriting is mathematically equivalent to:
6482            // tensor = src0 * 1 + src1 * 0
6483            if (src0_needs_grads) {
6484                // dsrc0 = dtensor * 1
6485                ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
6486            }
6487            if (src1_needs_grads) {
6488                // dsrc1 = dtensor * 0 -> noop
6489            }
6490        } break;
6491        case GGML_OP_CONT: {
6492            // same as cpy
6493            if (src0_needs_grads) {
6494                GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
6495                GGML_ASSERT(ggml_is_contiguous(grad));
6496                GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
6497                ggml_add_or_set(ctx, cgraph, isrc0,
6498                    ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
6499            }
6500        } break;
6501        case GGML_OP_RESHAPE: {
6502            if (src0_needs_grads) {
6503                struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad);
6504                ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0));
6505            }
6506        } break;
6507        case GGML_OP_VIEW: {
6508            if (src0_needs_grads) {
6509                size_t offset;
6510
6511                memcpy(&offset, tensor->op_params, sizeof(offset));
6512
6513                size_t nb1 = tensor->nb[1];
6514                size_t nb2 = tensor->nb[2];
6515                size_t nb3 = tensor->nb[3];
6516
6517                if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) {
6518                    // gradient is typically F32, but src0 could be other type
6519                    size_t ng = ggml_element_size(cgraph->grads[isrc0]);
6520                    size_t n0 = ggml_element_size(src0);
6521                    GGML_ASSERT(offset % n0 == 0);
6522                    GGML_ASSERT(nb1 % n0 == 0);
6523                    GGML_ASSERT(nb2 % n0 == 0);
6524                    GGML_ASSERT(nb3 % n0 == 0);
6525                    offset = (offset / n0) * ng;
6526                    nb1 = (nb1 / n0) * ng;
6527                    nb2 = (nb2 / n0) * ng;
6528                    nb3 = (nb3 / n0) * ng;
6529                }
6530
6531                ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);
6532            }
6533        } break;
6534        case GGML_OP_PERMUTE: {
6535            if (src0_needs_grads) {
6536                const int32_t * axes = (const int32_t *) tensor->op_params;
6537                const int axis0 = axes[0] & 0x3;
6538                const int axis1 = axes[1] & 0x3;
6539                const int axis2 = axes[2] & 0x3;
6540                const int axis3 = axes[3] & 0x3;
6541                int axb[4] = {0,0,0,0}; // axes backward
6542                axb[axis0] = 0;
6543                axb[axis1] = 1;
6544                axb[axis2] = 2;
6545                axb[axis3] = 3;
6546                ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3]));
6547            }
6548        } break;
6549        case GGML_OP_TRANSPOSE: {
6550            if (src0_needs_grads) {
6551                ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad));
6552            }
6553        } break;
6554        case GGML_OP_GET_ROWS: {
6555            if (src0_needs_grads) {
6556                ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0));
6557            }
6558            if (src1_needs_grads) {
6559                // noop
6560            }
6561        } break;
6562        case GGML_OP_DIAG_MASK_INF: {
6563            if (src0_needs_grads) {
6564                /* ggml_diag_mask_inf_impl() shouldn't be here */
6565                /* ref:  https://github.com/ggml-org/llama.cpp/pull/4203#discussion_r1412377992 */
6566                const int n_past = ((const int32_t *) tensor->op_params)[0];
6567                ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
6568            }
6569        } break;
6570        case GGML_OP_DIAG_MASK_ZERO: {
6571            if (src0_needs_grads) {
6572                const int n_past = ((const int32_t *) tensor->op_params)[0];
6573                ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
6574            }
6575        } break;
6576        case GGML_OP_SOFT_MAX: {
6577            if (src0_needs_grads) {
6578                float scale    = 1.0f;
6579                float max_bias = 0.0f;
6580
6581                memcpy(&scale,    (const float *) tensor->op_params + 0, sizeof(float));
6582                memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
6583
6584                ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
6585            }
6586            GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
6587        } break;
6588        case GGML_OP_ROPE: {
6589            if (src0_needs_grads) {
6590                //const int n_past = ((int32_t *) tensor->op_params)[0];
6591                const int n_dims     = ((const int32_t *) tensor->op_params)[1];
6592                const int mode       = ((const int32_t *) tensor->op_params)[2];
6593                //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
6594                const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
6595                float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
6596                int sections[4] = {0, 0, 0, 0};
6597
6598                memcpy(&freq_base,   (const float *) tensor->op_params +  5, sizeof(float));
6599                memcpy(&freq_scale,  (const float *) tensor->op_params +  6, sizeof(float));
6600                memcpy(&ext_factor,  (const float *) tensor->op_params +  7, sizeof(float));
6601                memcpy(&attn_factor, (const float *) tensor->op_params +  8, sizeof(float));
6602                memcpy(&beta_fast,   (const float *) tensor->op_params +  9, sizeof(float));
6603                memcpy(&beta_slow,   (const float *) tensor->op_params + 10, sizeof(float));
6604                memcpy(&sections,                    tensor->op_params + 11, sizeof(sections));
6605
6606                struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
6607                    ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
6608                        mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
6609                    ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
6610                        mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
6611                ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
6612            }
6613            GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
6614        } break;
6615        case GGML_OP_IM2COL: {
6616            if (src1_needs_grads) {
6617                const int32_t s0    = ggml_get_op_params_i32(tensor, 0);
6618                const int32_t s1    = ggml_get_op_params_i32(tensor, 1);
6619                const int32_t p0    = ggml_get_op_params_i32(tensor, 2);
6620                const int32_t p1    = ggml_get_op_params_i32(tensor, 3);
6621                const int32_t d0    = ggml_get_op_params_i32(tensor, 4);
6622                const int32_t d1    = ggml_get_op_params_i32(tensor, 5);
6623                const bool    is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
6624
6625                ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
6626            }
6627        } break;
6628        case GGML_OP_POOL_2D: {
6629            if (src0_needs_grads) {
6630                const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);
6631                const      int32_t      k0 = ggml_get_op_params_i32(tensor, 1);
6632                const      int32_t      k1 = ggml_get_op_params_i32(tensor, 2);
6633                const      int32_t      s0 = ggml_get_op_params_i32(tensor, 3);
6634                const      int32_t      s1 = ggml_get_op_params_i32(tensor, 4);
6635                const      int32_t      p0 = ggml_get_op_params_i32(tensor, 5);
6636                const      int32_t      p1 = ggml_get_op_params_i32(tensor, 6);
6637
6638                ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1));
6639            }
6640        } break;
6641        case GGML_OP_WIN_PART:
6642        case GGML_OP_WIN_UNPART:
6643        case GGML_OP_UNARY: {
6644            switch (ggml_get_unary_op(tensor)) {
6645                case GGML_UNARY_OP_ABS: {
6646                    if (src0_needs_grads) {
6647                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad));
6648                    }
6649                } break;
6650                case GGML_UNARY_OP_SGN: {
6651                    // noop
6652                } break;
6653                case GGML_UNARY_OP_NEG: {
6654                    if (src0_needs_grads) {
6655                        ggml_sub_or_set(ctx, cgraph, isrc0, grad);
6656                    }
6657                } break;
6658                case GGML_UNARY_OP_STEP: {
6659                    // noop
6660                } break;
6661                case GGML_UNARY_OP_RELU: {
6662                    if (src0_needs_grads) {
6663                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad));
6664                    }
6665                } break;
6666                case GGML_UNARY_OP_SILU: {
6667                    if (src0_needs_grads) {
6668                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
6669                    }
6670                } break;
6671                case GGML_UNARY_OP_EXP: {
6672                    if (src0_needs_grads) {
6673                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
6674                    }
6675                } break;
6676                case GGML_UNARY_OP_EXPM1: {
6677                    if (src0_needs_grads) {
6678                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_exp(ctx, src0)));
6679                    }
6680                } break;
6681                case GGML_UNARY_OP_SOFTPLUS: {
6682                    if (src0_needs_grads) {
6683                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sigmoid(ctx, src0)));
6684                    }
6685                } break;
6686                default: {
6687                    fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
6688                        __func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));
6689                    GGML_ABORT("fatal error");
6690                } //break;
6691            }
6692        } break;
6693        case GGML_OP_CROSS_ENTROPY_LOSS: {
6694            if (src0_needs_grads) {
6695                ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
6696            }
6697            GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
6698        } break;
6699        case GGML_OP_GLU: {
6700            switch (ggml_get_glu_op(tensor)) {
6701                case GGML_GLU_OP_SWIGLU: {
6702                    if (src0_needs_grads) {
6703                        GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
6704                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
6705                    }
6706                    if (src1_needs_grads) {
6707                        ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
6708                    }
6709                } break;
6710                default: {
6711                    GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
6712                } //break;
6713            }
6714        } break;
6715        case GGML_OP_NONE: {
6716            // noop
6717        } break;
6718        case GGML_OP_COUNT:
6719        default: {
6720            GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
6721        } //break;
6722    }
6723
6724    GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0]));
6725    GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1]));
6726    GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
6727}
6728
6729static size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) {
6730    if (node->op != GGML_OP_NONE && compute) {
6731        node->flags |= GGML_TENSOR_FLAG_COMPUTE;
6732    }
6733
6734    const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
6735    GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
6736
6737    if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
6738        // already visited
6739
6740        if (compute) {
6741            // update the compute flag regardless
6742            for (int i = 0; i < GGML_MAX_SRC; ++i) {
6743                struct ggml_tensor * src = node->src[i];
6744                if (src && ((src->flags & GGML_TENSOR_FLAG_COMPUTE) == 0)) {
6745                    ggml_visit_parents_graph(cgraph, src, true);
6746                }
6747            }
6748        }
6749
6750        return node_hash_pos;
6751    }
6752
6753    // This is the first time we see this node in the current graph.
6754    cgraph->visited_hash_set.keys[node_hash_pos] = node;
6755    ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
6756    cgraph->use_counts[node_hash_pos] = 0;
6757
6758    for (int i = 0; i < GGML_MAX_SRC; ++i) {
6759        const int k =
6760            (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
6761            (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
6762            /* unknown order, just fall back to using i */ i;
6763
6764        struct ggml_tensor * src = node->src[k];
6765        if (src) {
6766            const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute);
6767
6768            // Update the use count for this operand.
6769            cgraph->use_counts[src_hash_pos]++;
6770        }
6771    }
6772
6773    if (node->op == GGML_OP_NONE && !(node->flags & GGML_TENSOR_FLAG_PARAM)) {
6774        // reached a leaf node, not part of the gradient graph (e.g. a constant)
6775        GGML_ASSERT(cgraph->n_leafs < cgraph->size);
6776
6777        if (strlen(node->name) == 0) {
6778            ggml_format_name(node, "leaf_%d", cgraph->n_leafs);
6779        }
6780
6781        cgraph->leafs[cgraph->n_leafs] = node;
6782        cgraph->n_leafs++;
6783    } else {
6784        GGML_ASSERT(cgraph->n_nodes < cgraph->size);
6785
6786        if (strlen(node->name) == 0) {
6787            ggml_format_name(node, "node_%d", cgraph->n_nodes);
6788        }
6789
6790        cgraph->nodes[cgraph->n_nodes] = node;
6791        cgraph->n_nodes++;
6792    }
6793
6794    return node_hash_pos;
6795}
6796
6797static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) {
6798    if (!expand) {
6799        // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand
6800        ggml_graph_clear(cgraph);
6801    }
6802
6803    const int n_old = cgraph->n_nodes;
6804
6805    ggml_visit_parents_graph(cgraph, tensor, compute);
6806
6807    const int n_new = cgraph->n_nodes - n_old;
6808    GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
6809
6810    if (n_new > 0) {
6811        // the last added node should always be starting point
6812        GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
6813    }
6814}
6815
6816struct ggml_tensor * ggml_build_forward_select(
6817        struct ggml_cgraph  * cgraph,
6818        struct ggml_tensor ** tensors,
6819        int                   n_tensors,
6820        int                   idx) {
6821    GGML_ASSERT(idx >= 0 && idx < n_tensors);
6822
6823    for (int i = 0; i < n_tensors; i++) {
6824        ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false);
6825    }
6826
6827    return tensors[idx];
6828}
6829
6830void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
6831    ggml_build_forward_impl(cgraph, tensor, true, true);
6832}
6833
6834void ggml_build_backward_expand(
6835        struct ggml_context *  ctx,
6836        struct ggml_cgraph  *  cgraph,
6837        struct ggml_tensor  ** grad_accs) {
6838    GGML_ASSERT(cgraph->n_nodes > 0);
6839    GGML_ASSERT(cgraph->grads);
6840    GGML_ASSERT(cgraph->grad_accs);
6841
6842    const int n_nodes_f = cgraph->n_nodes;
6843
6844    memset(cgraph->grads,     0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
6845    memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
6846    bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool));
6847
6848    {
6849        bool any_params = false;
6850        bool any_loss   = false;
6851        for (int i = 0; i < n_nodes_f; ++i) {
6852            struct ggml_tensor * node = cgraph->nodes[i];
6853            any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM);
6854            any_loss   = any_loss   || (node->flags & GGML_TENSOR_FLAG_LOSS);
6855        }
6856        GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call ggml_set_param?");
6857        GGML_ASSERT(any_loss && "no training loss found, did you forget to call ggml_set_loss?");
6858    }
6859
6860    for (int i = 0; i < n_nodes_f; ++i) {
6861        struct ggml_tensor * node = cgraph->nodes[i];
6862
6863        if (node->type == GGML_TYPE_I32) {
6864            continue;
6865        }
6866
6867        bool node_needs_grad = (node->flags & GGML_TENSOR_FLAG_PARAM) || (node->flags & GGML_TENSOR_FLAG_LOSS);
6868        bool ignore_src[GGML_MAX_SRC] = {false};
6869        switch (node->op) {
6870            // gradients in node->src[0] for one reason or another have no effect on output gradients
6871            case GGML_OP_IM2COL:      // only used for its shape
6872            case GGML_OP_IM2COL_BACK: // same as IM2COL
6873                ignore_src[0] = true;
6874                break;
6875            case GGML_OP_UNARY: {
6876                const enum ggml_unary_op uop = ggml_get_unary_op(node);
6877                // SGN and STEP unary ops are piecewise constant
6878                if (uop == GGML_UNARY_OP_SGN || uop == GGML_UNARY_OP_STEP) {
6879                    ignore_src[0] = true;
6880                }
6881            } break;
6882
6883            // gradients in node->src[1] for one reason or another have no effect on output gradients
6884            case GGML_OP_CPY:           // gradients in CPY target are irrelevant
6885            case GGML_OP_GET_ROWS:      // row indices not differentiable
6886            case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
6887            case GGML_OP_ROPE:          // positions not differentiable
6888                ignore_src[1] = true;
6889                break;
6890
6891            default:
6892                break;
6893        }
6894        for (int j = 0; j < GGML_MAX_SRC; ++j) {
6895            if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) {
6896                continue;
6897            }
6898            GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16);
6899            node_needs_grad = true;
6900            break;
6901        }
6902        if (!node_needs_grad) {
6903            continue;
6904        }
6905
6906        // inplace operations are currently not supported
6907        GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
6908            node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
6909
6910        const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
6911        GGML_ASSERT(ihash != GGML_HASHSET_FULL);
6912        GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
6913        if (grad_accs && grad_accs[i]) {
6914            cgraph->grad_accs[ihash] = grad_accs[i];
6915            cgraph->grads[ihash]     = cgraph->grad_accs[ihash];
6916        } else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
6917            // loss tensors always need a gradient accumulator
6918            cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
6919            cgraph->grads[ihash]     = cgraph->grad_accs[ihash];
6920        }
6921        grads_needed[ihash] = true;
6922    }
6923
6924    for (int i = n_nodes_f - 1; i >= 0; --i) {
6925        // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
6926        // use allocator to automatically make inplace operations
6927        ggml_compute_backward(ctx, cgraph, i, grads_needed);
6928    }
6929
6930    free(grads_needed);
6931}
6932
6933static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
6934    void * ptr = *p;
6935    ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
6936    *p = (void *) ((char *) ptr + size);
6937    return ptr;
6938}
6939
6940static size_t ggml_graph_nbytes(size_t size, bool grads) {
6941    size_t hash_size = ggml_hash_size(size * 2);
6942    void * p = 0;
6943    incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
6944    incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
6945    incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
6946    incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
6947    incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
6948    if (grads) {
6949        incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
6950        incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs
6951    }
6952    incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
6953
6954    size_t nbytes = (size_t) p;
6955    return nbytes;
6956}
6957
6958size_t ggml_graph_overhead_custom(size_t size, bool grads) {
6959    return GGML_OBJECT_SIZE + GGML_PAD(ggml_graph_nbytes(size, grads), GGML_MEM_ALIGN);
6960}
6961
6962size_t ggml_graph_overhead(void) {
6963    return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false);
6964}
6965
6966struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) {
6967    const size_t obj_size = ggml_graph_nbytes(size, grads);
6968    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_GRAPH, obj_size);
6969    struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
6970
6971    // the size of the hash table is doubled since it needs to hold both nodes and leafs
6972    size_t hash_size = ggml_hash_size(size * 2);
6973
6974    void * p = cgraph + 1;
6975
6976    struct ggml_tensor ** nodes_ptr      =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6977    struct ggml_tensor ** leafs_ptr      =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6978    int32_t             * use_counts_ptr =         incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
6979    struct ggml_tensor ** hash_keys_ptr  =         incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6980    struct ggml_tensor ** grads_ptr      = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6981    struct ggml_tensor ** grad_accs_ptr  = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
6982
6983    ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
6984
6985    // check that we allocated the correct amount of memory
6986    assert(obj_size == (size_t)((char *)p - (char *)cgraph));
6987
6988    *cgraph = (struct ggml_cgraph) {
6989        /*.size         =*/ size,
6990        /*.n_nodes      =*/ 0,
6991        /*.n_leafs      =*/ 0,
6992        /*.nodes        =*/ nodes_ptr,
6993        /*.grads        =*/ grads_ptr,
6994        /*.grad_accs    =*/ grad_accs_ptr,
6995        /*.leafs        =*/ leafs_ptr,
6996        /*.use_counts   =*/ use_counts_ptr,
6997        /*.hash_table   =*/ { hash_size, hash_used, hash_keys_ptr },
6998        /*.order        =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
6999    };
7000
7001    ggml_hash_set_reset(&cgraph->visited_hash_set);
7002    if (grads) {
7003        memset(cgraph->grads,     0, hash_size*sizeof(struct ggml_tensor *));
7004        memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
7005    }
7006
7007    return cgraph;
7008}
7009
7010struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
7011    return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
7012}
7013
7014struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {
7015    struct ggml_cgraph cgraph = {
7016        /*.size             =*/ 0,
7017        /*.n_nodes          =*/ i1 - i0,
7018        /*.n_leafs          =*/ 0,
7019        /*.nodes            =*/ cgraph0->nodes + i0,
7020        /*.grads            =*/ NULL, // gradients would need visited_hash_set
7021        /*.grad_accs        =*/ NULL,
7022        /*.leafs            =*/ NULL,
7023        /*.use_counts       =*/ cgraph0->use_counts,
7024        /*.visited_hash_set =*/ cgraph0->visited_hash_set,
7025        /*.order            =*/ cgraph0->order,
7026    };
7027
7028    return cgraph;
7029}
7030
7031void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
7032    GGML_ASSERT(dst->size >= src->n_leafs);
7033    GGML_ASSERT(dst->size >= src->n_nodes);
7034    GGML_ASSERT(dst->visited_hash_set.size >= src->visited_hash_set.size);
7035
7036    dst->n_leafs = src->n_leafs;
7037    dst->n_nodes = src->n_nodes;
7038    dst->order   = src->order;
7039
7040    for (int i = 0; i < src->n_leafs; ++i) {
7041        dst->leafs[i] = src->leafs[i];
7042    }
7043
7044    for (int i = 0; i < src->n_nodes; ++i) {
7045        dst->nodes[i] = src->nodes[i];
7046    }
7047
7048    for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
7049        // copy all hashset keys (tensors) that are in use
7050        if (ggml_bitset_get(src->visited_hash_set.used, i)) {
7051            size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
7052            dst->use_counts[new_hash_pos] = src->use_counts[i];
7053        }
7054    }
7055
7056    if (dst->grads) {
7057        memset(dst->grads,     0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
7058        memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
7059    }
7060    if (src->grads) {
7061        GGML_ASSERT(dst->grads     != NULL);
7062        GGML_ASSERT(dst->grad_accs != NULL);
7063        for (int i = 0; i < src->n_nodes; ++i) {
7064            const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
7065            const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
7066
7067            GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
7068            GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
7069            GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
7070            GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
7071
7072            dst->grads[igrad_dst]     = src->grads[igrad_src];
7073            dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
7074        }
7075    }
7076}
7077
7078struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
7079    struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
7080    ggml_graph_cpy(cgraph, result);
7081    return result;
7082}
7083
7084struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
7085    if (ggml_is_empty(tensor)) {
7086        return tensor;
7087    }
7088    if (tensor->buffer) {
7089        ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
7090    } else {
7091        GGML_ASSERT(tensor->data);
7092        memset(tensor->data, 0, ggml_nbytes(tensor));
7093    }
7094    return tensor;
7095}
7096
7097void ggml_graph_reset(struct ggml_cgraph * cgraph) {
7098    if (!cgraph) {
7099        return;
7100    }
7101    GGML_ASSERT(cgraph->grads != NULL);
7102
7103    for (int i = 0; i < cgraph->n_nodes; i++) {
7104        struct ggml_tensor * node     = cgraph->nodes[i];
7105        struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node);
7106
7107        if (node->op == GGML_OP_OPT_STEP_ADAMW) {
7108            // clear momenta
7109            ggml_set_zero(node->src[2]);
7110            ggml_set_zero(node->src[3]);
7111        }
7112
7113        // initial gradients of loss should be 1, 0 otherwise
7114        if (grad_acc) {
7115            if (node->flags & GGML_TENSOR_FLAG_LOSS) {
7116                GGML_ASSERT(grad_acc->type == GGML_TYPE_F32);
7117                GGML_ASSERT(ggml_is_scalar(grad_acc));
7118
7119                const float onef = 1.0f;
7120                if (grad_acc->buffer) {
7121                    ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float));
7122                } else {
7123                    GGML_ASSERT(grad_acc->data);
7124                    *((float *) grad_acc->data) = onef;
7125                }
7126            } else {
7127                ggml_set_zero(grad_acc);
7128            }
7129        }
7130    }
7131}
7132
7133void ggml_graph_clear(struct ggml_cgraph * cgraph) {
7134    cgraph->n_leafs = 0;
7135    cgraph->n_nodes = 0;
7136    ggml_hash_set_reset(&cgraph->visited_hash_set);
7137}
7138
7139int ggml_graph_size(struct ggml_cgraph * cgraph) {
7140    return cgraph->size;
7141}
7142
7143struct ggml_tensor * ggml_graph_node(struct ggml_cgraph * cgraph, int i) {
7144    if (i < 0) {
7145        GGML_ASSERT(cgraph->n_nodes + i >= 0);
7146        return cgraph->nodes[cgraph->n_nodes + i];
7147    }
7148
7149    GGML_ASSERT(i < cgraph->n_nodes);
7150    return cgraph->nodes[i];
7151}
7152
7153struct ggml_tensor ** ggml_graph_nodes(struct ggml_cgraph * cgraph) {
7154    return cgraph->nodes;
7155}
7156
7157int ggml_graph_n_nodes(struct ggml_cgraph * cgraph) {
7158    return cgraph->n_nodes;
7159}
7160
7161void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
7162    GGML_ASSERT(cgraph->size > cgraph->n_nodes);
7163    cgraph->nodes[cgraph->n_nodes] = tensor;
7164    cgraph->n_nodes++;
7165}
7166
7167struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) {
7168    for (int i = 0; i < cgraph->n_leafs; i++) {
7169        struct ggml_tensor * leaf = cgraph->leafs[i];
7170
7171        if (strcmp(leaf->name, name) == 0) {
7172            return leaf;
7173        }
7174    }
7175
7176    for (int i = 0; i < cgraph->n_nodes; i++) {
7177        struct ggml_tensor * node = cgraph->nodes[i];
7178
7179        if (strcmp(node->name, name) == 0) {
7180            return node;
7181        }
7182    }
7183
7184    return NULL;
7185}
7186
7187struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
7188    const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
7189    return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grads ? cgraph->grads[igrad] : NULL;
7190}
7191
7192struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
7193    const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
7194    return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grad_accs ? cgraph->grad_accs[igrad] : NULL;
7195}
7196
7197void ggml_graph_print(const struct ggml_cgraph * cgraph) {
7198    GGML_LOG_INFO("=== GRAPH ===\n");
7199
7200    GGML_LOG_INFO("n_nodes = %d\n", cgraph->n_nodes);
7201    for (int i = 0; i < cgraph->n_nodes; i++) {
7202        struct ggml_tensor * node = cgraph->nodes[i];
7203
7204        GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
7205                i,
7206                node->ne[0], node->ne[1], node->ne[2],
7207                ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" :
7208                      ggml_graph_get_grad(cgraph, node) ? "g" : " ");
7209    }
7210
7211    GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs);
7212    for (int i = 0; i < cgraph->n_leafs; i++) {
7213        struct ggml_tensor * node = cgraph->leafs[i];
7214
7215        GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
7216                i,
7217                node->ne[0], node->ne[1],
7218                ggml_op_name(node->op),
7219                ggml_get_name(node));
7220    }
7221
7222    GGML_LOG_INFO("========================================\n");
7223}
7224
7225static int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph,
7226                                      const int *                idxs,
7227                                      int                        count,
7228                                      const struct ggml_tensor * tensor) {
7229    GGML_ASSERT(cgraph && idxs);
7230    for (int i = 0; i < count; ++i) {
7231        const int node_idx = idxs[i];
7232
7233        if (node_idx >= cgraph->n_nodes) {
7234            return -1;
7235        }
7236        if (cgraph->nodes[node_idx] == tensor) {
7237            return i;
7238        }
7239    }
7240    return -1;
7241}
7242
7243bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
7244                                const int *                node_idxs,
7245                                int                        count,
7246                                const enum ggml_op *       ops,
7247                                const int *                outputs,
7248                                int                        num_outputs) {
7249    GGML_ASSERT(outputs && num_outputs > 0);
7250
7251    for (int i = 0; i < count; ++i) {
7252        if (node_idxs[i] >= cgraph->n_nodes) {
7253            return false;
7254        }
7255
7256        const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
7257
7258        if (node->op != ops[i]) {
7259            return false;
7260        }
7261
7262        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
7263            return false;
7264        }
7265
7266        if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
7267            continue;
7268        }
7269
7270        if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
7271            return false;
7272        }
7273
7274        int subgraph_uses = 0;
7275        for (int j = i + 1; j < count; ++j) {
7276            const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
7277            for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
7278                if (other_node->src[src_idx] == node) {
7279                    subgraph_uses++;
7280                }
7281            }
7282        }
7283
7284        if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {
7285            return false;
7286        }
7287
7288        // if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
7289        struct ggml_tensor * view_src = node->view_src;
7290        while (view_src) {
7291            if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {
7292                return false;
7293            }
7294            view_src = view_src->view_src;
7295        }
7296    }
7297
7298    return true;
7299}
7300
7301// check if node is part of the graph
7302static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
7303    if (cgraph == NULL) {
7304        return true;
7305    }
7306
7307    for (int i = 0; i < cgraph->n_nodes; i++) {
7308        if (cgraph->nodes[i] == node) {
7309            return true;
7310        }
7311    }
7312
7313    return false;
7314}
7315
7316static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
7317    for (int i = 0; i < cgraph->n_nodes; i++) {
7318        struct ggml_tensor * parent = cgraph->nodes[i];
7319        struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent);
7320
7321        if (grad == node) {
7322            return parent;
7323        }
7324    }
7325
7326    return NULL;
7327}
7328
7329static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label)  {
7330    struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node);
7331    struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent);
7332    fprintf(fp, "  \"%p\" -> \"%p\" [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
7333            gparent0 ? (void *) gparent0 : (void *) parent,
7334            gparent ? (void *) gparent : (void *) node,
7335            gparent ? "empty" : "vee",
7336            gparent ? "dashed" : "solid",
7337            label);
7338}
7339
7340static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label)  {
7341    fprintf(fp, "  \"%p\" -> \"%p\" [ label = \"%s\"; ]\n",
7342            (void *) parent,
7343            (void *) node,
7344            label);
7345}
7346
7347void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) {
7348    char color[16];
7349
7350    FILE * fp = ggml_fopen(filename, "w");
7351    GGML_ASSERT(fp);
7352
7353    fprintf(fp, "digraph G {\n");
7354    fprintf(fp, "  newrank = true;\n");
7355    fprintf(fp, "  rankdir = TB;\n");
7356
7357    for (int i = 0; i < gb->n_nodes; i++) {
7358        struct ggml_tensor * node = gb->nodes[i];
7359        struct ggml_tensor * grad = ggml_graph_get_grad(gb, node);
7360
7361        if (ggml_graph_get_parent(gb, node) != NULL) {
7362            continue;
7363        }
7364
7365        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
7366            snprintf(color, sizeof(color), "yellow");
7367        } else if (grad) {
7368            if (ggml_graph_find(cgraph, node)) {
7369                snprintf(color, sizeof(color), "green");
7370            } else {
7371                snprintf(color, sizeof(color), "lightblue");
7372            }
7373        } else {
7374            snprintf(color, sizeof(color), "white");
7375        }
7376
7377        fprintf(fp, "  \"%p\" [ "
7378                    "style = filled; fillcolor = %s; shape = record; "
7379                    "label=\"",
7380                (void *) node, color);
7381
7382        if (strlen(node->name) > 0) {
7383            fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
7384        } else {
7385            fprintf(fp, "(%s)|", ggml_type_name(node->type));
7386        }
7387
7388        if (ggml_is_matrix(node)) {
7389            fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op));
7390        } else {
7391            fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));
7392        }
7393
7394        if (grad) {
7395            fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(grad->op));
7396        } else {
7397            fprintf(fp, "\"; ]\n");
7398        }
7399    }
7400
7401    for (int i = 0; i < gb->n_leafs; i++) {
7402        struct ggml_tensor * node = gb->leafs[i];
7403
7404        snprintf(color, sizeof(color), "pink");
7405
7406        fprintf(fp, "  \"%p\" [ "
7407                    "style = filled; fillcolor = %s; shape = record; "
7408                    "label=\"<x>",
7409                (void *) node, color);
7410
7411        if (strlen(node->name) > 0) {
7412            fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
7413        } else {
7414            fprintf(fp, "(%s)|", ggml_type_name(node->type));
7415        }
7416
7417        fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
7418        if (ggml_nelements(node) < 5 && node->data != NULL) {
7419            fprintf(fp, " | (");
7420            for (int j = 0; j < ggml_nelements(node); j++) {
7421                // FIXME: use ggml-backend to obtain the tensor data
7422                //if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
7423                //    fprintf(fp, "%d", ggml_get_i32_1d(node, j));
7424                //}
7425                //else if (node->type == GGML_TYPE_F32 ||
7426                //         node->type == GGML_TYPE_F16 ||
7427                //         node->type == GGML_TYPE_BF16) {
7428                //    fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
7429                //}
7430                //else
7431                {
7432                    fprintf(fp, "#");
7433                }
7434                if (j < ggml_nelements(node) - 1) {
7435                    fprintf(fp, ", ");
7436                }
7437            }
7438            fprintf(fp, ")");
7439        }
7440        fprintf(fp, "\"; ]\n");
7441    }
7442
7443    for (int i = 0; i < gb->n_nodes; i++) {
7444        struct ggml_tensor * node = gb->nodes[i];
7445
7446        for (int j = 0; j < GGML_MAX_SRC; j++) {
7447            if (node->src[j]) {
7448                char label[16];
7449                snprintf(label, sizeof(label), "src %d", j);
7450                ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);
7451            }
7452        }
7453    }
7454
7455    for (int i = 0; i < gb->n_leafs; i++) {
7456        struct ggml_tensor * node = gb->leafs[i];
7457
7458        for (int j = 0; j < GGML_MAX_SRC; j++) {
7459            if (node->src[j]) {
7460                char label[16];
7461                snprintf(label, sizeof(label), "src %d", j);
7462                ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);
7463            }
7464        }
7465    }
7466
7467    fprintf(fp, "}\n");
7468
7469    fclose(fp);
7470
7471    GGML_LOG_INFO("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
7472}
7473
7474////////////////////////////////////////////////////////////////////////////////
7475
7476void ggml_set_input(struct ggml_tensor * tensor) {
7477    tensor->flags |= GGML_TENSOR_FLAG_INPUT;
7478}
7479
7480void ggml_set_output(struct ggml_tensor * tensor) {
7481    tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
7482}
7483
7484void ggml_set_param(struct ggml_tensor * tensor) {
7485    GGML_ASSERT(tensor->op == GGML_OP_NONE);
7486    tensor->flags |= GGML_TENSOR_FLAG_PARAM;
7487}
7488
7489void ggml_set_loss(struct ggml_tensor * tensor) {
7490    GGML_ASSERT(ggml_is_scalar(tensor));
7491    GGML_ASSERT(tensor->type == GGML_TYPE_F32);
7492    tensor->flags |= GGML_TENSOR_FLAG_LOSS;
7493}
7494
7495////////////////////////////////////////////////////////////////////////////////
7496
7497void ggml_quantize_init(enum ggml_type type) {
7498    ggml_critical_section_start();
7499
7500    switch (type) {
7501        case GGML_TYPE_IQ2_XXS:
7502        case GGML_TYPE_IQ2_XS:
7503        case GGML_TYPE_IQ2_S:
7504        case GGML_TYPE_IQ1_S:
7505        case GGML_TYPE_IQ1_M:   iq2xs_init_impl(type); break;
7506        case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
7507        case GGML_TYPE_IQ3_S:   iq3xs_init_impl(512); break;
7508        default: // nothing
7509            break;
7510    }
7511
7512    ggml_critical_section_end();
7513}
7514
7515void ggml_quantize_free(void) {
7516    ggml_critical_section_start();
7517
7518    iq2xs_free_impl(GGML_TYPE_IQ2_XXS);
7519    iq2xs_free_impl(GGML_TYPE_IQ2_XS);
7520    iq2xs_free_impl(GGML_TYPE_IQ2_S);
7521    iq2xs_free_impl(GGML_TYPE_IQ1_S);
7522    iq2xs_free_impl(GGML_TYPE_IQ1_M);
7523    iq3xs_free_impl(256);
7524    iq3xs_free_impl(512);
7525
7526    ggml_critical_section_end();
7527}
7528
7529bool ggml_quantize_requires_imatrix(enum ggml_type type) {
7530    return
7531        type == GGML_TYPE_IQ2_XXS ||
7532        type == GGML_TYPE_IQ2_XS  ||
7533        type == GGML_TYPE_IQ1_S;//   ||
7534        //type == GGML_TYPE_IQ1_M;
7535}
7536
7537size_t ggml_quantize_chunk(
7538        enum ggml_type   type,
7539           const float * src,
7540                  void * dst,
7541               int64_t   start,
7542               int64_t   nrows,
7543               int64_t   n_per_row,
7544           const float * imatrix) {
7545    const int64_t n = (int64_t) nrows * n_per_row;
7546
7547    if (ggml_quantize_requires_imatrix(type)) {
7548        GGML_ASSERT(imatrix != NULL);
7549    }
7550
7551    GGML_ASSERT(start % type_traits[type].blck_size == 0);
7552    GGML_ASSERT(start % n_per_row == 0);
7553
7554    ggml_quantize_init(type); // this is noop if already initialized
7555
7556    const size_t start_row = start / n_per_row;
7557    const size_t row_size  = ggml_row_size(type, n_per_row);
7558
7559    size_t result = 0;
7560
7561    switch (type) {
7562        case GGML_TYPE_Q4_0:    result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7563        case GGML_TYPE_Q4_1:    result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7564        case GGML_TYPE_Q5_0:    result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7565        case GGML_TYPE_Q5_1:    result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7566        case GGML_TYPE_Q8_0:    result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7567        case GGML_TYPE_MXFP4:   result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7568        case GGML_TYPE_Q2_K:    result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7569        case GGML_TYPE_Q3_K:    result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7570        case GGML_TYPE_Q4_K:    result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7571        case GGML_TYPE_Q5_K:    result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7572        case GGML_TYPE_Q6_K:    result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7573        case GGML_TYPE_TQ1_0:   result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7574        case GGML_TYPE_TQ2_0:   result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7575        case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7576        case GGML_TYPE_IQ2_XS:  result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7577        case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7578        case GGML_TYPE_IQ3_S:   result = quantize_iq3_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7579        case GGML_TYPE_IQ2_S:   result = quantize_iq2_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7580        case GGML_TYPE_IQ1_S:   result = quantize_iq1_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7581        case GGML_TYPE_IQ1_M:   result = quantize_iq1_m  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7582        case GGML_TYPE_IQ4_NL:  result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7583        case GGML_TYPE_IQ4_XS:  result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
7584        case GGML_TYPE_F16:
7585            {
7586                size_t elemsize = sizeof(ggml_fp16_t);
7587                ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
7588                result = n * elemsize;
7589            } break;
7590        case GGML_TYPE_BF16:
7591            {
7592                size_t elemsize = sizeof(ggml_bf16_t);
7593                ggml_fp32_to_bf16_row_ref(src + start, (ggml_bf16_t *)dst + start, n);
7594                result = n * elemsize;
7595            } break;
7596        case GGML_TYPE_F32:
7597            {
7598                size_t elemsize = sizeof(float);
7599                result = n * elemsize;
7600                memcpy((uint8_t *)dst + start * elemsize, src + start, result);
7601            } break;
7602        default:
7603            assert(false);
7604    }
7605
7606    GGML_ASSERT(result == nrows * row_size);
7607
7608    return result;
7609}
7610
7611////////////////////////////////////////////////////////////////////////////////
7612
7613void ggml_log_get(ggml_log_callback * log_callback, void ** user_data) {
7614    *log_callback = g_logger_state.log_callback;
7615    *user_data    = g_logger_state.log_callback_user_data;
7616}
7617
7618void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
7619    g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
7620    g_logger_state.log_callback_user_data = user_data;
7621}
7622
7623void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) {
7624    p->n_threads  = n_threads;
7625    p->prio       = 0;     // default priority (usually means normal or inherited)
7626    p->poll       = 50;    // hybrid-polling enabled
7627    p->strict_cpu = false; // no strict placement (all threads share same cpumask)
7628    p->paused     = false; // threads are ready to go
7629    memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
7630}
7631
7632struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) {
7633    struct ggml_threadpool_params p;
7634    ggml_threadpool_params_init(&p, n_threads);
7635    return p;
7636}
7637
7638bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {
7639    if (p0->n_threads      != p1->n_threads  )    return false;
7640    if (p0->prio           != p1->prio       )    return false;
7641    if (p0->poll           != p1->poll       )    return false;
7642    if (p0->strict_cpu     != p1->strict_cpu )    return false;
7643    return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
7644}