1#pragma once
   2
   3//
   4// GGML Tensor Library
   5//
   6// This documentation is still a work in progress.
   7// If you wish some specific topics to be covered, feel free to drop a comment:
   8//
   9//   https://github.com/ggml-org/whisper.cpp/issues/40
  10//
  11// ## Overview
  12//
  13// This library implements:
  14//
  15//  - a set of tensor operations
  16//  - automatic differentiation
  17//  - basic optimization algorithms
  18//
  19// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
  20// but is not limited to, the following:
  21//
  22//  - linear regression
  23//  - support vector machines
  24//  - neural networks
  25//
  26// The library allows the user to define a certain function using the available tensor operations. This function
  27// definition is represented internally via a computation graph. Each tensor operation in the function definition
  28// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
  29// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
  30// using one of the available optimization algorithms.
  31//
  32// For example, here we define the function: f(x) = a*x^2 + b
  33//
  34//   {
  35//       struct ggml_init_params params = {
  36//           .mem_size   = 16*1024*1024,
  37//           .mem_buffer = NULL,
  38//       };
  39//
  40//       // memory allocation happens here
  41//       struct ggml_context * ctx = ggml_init(params);
  42//
  43//       struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  44//
  45//       ggml_set_param(ctx, x); // x is an input variable
  46//
  47//       struct ggml_tensor * a  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  48//       struct ggml_tensor * b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  49//       struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
  50//       struct ggml_tensor * f  = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
  51//
  52//       ...
  53//   }
  54//
  55// Notice that the function definition above does not involve any actual computation. The computation is performed only
  56// when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
  57//
  58//   {
  59//       ...
  60//
  61//       struct ggml_cgraph * gf = ggml_new_graph(ctx);
  62//       ggml_build_forward_expand(gf, f);
  63//
  64//       // set the input variable and parameter values
  65//       ggml_set_f32(x, 2.0f);
  66//       ggml_set_f32(a, 3.0f);
  67//       ggml_set_f32(b, 4.0f);
  68//
  69//       ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
  70//
  71//       printf("f = %f\n", ggml_get_f32_1d(f, 0));
  72//
  73//       ...
  74//   }
  75//
  76// The actual computation is performed in the ggml_graph_compute() function.
  77//
  78// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
  79// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
  80// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
  81// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
  82// actually needed.
  83//
  84// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
  85// differentiation and optimization algorithms.
  86//
  87// The described approach allows to define the function graph once and then compute its forward or backward graphs
  88// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
  89// the user can avoid the memory allocation overhead at runtime.
  90//
  91// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
  92// citizens, but in theory the library can be extended to support FP8 and integer data types.
  93//
  94// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
  95// and binary operations. Most of the available operations fall into one of these two categories. With time, it became
  96// clear that the library needs to support more complex operations. The way to support these operations is not clear
  97// yet, but a few examples are demonstrated in the following operations:
  98//
  99//   - ggml_permute()
 100//   - ggml_conv_1d_1s()
 101//   - ggml_conv_1d_2s()
 102//
 103// For each tensor operator, the library implements a forward and backward computation function. The forward function
 104// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
 105// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
 106// calculus class, or watch the following video:
 107//
 108//   What is Automatic Differentiation?
 109//   https://www.youtube.com/watch?v=wG_nF1awSSY
 110//
 111//
 112// ## Tensor data (struct ggml_tensor)
 113//
 114// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
 115// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
 116// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
 117//
 118//   {
 119//       struct ggml_tensor * c = ggml_add(ctx, a, b);
 120//
 121//       assert(c->src[0] == a);
 122//       assert(c->src[1] == b);
 123//   }
 124//
 125// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
 126// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
 127// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
 128// permutation. All tensor operations have to take the stride into account and not assume that the tensor is
 129// contiguous in memory.
 130//
 131// The data of the tensor is accessed via the "data" pointer. For example:
 132//
 133//   {
 134//       const int nx = 2;
 135//       const int ny = 3;
 136//
 137//       struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny);
 138//
 139//       for (int y = 0; y < ny; y++) {
 140//           for (int x = 0; x < nx; x++) {
 141//               *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y;
 142//           }
 143//       }
 144//
 145//       ...
 146//   }
 147//
 148// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
 149//
 150// ## The matrix multiplication operator (ggml_mul_mat)
 151//
 152// TODO
 153//
 154//
 155// ## Multi-threading
 156//
 157// TODO
 158//
 159//
 160// ## Overview of ggml.c
 161//
 162// TODO
 163//
 164//
 165// ## SIMD optimizations
 166//
 167// TODO
 168//
 169//
 170// ## Debugging ggml
 171//
 172// TODO
 173//
 174//
 175
 176#ifdef GGML_SHARED
 177#    if defined(_WIN32) && !defined(__MINGW32__)
 178#        ifdef GGML_BUILD
 179#            define GGML_API __declspec(dllexport) extern
 180#        else
 181#            define GGML_API __declspec(dllimport) extern
 182#        endif
 183#    else
 184#        define GGML_API __attribute__ ((visibility ("default"))) extern
 185#    endif
 186#else
 187#    define GGML_API extern
 188#endif
 189
 190// TODO: support for clang
 191#ifdef __GNUC__
 192#    define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
 193#elif defined(_MSC_VER)
 194#    define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
 195#else
 196#    define GGML_DEPRECATED(func, hint) func
 197#endif
 198
 199#ifndef __GNUC__
 200#    define GGML_ATTRIBUTE_FORMAT(...)
 201#elif defined(__MINGW32__) && !defined(__clang__)
 202#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
 203#else
 204#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
 205#endif
 206
 207#if defined(_WIN32) && !defined(_WIN32_WINNT)
 208#    define _WIN32_WINNT 0x0A00
 209#endif
 210
 211#include <stdbool.h>
 212#include <stddef.h>
 213#include <stdint.h>
 214#include <stdio.h>
 215
 216#define GGML_FILE_MAGIC   0x67676d6c // "ggml"
 217#define GGML_FILE_VERSION 2
 218
 219#define GGML_QNT_VERSION        2    // bump this on quantization format changes
 220#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 221
 222#define GGML_MAX_DIMS           4
 223#define GGML_MAX_PARAMS         2048
 224#define GGML_MAX_SRC            10
 225#define GGML_MAX_N_THREADS      512
 226#define GGML_MAX_OP_PARAMS      64
 227
 228#ifndef GGML_MAX_NAME
 229#   define GGML_MAX_NAME        64
 230#endif
 231
 232#define GGML_DEFAULT_N_THREADS  4
 233#define GGML_DEFAULT_GRAPH_SIZE 2048
 234
 235#if UINTPTR_MAX == 0xFFFFFFFF
 236    #define GGML_MEM_ALIGN 4
 237#elif defined(__EMSCRIPTEN__)
 238// emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm.
 239// (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.)
 240// ref: https://github.com/ggml-org/llama.cpp/pull/18628
 241    #define GGML_MEM_ALIGN 8
 242#else
 243    #define GGML_MEM_ALIGN 16
 244#endif
 245
 246#define GGML_EXIT_SUCCESS 0
 247#define GGML_EXIT_ABORTED 1
 248
 249// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726
 250#define GGML_ROPE_TYPE_NORMAL 0
 251#define GGML_ROPE_TYPE_NEOX   2
 252#define GGML_ROPE_TYPE_MROPE  8
 253#define GGML_ROPE_TYPE_VISION 24
 254#define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000
 255
 256#define GGML_MROPE_SECTIONS   4
 257
 258#define GGML_UNUSED(x) (void)(x)
 259#ifdef __CUDACC__
 260template<typename... Args>
 261__host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {}
 262#define GGML_UNUSED_VARS(...) ggml_unused_vars_impl(__VA_ARGS__)
 263#else
 264#define GGML_UNUSED_VARS(...) do { (void)sizeof((__VA_ARGS__, 0)); } while(0)
 265#endif // __CUDACC__
 266
 267#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
 268
 269#ifndef NDEBUG
 270#   define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
 271#elif defined(__GNUC__)
 272#   define GGML_UNREACHABLE() __builtin_unreachable()
 273#elif defined(_MSC_VER)
 274#   define GGML_UNREACHABLE() __assume(0)
 275#else
 276#   define GGML_UNREACHABLE() ((void) 0)
 277#endif
 278
 279#ifdef __cplusplus
 280#   define GGML_NORETURN [[noreturn]]
 281#elif defined(_MSC_VER)
 282#   define GGML_NORETURN __declspec(noreturn)
 283#else
 284#   define GGML_NORETURN _Noreturn
 285#endif
 286
 287#define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__)
 288#define GGML_ASSERT(x) if (!(x)) GGML_ABORT("GGML_ASSERT(%s) failed", #x)
 289
 290// used to copy the number of elements and stride in bytes of tensors into local variables.
 291// main purpose is to reduce code duplication and improve readability.
 292//
 293// example:
 294//
 295//    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
 296//    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb);
 297//
 298#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
 299    const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \
 300    GGML_UNUSED(prefix##0);
 301#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
 302    GGML_TENSOR_LOCALS_1    (type, prefix, pointer, array) \
 303    const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \
 304    GGML_UNUSED(prefix##1);
 305#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
 306    GGML_TENSOR_LOCALS_2    (type, prefix, pointer, array) \
 307    const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \
 308    GGML_UNUSED(prefix##2);
 309#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
 310    GGML_TENSOR_LOCALS_3  (type, prefix, pointer, array) \
 311    const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \
 312    GGML_UNUSED(prefix##3);
 313
 314#define GGML_TENSOR_UNARY_OP_LOCALS \
 315    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
 316    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
 317    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
 318    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 319
 320#define GGML_TENSOR_BINARY_OP_LOCALS \
 321    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
 322    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
 323    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
 324    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
 325    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
 326    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 327
 328#define GGML_TENSOR_TERNARY_OP_LOCALS \
 329    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
 330    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
 331    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
 332    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
 333    GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
 334    GGML_TENSOR_LOCALS(size_t,  nb2, src2, nb) \
 335    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
 336    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 337
 338#define GGML_TENSOR_BINARY_OP_LOCALS01 \
 339    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
 340    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
 341    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
 342    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
 343
 344#ifdef  __cplusplus
 345extern "C" {
 346#endif
 347
 348    // Function type used in fatal error callbacks
 349    typedef void (*ggml_abort_callback_t)(const char * error_message);
 350
 351    // Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)
 352    // Returns the old callback for chaining
 353    GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback);
 354
 355    GGML_NORETURN GGML_ATTRIBUTE_FORMAT(3, 4)
 356    GGML_API void ggml_abort(const char * file, int line, const char * fmt, ...);
 357
 358    enum ggml_status {
 359        GGML_STATUS_ALLOC_FAILED = -2,
 360        GGML_STATUS_FAILED = -1,
 361        GGML_STATUS_SUCCESS = 0,
 362        GGML_STATUS_ABORTED = 1,
 363    };
 364
 365    // get ggml_status name string
 366    GGML_API const char * ggml_status_to_string(enum ggml_status status);
 367
 368    // ieee 754-2008 half-precision float16
 369    // todo: make this not an integral type
 370    typedef uint16_t ggml_fp16_t;
 371    GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t);
 372    GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);
 373    GGML_API void        ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);
 374    GGML_API void        ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);
 375
 376    // google brain half-precision bfloat16
 377    typedef struct { uint16_t bits; } ggml_bf16_t;
 378    GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
 379    GGML_API float       ggml_bf16_to_fp32(ggml_bf16_t);  // consider just doing << 16
 380    GGML_API void        ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
 381    GGML_API void        ggml_fp32_to_bf16_row_ref(const float *, ggml_bf16_t *, int64_t);
 382    GGML_API void        ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
 383
 384    struct ggml_object;
 385    struct ggml_context;
 386    struct ggml_cgraph;
 387
 388    // NOTE: always add types at the end of the enum to keep backward compatibility
 389    enum ggml_type {
 390        GGML_TYPE_F32     = 0,
 391        GGML_TYPE_F16     = 1,
 392        GGML_TYPE_Q4_0    = 2,
 393        GGML_TYPE_Q4_1    = 3,
 394        // GGML_TYPE_Q4_2 = 4, support has been removed
 395        // GGML_TYPE_Q4_3 = 5, support has been removed
 396        GGML_TYPE_Q5_0    = 6,
 397        GGML_TYPE_Q5_1    = 7,
 398        GGML_TYPE_Q8_0    = 8,
 399        GGML_TYPE_Q8_1    = 9,
 400        GGML_TYPE_Q2_K    = 10,
 401        GGML_TYPE_Q3_K    = 11,
 402        GGML_TYPE_Q4_K    = 12,
 403        GGML_TYPE_Q5_K    = 13,
 404        GGML_TYPE_Q6_K    = 14,
 405        GGML_TYPE_Q8_K    = 15,
 406        GGML_TYPE_IQ2_XXS = 16,
 407        GGML_TYPE_IQ2_XS  = 17,
 408        GGML_TYPE_IQ3_XXS = 18,
 409        GGML_TYPE_IQ1_S   = 19,
 410        GGML_TYPE_IQ4_NL  = 20,
 411        GGML_TYPE_IQ3_S   = 21,
 412        GGML_TYPE_IQ2_S   = 22,
 413        GGML_TYPE_IQ4_XS  = 23,
 414        GGML_TYPE_I8      = 24,
 415        GGML_TYPE_I16     = 25,
 416        GGML_TYPE_I32     = 26,
 417        GGML_TYPE_I64     = 27,
 418        GGML_TYPE_F64     = 28,
 419        GGML_TYPE_IQ1_M   = 29,
 420        GGML_TYPE_BF16    = 30,
 421        // GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files
 422        // GGML_TYPE_Q4_0_4_8 = 32,
 423        // GGML_TYPE_Q4_0_8_8 = 33,
 424        GGML_TYPE_TQ1_0   = 34,
 425        GGML_TYPE_TQ2_0   = 35,
 426        // GGML_TYPE_IQ4_NL_4_4 = 36,
 427        // GGML_TYPE_IQ4_NL_4_8 = 37,
 428        // GGML_TYPE_IQ4_NL_8_8 = 38,
 429        GGML_TYPE_MXFP4   = 39, // MXFP4 (1 block)
 430        GGML_TYPE_COUNT   = 40,
 431    };
 432
 433    // precision
 434    enum ggml_prec {
 435        GGML_PREC_DEFAULT =  0, // stored as ggml_tensor.op_params, 0 by default
 436        GGML_PREC_F32     = 10,
 437    };
 438
 439    // model file types
 440    enum ggml_ftype {
 441        GGML_FTYPE_UNKNOWN        = -1,
 442        GGML_FTYPE_ALL_F32        = 0,
 443        GGML_FTYPE_MOSTLY_F16     = 1,  // except 1d tensors
 444        GGML_FTYPE_MOSTLY_Q4_0    = 2,  // except 1d tensors
 445        GGML_FTYPE_MOSTLY_Q4_1    = 3,  // except 1d tensors
 446        GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
 447        GGML_FTYPE_MOSTLY_Q8_0    = 7,  // except 1d tensors
 448        GGML_FTYPE_MOSTLY_Q5_0    = 8,  // except 1d tensors
 449        GGML_FTYPE_MOSTLY_Q5_1    = 9,  // except 1d tensors
 450        GGML_FTYPE_MOSTLY_Q2_K    = 10, // except 1d tensors
 451        GGML_FTYPE_MOSTLY_Q3_K    = 11, // except 1d tensors
 452        GGML_FTYPE_MOSTLY_Q4_K    = 12, // except 1d tensors
 453        GGML_FTYPE_MOSTLY_Q5_K    = 13, // except 1d tensors
 454        GGML_FTYPE_MOSTLY_Q6_K    = 14, // except 1d tensors
 455        GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
 456        GGML_FTYPE_MOSTLY_IQ2_XS  = 16, // except 1d tensors
 457        GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
 458        GGML_FTYPE_MOSTLY_IQ1_S   = 18, // except 1d tensors
 459        GGML_FTYPE_MOSTLY_IQ4_NL  = 19, // except 1d tensors
 460        GGML_FTYPE_MOSTLY_IQ3_S   = 20, // except 1d tensors
 461        GGML_FTYPE_MOSTLY_IQ2_S   = 21, // except 1d tensors
 462        GGML_FTYPE_MOSTLY_IQ4_XS  = 22, // except 1d tensors
 463        GGML_FTYPE_MOSTLY_IQ1_M   = 23, // except 1d tensors
 464        GGML_FTYPE_MOSTLY_BF16    = 24, // except 1d tensors
 465        GGML_FTYPE_MOSTLY_MXFP4   = 25, // except 1d tensors
 466    };
 467
 468    // available tensor operations:
 469    enum ggml_op {
 470        GGML_OP_NONE = 0,
 471
 472        GGML_OP_DUP,
 473        GGML_OP_ADD,
 474        GGML_OP_ADD_ID,
 475        GGML_OP_ADD1,
 476        GGML_OP_ACC,
 477        GGML_OP_SUB,
 478        GGML_OP_MUL,
 479        GGML_OP_DIV,
 480        GGML_OP_SQR,
 481        GGML_OP_SQRT,
 482        GGML_OP_LOG,
 483        GGML_OP_SIN,
 484        GGML_OP_COS,
 485        GGML_OP_SUM,
 486        GGML_OP_SUM_ROWS,
 487        GGML_OP_CUMSUM,
 488        GGML_OP_MEAN,
 489        GGML_OP_ARGMAX,
 490        GGML_OP_COUNT_EQUAL,
 491        GGML_OP_REPEAT,
 492        GGML_OP_REPEAT_BACK,
 493        GGML_OP_CONCAT,
 494        GGML_OP_SILU_BACK,
 495        GGML_OP_NORM, // normalize
 496        GGML_OP_RMS_NORM,
 497        GGML_OP_RMS_NORM_BACK,
 498        GGML_OP_GROUP_NORM,
 499        GGML_OP_L2_NORM,
 500
 501        GGML_OP_MUL_MAT,
 502        GGML_OP_MUL_MAT_ID,
 503        GGML_OP_OUT_PROD,
 504
 505        GGML_OP_SCALE,
 506        GGML_OP_SET,
 507        GGML_OP_CPY,
 508        GGML_OP_CONT,
 509        GGML_OP_RESHAPE,
 510        GGML_OP_VIEW,
 511        GGML_OP_PERMUTE,
 512        GGML_OP_TRANSPOSE,
 513        GGML_OP_GET_ROWS,
 514        GGML_OP_GET_ROWS_BACK,
 515        GGML_OP_SET_ROWS,
 516        GGML_OP_DIAG,
 517        GGML_OP_DIAG_MASK_INF,
 518        GGML_OP_DIAG_MASK_ZERO,
 519        GGML_OP_SOFT_MAX,
 520        GGML_OP_SOFT_MAX_BACK,
 521        GGML_OP_ROPE,
 522        GGML_OP_ROPE_BACK,
 523        GGML_OP_CLAMP,
 524        GGML_OP_CONV_TRANSPOSE_1D,
 525        GGML_OP_IM2COL,
 526        GGML_OP_IM2COL_BACK,
 527        GGML_OP_IM2COL_3D,
 528        GGML_OP_CONV_2D,
 529        GGML_OP_CONV_3D,
 530        GGML_OP_CONV_2D_DW,
 531        GGML_OP_CONV_TRANSPOSE_2D,
 532        GGML_OP_POOL_1D,
 533        GGML_OP_POOL_2D,
 534        GGML_OP_POOL_2D_BACK,
 535        GGML_OP_UPSCALE,
 536        GGML_OP_PAD,
 537        GGML_OP_PAD_REFLECT_1D,
 538        GGML_OP_ROLL,
 539        GGML_OP_ARANGE,
 540        GGML_OP_TIMESTEP_EMBEDDING,
 541        GGML_OP_ARGSORT,
 542        GGML_OP_TOP_K,
 543        GGML_OP_LEAKY_RELU,
 544        GGML_OP_TRI,
 545        GGML_OP_FILL,
 546
 547        GGML_OP_FLASH_ATTN_EXT,
 548        GGML_OP_FLASH_ATTN_BACK,
 549        GGML_OP_SSM_CONV,
 550        GGML_OP_SSM_SCAN,
 551        GGML_OP_WIN_PART,
 552        GGML_OP_WIN_UNPART,
 553        GGML_OP_GET_REL_POS,
 554        GGML_OP_ADD_REL_POS,
 555        GGML_OP_RWKV_WKV6,
 556        GGML_OP_GATED_LINEAR_ATTN,
 557        GGML_OP_RWKV_WKV7,
 558        GGML_OP_SOLVE_TRI,
 559
 560        GGML_OP_UNARY,
 561
 562        GGML_OP_MAP_CUSTOM1,
 563        GGML_OP_MAP_CUSTOM2,
 564        GGML_OP_MAP_CUSTOM3,
 565
 566        GGML_OP_CUSTOM,
 567
 568        GGML_OP_CROSS_ENTROPY_LOSS,
 569        GGML_OP_CROSS_ENTROPY_LOSS_BACK,
 570        GGML_OP_OPT_STEP_ADAMW,
 571        GGML_OP_OPT_STEP_SGD,
 572
 573        GGML_OP_GLU,
 574
 575        GGML_OP_COUNT,
 576    };
 577
 578    enum ggml_unary_op {
 579        GGML_UNARY_OP_ABS,
 580        GGML_UNARY_OP_SGN,
 581        GGML_UNARY_OP_NEG,
 582        GGML_UNARY_OP_STEP,
 583        GGML_UNARY_OP_TANH,
 584        GGML_UNARY_OP_ELU,
 585        GGML_UNARY_OP_RELU,
 586        GGML_UNARY_OP_SIGMOID,
 587        GGML_UNARY_OP_GELU,
 588        GGML_UNARY_OP_GELU_QUICK,
 589        GGML_UNARY_OP_SILU,
 590        GGML_UNARY_OP_HARDSWISH,
 591        GGML_UNARY_OP_HARDSIGMOID,
 592        GGML_UNARY_OP_EXP,
 593        GGML_UNARY_OP_EXPM1,
 594        GGML_UNARY_OP_SOFTPLUS,
 595        GGML_UNARY_OP_GELU_ERF,
 596        GGML_UNARY_OP_XIELU,
 597        GGML_UNARY_OP_FLOOR,
 598        GGML_UNARY_OP_CEIL,
 599        GGML_UNARY_OP_ROUND,
 600        GGML_UNARY_OP_TRUNC,
 601
 602        GGML_UNARY_OP_COUNT,
 603    };
 604
 605    enum ggml_glu_op {
 606        GGML_GLU_OP_REGLU,
 607        GGML_GLU_OP_GEGLU,
 608        GGML_GLU_OP_SWIGLU,
 609        GGML_GLU_OP_SWIGLU_OAI,
 610        GGML_GLU_OP_GEGLU_ERF,
 611        GGML_GLU_OP_GEGLU_QUICK,
 612
 613        GGML_GLU_OP_COUNT,
 614    };
 615
 616    enum ggml_object_type {
 617        GGML_OBJECT_TYPE_TENSOR,
 618        GGML_OBJECT_TYPE_GRAPH,
 619        GGML_OBJECT_TYPE_WORK_BUFFER
 620    };
 621
 622    enum ggml_log_level {
 623        GGML_LOG_LEVEL_NONE  = 0,
 624        GGML_LOG_LEVEL_DEBUG = 1,
 625        GGML_LOG_LEVEL_INFO  = 2,
 626        GGML_LOG_LEVEL_WARN  = 3,
 627        GGML_LOG_LEVEL_ERROR = 4,
 628        GGML_LOG_LEVEL_CONT  = 5, // continue previous log
 629    };
 630
 631    // this tensor...
 632    enum ggml_tensor_flag {
 633        GGML_TENSOR_FLAG_INPUT   =  1, // ...is an input for the GGML compute graph
 634        GGML_TENSOR_FLAG_OUTPUT  =  2, // ...is an output for the GGML compute graph
 635        GGML_TENSOR_FLAG_PARAM   =  4, // ...contains trainable parameters
 636        GGML_TENSOR_FLAG_LOSS    =  8, // ...defines loss for numerical optimization (multiple loss tensors add up)
 637        GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed
 638    };
 639
 640    enum ggml_tri_type {
 641        GGML_TRI_TYPE_UPPER_DIAG = 0,
 642        GGML_TRI_TYPE_UPPER      = 1,
 643        GGML_TRI_TYPE_LOWER_DIAG = 2,
 644        GGML_TRI_TYPE_LOWER      = 3
 645    };
 646
 647    struct ggml_init_params {
 648        // memory pool
 649        size_t mem_size;   // bytes
 650        void * mem_buffer; // if NULL, memory will be allocated internally
 651        bool   no_alloc;   // don't allocate memory for the tensor data
 652    };
 653
 654    // n-dimensional tensor
 655    struct ggml_tensor {
 656        enum ggml_type type;
 657
 658        struct ggml_backend_buffer * buffer;
 659
 660        int64_t ne[GGML_MAX_DIMS]; // number of elements
 661        size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
 662                                   // nb[0] = ggml_type_size(type)
 663                                   // nb[1] = nb[0]   * (ne[0] / ggml_blck_size(type)) + padding
 664                                   // nb[i] = nb[i-1] * ne[i-1]
 665
 666        // compute data
 667        enum ggml_op op;
 668
 669        // op params - allocated as int32_t for alignment
 670        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
 671
 672        int32_t flags;
 673
 674        struct ggml_tensor * src[GGML_MAX_SRC];
 675
 676        // source tensor and offset for views
 677        struct ggml_tensor * view_src;
 678        size_t               view_offs;
 679
 680        void * data;
 681
 682        char name[GGML_MAX_NAME];
 683
 684        void * extra; // extra things e.g. for ggml-cuda.cu
 685
 686        char padding[8];
 687    };
 688
 689    static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
 690
 691    // Abort callback
 692    // If not NULL, called before ggml computation
 693    // If it returns true, the computation is aborted
 694    typedef bool (*ggml_abort_callback)(void * data);
 695
 696
 697    //
 698    // GUID
 699    //
 700
 701    // GUID types
 702    typedef uint8_t ggml_guid[16];
 703    typedef ggml_guid * ggml_guid_t;
 704
 705    GGML_API bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b);
 706
 707    // misc
 708
 709    GGML_API const char * ggml_version(void);
 710    GGML_API const char * ggml_commit(void);
 711
 712    GGML_API void    ggml_time_init(void); // call this once at the beginning of the program
 713    GGML_API int64_t ggml_time_ms(void);
 714    GGML_API int64_t ggml_time_us(void);
 715    GGML_API int64_t ggml_cycles(void);
 716    GGML_API int64_t ggml_cycles_per_ms(void);
 717
 718    // accepts a UTF-8 path, even on Windows
 719    GGML_API FILE *  ggml_fopen(const char * fname, const char * mode);
 720
 721    GGML_API void    ggml_print_object (const struct ggml_object * obj);
 722    GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
 723
 724    GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
 725    GGML_API int64_t ggml_nrows     (const struct ggml_tensor * tensor);
 726    GGML_API size_t  ggml_nbytes    (const struct ggml_tensor * tensor);
 727    GGML_API size_t  ggml_nbytes_pad(const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
 728
 729    GGML_API int64_t ggml_blck_size(enum ggml_type type);
 730    GGML_API size_t  ggml_type_size(enum ggml_type type);             // size in bytes for all elements in a block
 731    GGML_API size_t  ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
 732
 733    GGML_DEPRECATED(
 734    GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
 735    "use ggml_row_size() instead");
 736
 737    GGML_API const char * ggml_type_name(enum ggml_type type);
 738    GGML_API const char * ggml_op_name  (enum ggml_op   op);
 739    GGML_API const char * ggml_op_symbol(enum ggml_op   op);
 740
 741    GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
 742    GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
 743    GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
 744
 745    GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
 746
 747    GGML_API bool    ggml_is_quantized(enum ggml_type type);
 748
 749    // TODO: temporary until model loading of ggml examples is refactored
 750    GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
 751
 752    GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
 753    GGML_API bool ggml_is_permuted  (const struct ggml_tensor * tensor);
 754    GGML_API bool ggml_is_empty     (const struct ggml_tensor * tensor);
 755    GGML_API bool ggml_is_scalar    (const struct ggml_tensor * tensor);
 756    GGML_API bool ggml_is_vector    (const struct ggml_tensor * tensor);
 757    GGML_API bool ggml_is_matrix    (const struct ggml_tensor * tensor);
 758    GGML_API bool ggml_is_3d        (const struct ggml_tensor * tensor);
 759    GGML_API int  ggml_n_dims       (const struct ggml_tensor * tensor); // returns 1 for scalars
 760
 761    // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation)
 762    GGML_API bool ggml_is_contiguous  (const struct ggml_tensor * tensor);
 763    GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
 764    GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
 765    GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
 766
 767    // returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)
 768    GGML_API bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor);
 769
 770    // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
 771    GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
 772
 773    // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
 774    GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
 775
 776    GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
 777    GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
 778
 779    GGML_API bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
 780
 781    // use this to compute the memory overhead of a tensor
 782    GGML_API size_t ggml_tensor_overhead(void);
 783
 784    GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes);
 785
 786    // main
 787
 788    GGML_API struct ggml_context * ggml_init (struct ggml_init_params params);
 789    GGML_API void                  ggml_reset(struct ggml_context * ctx);
 790    GGML_API void                  ggml_free (struct ggml_context * ctx);
 791
 792    GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
 793
 794    GGML_API bool    ggml_get_no_alloc(struct ggml_context * ctx);
 795    GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
 796
 797    GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);
 798    GGML_API size_t  ggml_get_mem_size       (const struct ggml_context * ctx);
 799    GGML_API size_t  ggml_get_max_tensor_size(const struct ggml_context * ctx);
 800
 801    GGML_API struct ggml_tensor * ggml_new_tensor(
 802            struct ggml_context * ctx,
 803            enum   ggml_type type,
 804            int    n_dims,
 805            const int64_t *ne);
 806
 807    GGML_API struct ggml_tensor * ggml_new_tensor_1d(
 808            struct ggml_context * ctx,
 809            enum   ggml_type type,
 810            int64_t ne0);
 811
 812    GGML_API struct ggml_tensor * ggml_new_tensor_2d(
 813            struct ggml_context * ctx,
 814            enum   ggml_type type,
 815            int64_t ne0,
 816            int64_t ne1);
 817
 818    GGML_API struct ggml_tensor * ggml_new_tensor_3d(
 819            struct ggml_context * ctx,
 820            enum   ggml_type type,
 821            int64_t ne0,
 822            int64_t ne1,
 823            int64_t ne2);
 824
 825    GGML_API struct ggml_tensor * ggml_new_tensor_4d(
 826            struct ggml_context * ctx,
 827            enum   ggml_type type,
 828            int64_t ne0,
 829            int64_t ne1,
 830            int64_t ne2,
 831            int64_t ne3);
 832
 833    GGML_API void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes);
 834
 835    GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
 836    GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
 837
 838    // Context tensor enumeration and lookup
 839    GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx);
 840    GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);
 841    GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
 842
 843    // Converts a flat index into coordinates
 844    GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
 845
 846    GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
 847    GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
 848
 849    GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
 850    GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 851
 852    GGML_API const char *         ggml_get_name   (const struct ggml_tensor * tensor);
 853    GGML_API struct ggml_tensor * ggml_set_name   (      struct ggml_tensor * tensor, const char * name);
 854    GGML_ATTRIBUTE_FORMAT(2, 3)
 855    GGML_API struct ggml_tensor * ggml_format_name(      struct ggml_tensor * tensor, const char * fmt, ...);
 856
 857    // Tensor flags
 858    GGML_API void ggml_set_input(struct ggml_tensor * tensor);
 859    GGML_API void ggml_set_output(struct ggml_tensor * tensor);
 860    GGML_API void ggml_set_param(struct ggml_tensor * tensor);
 861    GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
 862
 863    //
 864    // operations on tensors with backpropagation
 865    //
 866
 867    GGML_API struct ggml_tensor * ggml_dup(
 868            struct ggml_context * ctx,
 869            struct ggml_tensor  * a);
 870
 871    // in-place, returns view(a)
 872    GGML_API struct ggml_tensor * ggml_dup_inplace(
 873            struct ggml_context * ctx,
 874            struct ggml_tensor  * a);
 875
 876    GGML_API struct ggml_tensor * ggml_add(
 877            struct ggml_context * ctx,
 878            struct ggml_tensor  * a,
 879            struct ggml_tensor  * b);
 880
 881    GGML_API struct ggml_tensor * ggml_add_inplace(
 882            struct ggml_context * ctx,
 883            struct ggml_tensor  * a,
 884            struct ggml_tensor  * b);
 885
 886    GGML_API struct ggml_tensor * ggml_add_cast(
 887            struct ggml_context * ctx,
 888            struct ggml_tensor  * a,
 889            struct ggml_tensor  * b,
 890            enum   ggml_type      type);
 891
 892    // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
 893    GGML_API struct ggml_tensor * ggml_add_id(
 894            struct ggml_context * ctx,
 895            struct ggml_tensor  * a,
 896            struct ggml_tensor  * b,
 897            struct ggml_tensor  * ids);
 898
 899    GGML_API struct ggml_tensor * ggml_add1(
 900            struct ggml_context * ctx,
 901            struct ggml_tensor  * a,
 902            struct ggml_tensor  * b);
 903
 904    GGML_API struct ggml_tensor * ggml_add1_inplace(
 905            struct ggml_context * ctx,
 906            struct ggml_tensor  * a,
 907            struct ggml_tensor  * b);
 908
 909    // dst = a
 910    // view(dst, nb1, nb2, nb3, offset) += b
 911    // return dst
 912    GGML_API struct ggml_tensor * ggml_acc(
 913            struct ggml_context * ctx,
 914            struct ggml_tensor  * a,
 915            struct ggml_tensor  * b,
 916            size_t                nb1,
 917            size_t                nb2,
 918            size_t                nb3,
 919            size_t                offset);
 920
 921    GGML_API struct ggml_tensor * ggml_acc_inplace(
 922            struct ggml_context * ctx,
 923            struct ggml_tensor  * a,
 924            struct ggml_tensor  * b,
 925            size_t                nb1,
 926            size_t                nb2,
 927            size_t                nb3,
 928            size_t                offset);
 929
 930    GGML_API struct ggml_tensor * ggml_sub(
 931            struct ggml_context * ctx,
 932            struct ggml_tensor  * a,
 933            struct ggml_tensor  * b);
 934
 935    GGML_API struct ggml_tensor * ggml_sub_inplace(
 936            struct ggml_context * ctx,
 937            struct ggml_tensor  * a,
 938            struct ggml_tensor  * b);
 939
 940    GGML_API struct ggml_tensor * ggml_mul(
 941            struct ggml_context * ctx,
 942            struct ggml_tensor  * a,
 943            struct ggml_tensor  * b);
 944
 945    GGML_API struct ggml_tensor * ggml_mul_inplace(
 946            struct ggml_context * ctx,
 947            struct ggml_tensor  * a,
 948            struct ggml_tensor  * b);
 949
 950    GGML_API struct ggml_tensor * ggml_div(
 951            struct ggml_context * ctx,
 952            struct ggml_tensor  * a,
 953            struct ggml_tensor  * b);
 954
 955    GGML_API struct ggml_tensor * ggml_div_inplace(
 956            struct ggml_context * ctx,
 957            struct ggml_tensor  * a,
 958            struct ggml_tensor  * b);
 959
 960    GGML_API struct ggml_tensor * ggml_sqr(
 961            struct ggml_context * ctx,
 962            struct ggml_tensor  * a);
 963
 964    GGML_API struct ggml_tensor * ggml_sqr_inplace(
 965            struct ggml_context * ctx,
 966            struct ggml_tensor  * a);
 967
 968    GGML_API struct ggml_tensor * ggml_sqrt(
 969            struct ggml_context * ctx,
 970            struct ggml_tensor  * a);
 971
 972    GGML_API struct ggml_tensor * ggml_sqrt_inplace(
 973            struct ggml_context * ctx,
 974            struct ggml_tensor  * a);
 975
 976    GGML_API struct ggml_tensor * ggml_log(
 977            struct ggml_context * ctx,
 978            struct ggml_tensor  * a);
 979
 980    GGML_API struct ggml_tensor * ggml_log_inplace(
 981            struct ggml_context * ctx,
 982            struct ggml_tensor  * a);
 983
 984    GGML_API struct ggml_tensor * ggml_expm1(
 985            struct ggml_context * ctx,
 986            struct ggml_tensor  * a);
 987
 988    GGML_API struct ggml_tensor * ggml_expm1_inplace(
 989            struct ggml_context * ctx,
 990            struct ggml_tensor  * a);
 991
 992    GGML_API struct ggml_tensor * ggml_softplus(
 993            struct ggml_context * ctx,
 994            struct ggml_tensor  * a);
 995
 996    GGML_API struct ggml_tensor * ggml_softplus_inplace(
 997            struct ggml_context * ctx,
 998            struct ggml_tensor  * a);
 999
1000    GGML_API struct ggml_tensor * ggml_sin(
1001            struct ggml_context * ctx,
1002            struct ggml_tensor  * a);
1003
1004    GGML_API struct ggml_tensor * ggml_sin_inplace(
1005            struct ggml_context * ctx,
1006            struct ggml_tensor  * a);
1007
1008    GGML_API struct ggml_tensor * ggml_cos(
1009            struct ggml_context * ctx,
1010            struct ggml_tensor  * a);
1011
1012    GGML_API struct ggml_tensor * ggml_cos_inplace(
1013            struct ggml_context * ctx,
1014            struct ggml_tensor  * a);
1015
1016    // return scalar
1017    GGML_API struct ggml_tensor * ggml_sum(
1018            struct ggml_context * ctx,
1019            struct ggml_tensor  * a);
1020
1021    // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]
1022    GGML_API struct ggml_tensor * ggml_sum_rows(
1023            struct ggml_context * ctx,
1024            struct ggml_tensor  * a);
1025
1026    GGML_API struct ggml_tensor * ggml_cumsum(
1027        struct ggml_context * ctx,
1028        struct ggml_tensor  * a);
1029
1030    // mean along rows
1031    GGML_API struct ggml_tensor * ggml_mean(
1032            struct ggml_context * ctx,
1033            struct ggml_tensor  * a);
1034
1035    // argmax along rows
1036    GGML_API struct ggml_tensor * ggml_argmax(
1037            struct ggml_context * ctx,
1038            struct ggml_tensor  * a);
1039
1040    // count number of equal elements in a and b
1041    GGML_API struct ggml_tensor * ggml_count_equal(
1042            struct ggml_context * ctx,
1043            struct ggml_tensor  * a,
1044            struct ggml_tensor  * b);
1045
1046    // if a is the same shape as b, and a is not parameter, return a
1047    // otherwise, return a new tensor: repeat(a) to fit in b
1048    GGML_API struct ggml_tensor * ggml_repeat(
1049            struct ggml_context * ctx,
1050            struct ggml_tensor  * a,
1051            struct ggml_tensor  * b);
1052
1053    // repeat a to the specified shape
1054    GGML_API struct ggml_tensor * ggml_repeat_4d(
1055            struct ggml_context * ctx,
1056            struct ggml_tensor  * a,
1057                       int64_t    ne0,
1058                       int64_t    ne1,
1059                       int64_t    ne2,
1060                       int64_t    ne3);
1061
1062    // sums repetitions in a into shape of b
1063    GGML_API struct ggml_tensor * ggml_repeat_back(
1064            struct ggml_context * ctx,
1065            struct ggml_tensor  * a,
1066            struct ggml_tensor  * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
1067
1068    // concat a and b along dim
1069    // used in stable-diffusion
1070    GGML_API struct ggml_tensor * ggml_concat(
1071            struct ggml_context * ctx,
1072            struct ggml_tensor  * a,
1073            struct ggml_tensor  * b,
1074            int                   dim);
1075
1076    GGML_API struct ggml_tensor * ggml_abs(
1077            struct ggml_context * ctx,
1078            struct ggml_tensor  * a);
1079
1080    GGML_API struct ggml_tensor * ggml_abs_inplace(
1081            struct ggml_context * ctx,
1082            struct ggml_tensor  * a);
1083
1084    GGML_API struct ggml_tensor * ggml_sgn(
1085            struct ggml_context * ctx,
1086            struct ggml_tensor  * a);
1087
1088    GGML_API struct ggml_tensor * ggml_sgn_inplace(
1089            struct ggml_context * ctx,
1090            struct ggml_tensor  * a);
1091
1092    GGML_API struct ggml_tensor * ggml_neg(
1093            struct ggml_context * ctx,
1094            struct ggml_tensor  * a);
1095
1096    GGML_API struct ggml_tensor * ggml_neg_inplace(
1097            struct ggml_context * ctx,
1098            struct ggml_tensor  * a);
1099
1100    GGML_API struct ggml_tensor * ggml_step(
1101            struct ggml_context * ctx,
1102            struct ggml_tensor  * a);
1103
1104    GGML_API struct ggml_tensor * ggml_step_inplace(
1105            struct ggml_context * ctx,
1106            struct ggml_tensor  * a);
1107
1108    GGML_API struct ggml_tensor * ggml_tanh(
1109            struct ggml_context * ctx,
1110            struct ggml_tensor  * a);
1111
1112    GGML_API struct ggml_tensor * ggml_tanh_inplace(
1113            struct ggml_context * ctx,
1114            struct ggml_tensor  * a);
1115
1116    GGML_API struct ggml_tensor * ggml_elu(
1117            struct ggml_context * ctx,
1118            struct ggml_tensor  * a);
1119
1120    GGML_API struct ggml_tensor * ggml_elu_inplace(
1121            struct ggml_context * ctx,
1122            struct ggml_tensor  * a);
1123
1124    GGML_API struct ggml_tensor * ggml_relu(
1125            struct ggml_context * ctx,
1126            struct ggml_tensor  * a);
1127
1128    GGML_API struct ggml_tensor * ggml_leaky_relu(
1129            struct ggml_context * ctx,
1130            struct ggml_tensor  * a, float negative_slope, bool inplace);
1131
1132    GGML_API struct ggml_tensor * ggml_relu_inplace(
1133            struct ggml_context * ctx,
1134            struct ggml_tensor  * a);
1135
1136    GGML_API struct ggml_tensor * ggml_sigmoid(
1137            struct ggml_context * ctx,
1138            struct ggml_tensor  * a);
1139
1140    GGML_API struct ggml_tensor * ggml_sigmoid_inplace(
1141            struct ggml_context * ctx,
1142            struct ggml_tensor  * a);
1143
1144    GGML_API struct ggml_tensor * ggml_gelu(
1145            struct ggml_context * ctx,
1146            struct ggml_tensor  * a);
1147
1148    GGML_API struct ggml_tensor * ggml_gelu_inplace(
1149            struct ggml_context * ctx,
1150            struct ggml_tensor  * a);
1151
1152    // GELU using erf (error function) when possible
1153    // some backends may fallback to approximation based on Abramowitz and Stegun formula
1154    GGML_API struct ggml_tensor * ggml_gelu_erf(
1155            struct ggml_context * ctx,
1156            struct ggml_tensor  * a);
1157
1158    GGML_API struct ggml_tensor * ggml_gelu_erf_inplace(
1159            struct ggml_context * ctx,
1160            struct ggml_tensor  * a);
1161
1162    GGML_API struct ggml_tensor * ggml_gelu_quick(
1163            struct ggml_context * ctx,
1164            struct ggml_tensor  * a);
1165
1166    GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(
1167            struct ggml_context * ctx,
1168            struct ggml_tensor  * a);
1169
1170    GGML_API struct ggml_tensor * ggml_silu(
1171            struct ggml_context * ctx,
1172            struct ggml_tensor  * a);
1173
1174    GGML_API struct ggml_tensor * ggml_silu_inplace(
1175            struct ggml_context * ctx,
1176            struct ggml_tensor  * a);
1177
1178    // a - x
1179    // b - dy
1180    GGML_API struct ggml_tensor * ggml_silu_back(
1181            struct ggml_context * ctx,
1182            struct ggml_tensor  * a,
1183            struct ggml_tensor  * b);
1184
1185    // hardswish(x) = x * relu6(x + 3) / 6
1186    GGML_API struct ggml_tensor * ggml_hardswish(
1187            struct ggml_context * ctx,
1188            struct ggml_tensor  * a);
1189
1190    // hardsigmoid(x) = relu6(x + 3) / 6
1191    GGML_API struct ggml_tensor * ggml_hardsigmoid(
1192            struct ggml_context * ctx,
1193            struct ggml_tensor  * a);
1194
1195    GGML_API struct ggml_tensor * ggml_exp(
1196            struct ggml_context * ctx,
1197            struct ggml_tensor  * a);
1198
1199    GGML_API struct ggml_tensor * ggml_exp_inplace(
1200            struct ggml_context * ctx,
1201            struct ggml_tensor  * a);
1202
1203    GGML_API struct ggml_tensor * ggml_floor(
1204            struct ggml_context * ctx,
1205            struct ggml_tensor  * a);
1206
1207    GGML_API struct ggml_tensor * ggml_floor_inplace(
1208            struct ggml_context * ctx,
1209            struct ggml_tensor  * a);
1210
1211    GGML_API struct ggml_tensor * ggml_ceil(
1212            struct ggml_context * ctx,
1213            struct ggml_tensor  * a);
1214
1215    GGML_API struct ggml_tensor * ggml_ceil_inplace(
1216            struct ggml_context * ctx,
1217            struct ggml_tensor  * a);
1218
1219    GGML_API struct ggml_tensor * ggml_round(
1220            struct ggml_context * ctx,
1221            struct ggml_tensor  * a);
1222
1223    GGML_API struct ggml_tensor * ggml_round_inplace(
1224            struct ggml_context * ctx,
1225            struct ggml_tensor  * a);
1226
1227     /**
1228     * Truncates the fractional part of each element in the tensor (towards zero).
1229     * For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0
1230     * Similar to std::trunc in C/C++.
1231     */
1232
1233    GGML_API struct ggml_tensor * ggml_trunc(
1234            struct ggml_context * ctx,
1235            struct ggml_tensor  * a);
1236
1237    GGML_API struct ggml_tensor * ggml_trunc_inplace(
1238            struct ggml_context * ctx,
1239            struct ggml_tensor  * a);
1240
1241
1242
1243    // xIELU activation function
1244    // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
1245    // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
1246    // that constrain the positive and negative source alpha values respectively
1247    GGML_API struct ggml_tensor * ggml_xielu(
1248            struct ggml_context * ctx,
1249            struct ggml_tensor  * a,
1250            float alpha_n,
1251            float alpha_p,
1252            float beta,
1253            float eps);
1254
1255    // gated linear unit ops
1256    // A: n columns, r rows,
1257    // result is n / 2 columns, r rows,
1258    // expects gate in second half of row, unless swapped is true
1259    GGML_API struct ggml_tensor * ggml_glu(
1260            struct ggml_context * ctx,
1261             struct ggml_tensor * a,
1262             enum ggml_glu_op     op,
1263             bool                 swapped);
1264
1265    GGML_API struct ggml_tensor * ggml_reglu(
1266            struct ggml_context * ctx,
1267            struct ggml_tensor  * a);
1268
1269    GGML_API struct ggml_tensor * ggml_reglu_swapped(
1270            struct ggml_context * ctx,
1271            struct ggml_tensor  * a);
1272
1273    GGML_API struct ggml_tensor * ggml_geglu(
1274            struct ggml_context * ctx,
1275            struct ggml_tensor  * a);
1276
1277    GGML_API struct ggml_tensor * ggml_geglu_swapped(
1278            struct ggml_context * ctx,
1279            struct ggml_tensor  * a);
1280
1281    GGML_API struct ggml_tensor * ggml_swiglu(
1282            struct ggml_context * ctx,
1283            struct ggml_tensor  * a);
1284
1285    GGML_API struct ggml_tensor * ggml_swiglu_swapped(
1286            struct ggml_context * ctx,
1287            struct ggml_tensor  * a);
1288
1289    GGML_API struct ggml_tensor * ggml_geglu_erf(
1290            struct ggml_context * ctx,
1291            struct ggml_tensor  * a);
1292
1293    GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
1294            struct ggml_context * ctx,
1295            struct ggml_tensor  * a);
1296
1297    GGML_API struct ggml_tensor * ggml_geglu_quick(
1298            struct ggml_context * ctx,
1299            struct ggml_tensor  * a);
1300
1301    GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
1302            struct ggml_context * ctx,
1303            struct ggml_tensor  * a);
1304
1305    // A: n columns, r rows,
1306    // B: n columns, r rows,
1307    GGML_API struct ggml_tensor * ggml_glu_split(
1308            struct ggml_context * ctx,
1309             struct ggml_tensor * a,
1310             struct ggml_tensor * b,
1311             enum ggml_glu_op     op);
1312
1313    GGML_API struct ggml_tensor * ggml_reglu_split(
1314            struct ggml_context * ctx,
1315            struct ggml_tensor  * a,
1316            struct ggml_tensor  * b);
1317
1318    GGML_API struct ggml_tensor * ggml_geglu_split(
1319            struct ggml_context * ctx,
1320            struct ggml_tensor  * a,
1321            struct ggml_tensor  * b);
1322
1323    GGML_API struct ggml_tensor * ggml_swiglu_split(
1324            struct ggml_context * ctx,
1325            struct ggml_tensor  * a,
1326            struct ggml_tensor  * b);
1327
1328    GGML_API struct ggml_tensor * ggml_geglu_erf_split(
1329            struct ggml_context * ctx,
1330            struct ggml_tensor  * a,
1331            struct ggml_tensor  * b);
1332
1333    GGML_API struct ggml_tensor * ggml_geglu_quick_split(
1334            struct ggml_context * ctx,
1335            struct ggml_tensor  * a,
1336            struct ggml_tensor  * b);
1337
1338    GGML_API struct ggml_tensor * ggml_swiglu_oai(
1339            struct ggml_context * ctx,
1340            struct ggml_tensor  * a,
1341            struct ggml_tensor  * b,
1342            float                 alpha,
1343            float                 limit);
1344
1345    // normalize along rows
1346    GGML_API struct ggml_tensor * ggml_norm(
1347            struct ggml_context * ctx,
1348            struct ggml_tensor  * a,
1349            float                 eps);
1350
1351    GGML_API struct ggml_tensor * ggml_norm_inplace(
1352            struct ggml_context * ctx,
1353            struct ggml_tensor  * a,
1354            float                 eps);
1355
1356    GGML_API struct ggml_tensor * ggml_rms_norm(
1357            struct ggml_context * ctx,
1358            struct ggml_tensor  * a,
1359            float                 eps);
1360
1361    GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
1362            struct ggml_context * ctx,
1363            struct ggml_tensor  * a,
1364            float                 eps);
1365
1366    // group normalize along ne0*ne1*n_groups
1367    // used in stable-diffusion
1368    GGML_API struct ggml_tensor * ggml_group_norm(
1369            struct ggml_context * ctx,
1370            struct ggml_tensor  * a,
1371            int                   n_groups,
1372            float                 eps);
1373
1374    GGML_API struct ggml_tensor * ggml_group_norm_inplace(
1375            struct ggml_context * ctx,
1376            struct ggml_tensor  * a,
1377            int                   n_groups,
1378            float                 eps);
1379
1380    // l2 normalize along rows
1381    // used in rwkv v7
1382    GGML_API struct ggml_tensor * ggml_l2_norm(
1383            struct ggml_context * ctx,
1384            struct ggml_tensor  * a,
1385            float                 eps);
1386
1387    GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
1388            struct ggml_context * ctx,
1389            struct ggml_tensor  * a,
1390            float                 eps);
1391
1392    // a - x
1393    // b - dy
1394    GGML_API struct ggml_tensor * ggml_rms_norm_back(
1395            struct ggml_context * ctx,
1396            struct ggml_tensor  * a,
1397            struct ggml_tensor  * b,
1398            float                 eps);
1399
1400    // A: k columns, n rows => [ne03, ne02, n, k]
1401    // B: k columns, m rows  (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
1402    // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
1403    GGML_API struct ggml_tensor * ggml_mul_mat(
1404            struct ggml_context * ctx,
1405            struct ggml_tensor  * a,
1406            struct ggml_tensor  * b);
1407
1408    // change the precision of a matrix multiplication
1409    // set to GGML_PREC_F32 for higher precision (useful for phi-2)
1410    GGML_API void ggml_mul_mat_set_prec(
1411            struct ggml_tensor * a,
1412            enum ggml_prec       prec);
1413
1414    // indirect matrix multiplication
1415    GGML_API struct ggml_tensor * ggml_mul_mat_id(
1416            struct ggml_context * ctx,
1417            struct ggml_tensor  * as,
1418            struct ggml_tensor  * b,
1419            struct ggml_tensor  * ids);
1420
1421    // A: m columns, n rows,
1422    // B: p columns, n rows,
1423    // result is m columns, p rows
1424    GGML_API struct ggml_tensor * ggml_out_prod(
1425            struct ggml_context * ctx,
1426            struct ggml_tensor  * a,
1427            struct ggml_tensor  * b);
1428
1429    //
1430    // operations on tensors without backpropagation
1431    //
1432
1433    GGML_API struct ggml_tensor * ggml_scale(
1434            struct ggml_context * ctx,
1435            struct ggml_tensor  * a,
1436            float                 s);
1437
1438    // in-place, returns view(a)
1439    GGML_API struct ggml_tensor * ggml_scale_inplace(
1440            struct ggml_context * ctx,
1441            struct ggml_tensor  * a,
1442            float                 s);
1443
1444    // x = s * a + b
1445    GGML_API struct ggml_tensor * ggml_scale_bias(
1446        struct ggml_context * ctx,
1447        struct ggml_tensor  * a,
1448        float                 s,
1449        float                 b);
1450
1451    GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
1452        struct ggml_context * ctx,
1453        struct ggml_tensor  * a,
1454        float                 s,
1455        float                 b);
1456
1457    // b -> view(a,offset,nb1,nb2,3), return modified a
1458    GGML_API struct ggml_tensor * ggml_set(
1459            struct ggml_context * ctx,
1460            struct ggml_tensor  * a,
1461            struct ggml_tensor  * b,
1462            size_t                nb1,
1463            size_t                nb2,
1464            size_t                nb3,
1465            size_t                offset); // in bytes
1466
1467    // b -> view(a,offset,nb1,nb2,3), return view(a)
1468    GGML_API struct ggml_tensor * ggml_set_inplace(
1469            struct ggml_context * ctx,
1470            struct ggml_tensor  * a,
1471            struct ggml_tensor  * b,
1472            size_t                nb1,
1473            size_t                nb2,
1474            size_t                nb3,
1475            size_t                offset); // in bytes
1476
1477    GGML_API struct ggml_tensor * ggml_set_1d(
1478            struct ggml_context * ctx,
1479            struct ggml_tensor  * a,
1480            struct ggml_tensor  * b,
1481            size_t                offset); // in bytes
1482
1483    GGML_API struct ggml_tensor * ggml_set_1d_inplace(
1484            struct ggml_context * ctx,
1485            struct ggml_tensor  * a,
1486            struct ggml_tensor  * b,
1487            size_t                offset); // in bytes
1488
1489    // b -> view(a,offset,nb1,nb2,3), return modified a
1490    GGML_API struct ggml_tensor * ggml_set_2d(
1491            struct ggml_context * ctx,
1492            struct ggml_tensor  * a,
1493            struct ggml_tensor  * b,
1494            size_t                nb1,
1495            size_t                offset); // in bytes
1496
1497    // b -> view(a,offset,nb1,nb2,3), return view(a)
1498    GGML_API struct ggml_tensor * ggml_set_2d_inplace(
1499            struct ggml_context * ctx,
1500            struct ggml_tensor  * a,
1501            struct ggml_tensor  * b,
1502            size_t                nb1,
1503            size_t                offset); // in bytes
1504
1505    // a -> b, return view(b)
1506    GGML_API struct ggml_tensor * ggml_cpy(
1507            struct ggml_context * ctx,
1508            struct ggml_tensor  * a,
1509            struct ggml_tensor  * b);
1510
1511    // note: casting from f32 to i32 will discard the fractional part
1512    GGML_API struct ggml_tensor * ggml_cast(
1513            struct ggml_context * ctx,
1514            struct ggml_tensor  * a,
1515            enum   ggml_type      type);
1516
1517    // make contiguous
1518    GGML_API struct ggml_tensor * ggml_cont(
1519            struct ggml_context * ctx,
1520            struct ggml_tensor  * a);
1521
1522    // make contiguous, with new shape
1523    GGML_API struct ggml_tensor * ggml_cont_1d(
1524            struct ggml_context * ctx,
1525            struct ggml_tensor  * a,
1526            int64_t               ne0);
1527
1528    GGML_API struct ggml_tensor * ggml_cont_2d(
1529            struct ggml_context * ctx,
1530            struct ggml_tensor  * a,
1531            int64_t               ne0,
1532            int64_t               ne1);
1533
1534    GGML_API struct ggml_tensor * ggml_cont_3d(
1535            struct ggml_context * ctx,
1536            struct ggml_tensor  * a,
1537            int64_t               ne0,
1538            int64_t               ne1,
1539            int64_t               ne2);
1540
1541    GGML_API struct ggml_tensor * ggml_cont_4d(
1542            struct ggml_context * ctx,
1543            struct ggml_tensor  * a,
1544            int64_t               ne0,
1545            int64_t               ne1,
1546            int64_t               ne2,
1547            int64_t               ne3);
1548
1549    // return view(a), b specifies the new shape
1550    // TODO: when we start computing gradient, make a copy instead of view
1551    GGML_API struct ggml_tensor * ggml_reshape(
1552            struct ggml_context * ctx,
1553            struct ggml_tensor  * a,
1554            struct ggml_tensor  * b);
1555
1556    // return view(a)
1557    // TODO: when we start computing gradient, make a copy instead of view
1558    GGML_API struct ggml_tensor * ggml_reshape_1d(
1559            struct ggml_context * ctx,
1560            struct ggml_tensor  * a,
1561            int64_t               ne0);
1562
1563    GGML_API struct ggml_tensor * ggml_reshape_2d(
1564            struct ggml_context * ctx,
1565            struct ggml_tensor  * a,
1566            int64_t               ne0,
1567            int64_t               ne1);
1568
1569    // return view(a)
1570    // TODO: when we start computing gradient, make a copy instead of view
1571    GGML_API struct ggml_tensor * ggml_reshape_3d(
1572            struct ggml_context * ctx,
1573            struct ggml_tensor  * a,
1574            int64_t               ne0,
1575            int64_t               ne1,
1576            int64_t               ne2);
1577
1578    GGML_API struct ggml_tensor * ggml_reshape_4d(
1579            struct ggml_context * ctx,
1580            struct ggml_tensor  * a,
1581            int64_t               ne0,
1582            int64_t               ne1,
1583            int64_t               ne2,
1584            int64_t               ne3);
1585
1586    // offset in bytes
1587    GGML_API struct ggml_tensor * ggml_view_1d(
1588            struct ggml_context * ctx,
1589            struct ggml_tensor  * a,
1590            int64_t               ne0,
1591            size_t                offset);
1592
1593    GGML_API struct ggml_tensor * ggml_view_2d(
1594            struct ggml_context * ctx,
1595            struct ggml_tensor  * a,
1596            int64_t               ne0,
1597            int64_t               ne1,
1598            size_t                nb1, // row stride in bytes
1599            size_t                offset);
1600
1601    GGML_API struct ggml_tensor * ggml_view_3d(
1602            struct ggml_context * ctx,
1603            struct ggml_tensor  * a,
1604            int64_t               ne0,
1605            int64_t               ne1,
1606            int64_t               ne2,
1607            size_t                nb1, // row   stride in bytes
1608            size_t                nb2, // slice stride in bytes
1609            size_t                offset);
1610
1611    GGML_API struct ggml_tensor * ggml_view_4d(
1612            struct ggml_context * ctx,
1613            struct ggml_tensor  * a,
1614            int64_t               ne0,
1615            int64_t               ne1,
1616            int64_t               ne2,
1617            int64_t               ne3,
1618            size_t                nb1, // row   stride in bytes
1619            size_t                nb2, // slice stride in bytes
1620            size_t                nb3,
1621            size_t                offset);
1622
1623    GGML_API struct ggml_tensor * ggml_permute(
1624            struct ggml_context * ctx,
1625            struct ggml_tensor  * a,
1626            int                   axis0,
1627            int                   axis1,
1628            int                   axis2,
1629            int                   axis3);
1630
1631    // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
1632    GGML_API struct ggml_tensor * ggml_transpose(
1633            struct ggml_context * ctx,
1634            struct ggml_tensor  * a);
1635
1636    // supports 4D a:
1637    // a     [n_embd, ne1, ne2, ne3]
1638    // b I32 [n_rows, ne2, ne3, 1]
1639    //
1640    // return [n_embd, n_rows, ne2, ne3]
1641    GGML_API struct ggml_tensor * ggml_get_rows(
1642            struct ggml_context * ctx,
1643            struct ggml_tensor  * a,  // data
1644            struct ggml_tensor  * b); // row indices
1645
1646    GGML_API struct ggml_tensor * ggml_get_rows_back(
1647            struct ggml_context * ctx,
1648            struct ggml_tensor  * a,  // gradients of ggml_get_rows result
1649            struct ggml_tensor  * b,  // row indices
1650            struct ggml_tensor  * c); // data for ggml_get_rows, only used for its shape
1651
1652    // a TD  [n_embd, ne1,    ne2,    ne3]
1653    // b TS  [n_embd, n_rows, ne02,   ne03] | ne02 == ne2, ne03 == ne3
1654    // c I64 [n_rows, ne11,   ne12,   1]    | c[i] in [0, ne1)
1655    //
1656    // undefined behavior if destination rows overlap
1657    //
1658    // broadcast:
1659    //   ne2 % ne11 == 0
1660    //   ne3 % ne12 == 0
1661    //
1662    // return view(a)
1663    GGML_API struct ggml_tensor * ggml_set_rows(
1664            struct ggml_context * ctx,
1665            struct ggml_tensor  * a,  // destination
1666            struct ggml_tensor  * b,  // source
1667            struct ggml_tensor  * c); // row indices
1668
1669    GGML_API struct ggml_tensor * ggml_diag(
1670        struct ggml_context     * ctx,
1671        struct ggml_tensor      * a);
1672
1673    // set elements above the diagonal to -INF
1674    GGML_API struct ggml_tensor * ggml_diag_mask_inf(
1675            struct ggml_context * ctx,
1676            struct ggml_tensor  * a,
1677            int                   n_past);
1678
1679    // in-place, returns view(a)
1680    GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace(
1681            struct ggml_context * ctx,
1682            struct ggml_tensor  * a,
1683            int                   n_past);
1684
1685    // set elements above the diagonal to 0
1686    GGML_API struct ggml_tensor * ggml_diag_mask_zero(
1687            struct ggml_context * ctx,
1688            struct ggml_tensor  * a,
1689            int                   n_past);
1690
1691    // in-place, returns view(a)
1692    GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace(
1693            struct ggml_context * ctx,
1694            struct ggml_tensor  * a,
1695            int                   n_past);
1696
1697    GGML_API struct ggml_tensor * ggml_soft_max(
1698            struct ggml_context * ctx,
1699            struct ggml_tensor  * a);
1700
1701    // in-place, returns view(a)
1702    GGML_API struct ggml_tensor * ggml_soft_max_inplace(
1703            struct ggml_context * ctx,
1704            struct ggml_tensor  * a);
1705
1706    // a    [ne0, ne01, ne02, ne03]
1707    // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
1708    //
1709    // broadcast:
1710    //   ne02 % ne12 == 0
1711    //   ne03 % ne13 == 0
1712    //
1713    // fused soft_max(a*scale + mask*(ALiBi slope))
1714    // max_bias = 0.0f for no ALiBi
1715    GGML_API struct ggml_tensor * ggml_soft_max_ext(
1716            struct ggml_context * ctx,
1717            struct ggml_tensor  * a,
1718            struct ggml_tensor  * mask,
1719            float                 scale,
1720            float                 max_bias);
1721
1722    GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace(
1723            struct ggml_context * ctx,
1724            struct ggml_tensor  * a,
1725            struct ggml_tensor  * mask,
1726            float                 scale,
1727            float                 max_bias);
1728
1729    GGML_API void ggml_soft_max_add_sinks(
1730            struct ggml_tensor * a,
1731            struct ggml_tensor * sinks);
1732
1733    GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
1734            struct ggml_context * ctx,
1735            struct ggml_tensor  * a,
1736            struct ggml_tensor  * b,
1737            float                 scale,
1738            float                 max_bias);
1739
1740    // in-place, returns view(a)
1741    GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
1742            struct ggml_context * ctx,
1743            struct ggml_tensor  * a,
1744            struct ggml_tensor  * b,
1745            float                 scale,
1746            float                 max_bias);
1747
1748    // rotary position embedding
1749    // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
1750    // if (mode & GGML_ROPE_TYPE_NEOX) - GPT-NeoX style
1751    //
1752    // b is an int32 vector with size a->ne[2], it contains the positions
1753    GGML_API struct ggml_tensor * ggml_rope(
1754            struct ggml_context * ctx,
1755            struct ggml_tensor  * a,
1756            struct ggml_tensor  * b,
1757            int                   n_dims,
1758            int                   mode);
1759
1760    // in-place, returns view(a)
1761    GGML_API struct ggml_tensor * ggml_rope_inplace(
1762            struct ggml_context * ctx,
1763            struct ggml_tensor  * a,
1764            struct ggml_tensor  * b,
1765            int                   n_dims,
1766            int                   mode);
1767
1768    // custom RoPE
1769    // c is freq factors (e.g. phi3-128k), (optional)
1770    GGML_API struct ggml_tensor * ggml_rope_ext(
1771            struct ggml_context * ctx,
1772            struct ggml_tensor  * a,
1773            struct ggml_tensor  * b,
1774            struct ggml_tensor  * c,
1775            int                   n_dims,
1776            int                   mode,
1777            int                   n_ctx_orig,
1778            float                 freq_base,
1779            float                 freq_scale,
1780            float                 ext_factor,
1781            float                 attn_factor,
1782            float                 beta_fast,
1783            float                 beta_slow);
1784
1785    GGML_API struct ggml_tensor * ggml_rope_multi(
1786            struct ggml_context * ctx,
1787            struct ggml_tensor  * a,
1788            struct ggml_tensor  * b,
1789            struct ggml_tensor  * c,
1790            int                   n_dims,
1791            int                   sections[GGML_MROPE_SECTIONS],
1792            int                   mode,
1793            int                   n_ctx_orig,
1794            float                 freq_base,
1795            float                 freq_scale,
1796            float                 ext_factor,
1797            float                 attn_factor,
1798            float                 beta_fast,
1799            float                 beta_slow);
1800
1801    // in-place, returns view(a)
1802    GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
1803            struct ggml_context * ctx,
1804            struct ggml_tensor  * a,
1805            struct ggml_tensor  * b,
1806            struct ggml_tensor  * c,
1807            int                   n_dims,
1808            int                   mode,
1809            int                   n_ctx_orig,
1810            float                 freq_base,
1811            float                 freq_scale,
1812            float                 ext_factor,
1813            float                 attn_factor,
1814            float                 beta_fast,
1815            float                 beta_slow);
1816
1817    GGML_API struct ggml_tensor * ggml_rope_multi_inplace(
1818            struct ggml_context * ctx,
1819            struct ggml_tensor  * a,
1820            struct ggml_tensor  * b,
1821            struct ggml_tensor  * c,
1822            int                   n_dims,
1823            int                   sections[GGML_MROPE_SECTIONS],
1824            int                   mode,
1825            int                   n_ctx_orig,
1826            float                 freq_base,
1827            float                 freq_scale,
1828            float                 ext_factor,
1829            float                 attn_factor,
1830            float                 beta_fast,
1831            float                 beta_slow);
1832
1833    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
1834            struct ggml_context * ctx,
1835            struct ggml_tensor  * a,
1836            struct ggml_tensor  * b,
1837            int                   n_dims,
1838            int                   mode,
1839            int                   n_ctx_orig,
1840            float                 freq_base,
1841            float                 freq_scale,
1842            float                 ext_factor,
1843            float                 attn_factor,
1844            float                 beta_fast,
1845            float                 beta_slow),
1846        "use ggml_rope_ext instead");
1847
1848    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
1849            struct ggml_context * ctx,
1850            struct ggml_tensor  * a,
1851            struct ggml_tensor  * b,
1852            int                   n_dims,
1853            int                   mode,
1854            int                   n_ctx_orig,
1855            float                 freq_base,
1856            float                 freq_scale,
1857            float                 ext_factor,
1858            float                 attn_factor,
1859            float                 beta_fast,
1860            float                 beta_slow),
1861        "use ggml_rope_ext_inplace instead");
1862
1863    // compute correction dims for YaRN RoPE scaling
1864    GGML_API void ggml_rope_yarn_corr_dims(
1865        int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1866
1867    // rotary position embedding backward, i.e compute dx from dy
1868    // a - dy
1869    GGML_API struct ggml_tensor * ggml_rope_ext_back(
1870            struct ggml_context * ctx,
1871            struct ggml_tensor  * a, // gradients of ggml_rope result
1872            struct ggml_tensor  * b, // positions
1873            struct ggml_tensor  * c, // freq factors
1874            int                   n_dims,
1875            int                   mode,
1876            int                   n_ctx_orig,
1877            float                 freq_base,
1878            float                 freq_scale,
1879            float                 ext_factor,
1880            float                 attn_factor,
1881            float                 beta_fast,
1882            float                 beta_slow);
1883
1884    GGML_API struct ggml_tensor * ggml_rope_multi_back(
1885            struct ggml_context * ctx,
1886            struct ggml_tensor  * a,
1887            struct ggml_tensor  * b,
1888            struct ggml_tensor  * c,
1889            int                   n_dims,
1890            int                   sections[4],
1891            int                   mode,
1892            int                   n_ctx_orig,
1893            float                 freq_base,
1894            float                 freq_scale,
1895            float                 ext_factor,
1896            float                 attn_factor,
1897            float                 beta_fast,
1898            float                 beta_slow);
1899
1900
1901    // clamp
1902    // in-place, returns view(a)
1903    GGML_API struct ggml_tensor * ggml_clamp(
1904            struct ggml_context * ctx,
1905            struct ggml_tensor  * a,
1906            float                 min,
1907            float                 max);
1908
1909    // im2col
1910    // converts data into a format that effectively results in a convolution when combined with matrix multiplication
1911    GGML_API struct ggml_tensor * ggml_im2col(
1912            struct ggml_context * ctx,
1913            struct ggml_tensor  * a,  // convolution kernel
1914            struct ggml_tensor  * b,  // data
1915            int                   s0, // stride dimension 0
1916            int                   s1, // stride dimension 1
1917            int                   p0, // padding dimension 0
1918            int                   p1, // padding dimension 1
1919            int                   d0, // dilation dimension 0
1920            int                   d1, // dilation dimension 1
1921            bool                  is_2D,
1922            enum ggml_type        dst_type);
1923
1924    GGML_API struct ggml_tensor * ggml_im2col_back(
1925        struct ggml_context * ctx,
1926        struct ggml_tensor  * a,  // convolution kernel
1927        struct ggml_tensor  * b,  // gradient of im2col output
1928        int64_t             * ne, // shape of im2col input
1929        int                   s0, // stride dimension 0
1930        int                   s1, // stride dimension 1
1931        int                   p0, // padding dimension 0
1932        int                   p1, // padding dimension 1
1933        int                   d0, // dilation dimension 0
1934        int                   d1, // dilation dimension 1
1935        bool                  is_2D);
1936
1937    GGML_API struct ggml_tensor * ggml_conv_1d(
1938            struct ggml_context * ctx,
1939            struct ggml_tensor  * a,   // convolution kernel
1940            struct ggml_tensor  * b,   // data
1941            int                   s0,  // stride
1942            int                   p0,  // padding
1943            int                   d0); // dilation
1944
1945    // conv_1d with padding = half
1946    // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1947    GGML_API struct ggml_tensor* ggml_conv_1d_ph(
1948            struct ggml_context * ctx,
1949            struct ggml_tensor  * a,  // convolution kernel
1950            struct ggml_tensor  * b,  // data
1951            int                   s,  // stride
1952            int                   d); // dilation
1953
1954    // depthwise
1955    // TODO: this is very likely wrong for some cases! - needs more testing
1956    GGML_API struct ggml_tensor * ggml_conv_1d_dw(
1957            struct ggml_context * ctx,
1958            struct ggml_tensor  * a,   // convolution kernel
1959            struct ggml_tensor  * b,   // data
1960            int                   s0,  // stride
1961            int                   p0,  // padding
1962            int                   d0); // dilation
1963
1964    GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph(
1965            struct ggml_context * ctx,
1966            struct ggml_tensor  * a,   // convolution kernel
1967            struct ggml_tensor  * b,   // data
1968            int                   s0,  // stride
1969            int                   d0); // dilation
1970
1971    GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
1972            struct ggml_context * ctx,
1973            struct ggml_tensor  * a,   // convolution kernel
1974            struct ggml_tensor  * b,   // data
1975            int                   s0,  // stride
1976            int                   p0,  // padding
1977            int                   d0); // dilation
1978
1979    GGML_API struct ggml_tensor * ggml_conv_2d(
1980            struct ggml_context * ctx,
1981            struct ggml_tensor  * a,   // convolution kernel
1982            struct ggml_tensor  * b,   // data
1983            int                   s0,  // stride dimension 0
1984            int                   s1,  // stride dimension 1
1985            int                   p0,  // padding dimension 0
1986            int                   p1,  // padding dimension 1
1987            int                   d0,  // dilation dimension 0
1988            int                   d1); // dilation dimension 1
1989
1990    GGML_API struct ggml_tensor * ggml_im2col_3d(
1991            struct ggml_context * ctx,
1992            struct ggml_tensor  * a,
1993            struct ggml_tensor  * b,
1994            int64_t               IC,
1995            int                   s0, // stride width
1996            int                   s1, // stride height
1997            int                   s2, // stride depth
1998            int                   p0, // padding width
1999            int                   p1, // padding height
2000            int                   p2, // padding depth
2001            int                   d0, // dilation width
2002            int                   d1, // dilation height
2003            int                   d2, // dilation depth
2004            enum ggml_type        dst_type);
2005
2006    // a: [OC*IC, KD, KH, KW]
2007    // b: [N*IC, ID, IH, IW]
2008    // result: [N*OC, OD, OH, OW]
2009    GGML_API struct ggml_tensor * ggml_conv_3d(
2010                struct ggml_context * ctx,
2011                struct ggml_tensor  * a,
2012                struct ggml_tensor  * b,
2013                int64_t               IC,
2014                int                   s0, // stride width
2015                int                   s1, // stride height
2016                int                   s2, // stride depth
2017                int                   p0, // padding width
2018                int                   p1, // padding height
2019                int                   p2, // padding depth
2020                int                   d0, // dilation width
2021                int                   d1, // dilation height
2022                int                   d2  // dilation depth
2023        );
2024
2025    // kernel size is a->ne[0] x a->ne[1]
2026    // stride is equal to kernel size
2027    // padding is zero
2028    // example:
2029    // a:     16   16    3  768
2030    // b:   1024 1024    3    1
2031    // res:   64   64  768    1
2032    // used in sam
2033    GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
2034            struct ggml_context * ctx,
2035            struct ggml_tensor  * a,
2036            struct ggml_tensor  * b);
2037
2038    // kernel size is a->ne[0] x a->ne[1]
2039    // stride is 1
2040    // padding is half
2041    // example:
2042    // a:      3    3    256  256
2043    // b:     64   64    256    1
2044    // res:   64   64    256    1
2045    // used in sam
2046    GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
2047            struct ggml_context * ctx,
2048            struct ggml_tensor  * a,
2049            struct ggml_tensor  * b);
2050
2051    // depthwise (via im2col and mul_mat)
2052    GGML_API struct ggml_tensor * ggml_conv_2d_dw(
2053            struct ggml_context * ctx,
2054            struct ggml_tensor  * a,  // convolution kernel
2055            struct ggml_tensor  * b,  // data
2056            int                  s0,  // stride dimension 0
2057            int                  s1,  // stride dimension 1
2058            int                  p0,  // padding dimension 0
2059            int                  p1,  // padding dimension 1
2060            int                  d0,  // dilation dimension 0
2061            int                  d1); // dilation dimension 1
2062
2063    // Depthwise 2D convolution
2064    // may be faster than ggml_conv_2d_dw, but not available in all backends
2065    // a:   KW    KH    1    C    convolution kernel
2066    // b:   W     H     C    N    input data
2067    // res: W_out H_out C    N
2068    GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct(
2069            struct ggml_context * ctx,
2070            struct ggml_tensor  * a,
2071            struct ggml_tensor  * b,
2072            int                   stride0,
2073            int                   stride1,
2074            int                   pad0,
2075            int                   pad1,
2076            int                   dilation0,
2077            int                   dilation1);
2078
2079    GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
2080            struct ggml_context * ctx,
2081            struct ggml_tensor  * a,
2082            struct ggml_tensor  * b,
2083            int                   stride);
2084
2085    GGML_API struct ggml_tensor * ggml_conv_2d_direct(
2086            struct ggml_context * ctx,
2087            struct ggml_tensor  * a,   // convolution kernel [KW, KH, IC, OC]
2088            struct ggml_tensor  * b,   // input data [W, H, C, N]
2089            int                   s0,  // stride dimension 0
2090            int                   s1,  // stride dimension 1
2091            int                   p0,  // padding dimension 0
2092            int                   p1,  // padding dimension 1
2093            int                   d0,  // dilation dimension 0
2094            int                   d1); // dilation dimension 1
2095
2096    GGML_API struct ggml_tensor * ggml_conv_3d_direct(
2097            struct ggml_context * ctx,
2098            struct ggml_tensor  * a,   // kernel [KW, KH, KD, IC * OC]
2099            struct ggml_tensor  * b,   // input  [W, H, D, C * N]
2100            int                   s0,  // stride
2101            int                   s1,
2102            int                   s2,
2103            int                   p0,  // padding
2104            int                   p1,
2105            int                   p2,
2106            int                   d0,  // dilation
2107            int                   d1,
2108            int                   d2,
2109            int                   n_channels,
2110            int                   n_batch,
2111            int                   n_channels_out);
2112
2113    enum ggml_op_pool {
2114        GGML_OP_POOL_MAX,
2115        GGML_OP_POOL_AVG,
2116        GGML_OP_POOL_COUNT,
2117    };
2118
2119    GGML_API struct ggml_tensor * ggml_pool_1d(
2120            struct ggml_context * ctx,
2121            struct ggml_tensor  * a,
2122            enum ggml_op_pool     op,
2123            int                   k0, // kernel size
2124            int                   s0, // stride
2125            int                   p0); // padding
2126
2127    // the result will have 2*p0 padding for the first dimension
2128    // and 2*p1 padding for the second dimension
2129    GGML_API struct ggml_tensor * ggml_pool_2d(
2130            struct ggml_context * ctx,
2131            struct ggml_tensor  * a,
2132            enum ggml_op_pool     op,
2133            int                   k0,
2134            int                   k1,
2135            int                   s0,
2136            int                   s1,
2137            float                 p0,
2138            float                 p1);
2139
2140    GGML_API struct ggml_tensor * ggml_pool_2d_back(
2141            struct ggml_context * ctx,
2142            struct ggml_tensor  * a,
2143            struct ggml_tensor  * af, // "a"/input used in forward pass
2144            enum ggml_op_pool     op,
2145            int                   k0,
2146            int                   k1,
2147            int                   s0,
2148            int                   s1,
2149            float                 p0,
2150            float                 p1);
2151
2152    enum ggml_scale_mode {
2153        GGML_SCALE_MODE_NEAREST  = 0,
2154        GGML_SCALE_MODE_BILINEAR = 1,
2155        GGML_SCALE_MODE_BICUBIC  = 2,
2156
2157        GGML_SCALE_MODE_COUNT
2158    };
2159
2160    enum ggml_scale_flag {
2161        GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8),
2162        GGML_SCALE_FLAG_ANTIALIAS     = (1 << 9),
2163    };
2164
2165    // interpolate
2166    // multiplies ne0 and ne1 by scale factor
2167    GGML_API struct ggml_tensor * ggml_upscale(
2168            struct ggml_context * ctx,
2169            struct ggml_tensor  * a,
2170            int                   scale_factor,
2171            enum ggml_scale_mode  mode);
2172
2173    // interpolate
2174    // interpolate scale to specified dimensions
2175    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_upscale_ext(
2176            struct ggml_context * ctx,
2177            struct ggml_tensor  * a,
2178            int                   ne0,
2179            int                   ne1,
2180            int                   ne2,
2181            int                   ne3,
2182            enum ggml_scale_mode  mode),
2183        "use ggml_interpolate instead");
2184
2185    // Up- or downsamples the input to the specified size.
2186    // 2D scale modes (eg. bilinear) are applied to the first two dimensions.
2187    GGML_API struct ggml_tensor * ggml_interpolate(
2188            struct ggml_context * ctx,
2189            struct ggml_tensor  * a,
2190            int64_t               ne0,
2191            int64_t               ne1,
2192            int64_t               ne2,
2193            int64_t               ne3,
2194            uint32_t              mode); // ggml_scale_mode [ | ggml_scale_flag...]
2195
2196    // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
2197    GGML_API struct ggml_tensor * ggml_pad(
2198            struct ggml_context * ctx,
2199            struct ggml_tensor  * a,
2200            int                  p0,
2201            int                  p1,
2202            int                  p2,
2203            int                  p3);
2204
2205    // pad each dimension with values on the other side of the torus (looping around)
2206    GGML_API struct ggml_tensor * ggml_pad_circular(
2207            struct ggml_context * ctx,
2208            struct ggml_tensor  * a,
2209            int                   p0,
2210            int                   p1,
2211            int                   p2,
2212            int                   p3);
2213
2214    GGML_API struct ggml_tensor * ggml_pad_ext(
2215            struct ggml_context * ctx,
2216            struct ggml_tensor  * a,
2217            int                  lp0,
2218            int                  rp0,
2219            int                  lp1,
2220            int                  rp1,
2221            int                  lp2,
2222            int                  rp2,
2223            int                  lp3,
2224            int                  rp3
2225            );
2226
2227    // pad each dimension with values on the other side of the torus (looping around)
2228    GGML_API struct ggml_tensor * ggml_pad_ext_circular(
2229            struct ggml_context * ctx,
2230            struct ggml_tensor  * a,
2231            int                   lp0,
2232            int                   rp0,
2233            int                   lp1,
2234            int                   rp1,
2235            int                   lp2,
2236            int                   rp2,
2237            int                   lp3,
2238            int                   rp3);
2239
2240    // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
2241    GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
2242            struct ggml_context * ctx,
2243            struct ggml_tensor  * a,
2244            int                   p0,
2245            int                   p1);
2246
2247    // Move tensor elements by an offset given for each dimension. Elements that
2248    // are shifted beyond the last position are wrapped around to the beginning.
2249    GGML_API struct ggml_tensor * ggml_roll(
2250            struct ggml_context * ctx,
2251            struct ggml_tensor  * a,
2252            int                   shift0,
2253            int                   shift1,
2254            int                   shift2,
2255            int                   shift3);
2256
2257    // Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing
2258    // zeroes everywhere outside the masked area
2259    GGML_API struct ggml_tensor * ggml_tri(
2260            struct ggml_context * ctx,
2261            struct ggml_tensor  * a,
2262            enum ggml_tri_type    type);
2263
2264    // Fill tensor a with constant c
2265    GGML_API struct ggml_tensor * ggml_fill(
2266            struct ggml_context * ctx,
2267            struct ggml_tensor  * a,
2268            float                 c);
2269
2270    GGML_API struct ggml_tensor * ggml_fill_inplace(
2271            struct ggml_context * ctx,
2272            struct ggml_tensor  * a,
2273            float                 c);
2274
2275    // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
2276    // timesteps: [N,]
2277    // return: [N, dim]
2278    GGML_API struct ggml_tensor * ggml_timestep_embedding(
2279            struct ggml_context * ctx,
2280            struct ggml_tensor  * timesteps,
2281            int                   dim,
2282            int                   max_period);
2283
2284    // sort rows
2285    enum ggml_sort_order {
2286        GGML_SORT_ORDER_ASC,
2287        GGML_SORT_ORDER_DESC,
2288    };
2289
2290    GGML_API struct ggml_tensor * ggml_argsort(
2291            struct ggml_context * ctx,
2292            struct ggml_tensor  * a,
2293            enum ggml_sort_order  order);
2294
2295    // similar to ggml_top_k but implemented as `argsort` + `view`
2296    GGML_API struct ggml_tensor * ggml_argsort_top_k(
2297            struct ggml_context * ctx,
2298            struct ggml_tensor  * a,
2299            int                   k);
2300
2301    // top k elements per row
2302    // note: the resulting top k indices are in no particular order
2303    GGML_API struct ggml_tensor * ggml_top_k(
2304            struct ggml_context * ctx,
2305            struct ggml_tensor  * a,
2306            int                   k);
2307
2308    GGML_API struct ggml_tensor * ggml_arange(
2309            struct ggml_context * ctx,
2310            float                 start,
2311            float                 stop,
2312            float                 step);
2313
2314    // q:    [n_embd_k, n_batch, n_head,    ne3 ]
2315    // k:    [n_embd_k, n_kv,    n_head_kv, ne3 ]
2316    // v:    [n_embd_v, n_kv,    n_head_kv, ne3 ] !! not transposed !!
2317    // mask: [n_kv,     n_batch, ne32,      ne33]
2318    // res:  [n_embd_v, n_head,  n_batch,   ne3 ] !! permuted !!
2319    //
2320    // broadcast:
2321    //   n_head % n_head_kv == 0
2322    //   n_head % ne32      == 0
2323    //   ne3    % ne33      == 0
2324    //
2325    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
2326            struct ggml_context * ctx,
2327            struct ggml_tensor  * q,
2328            struct ggml_tensor  * k,
2329            struct ggml_tensor  * v,
2330            struct ggml_tensor  * mask,
2331            float                 scale,
2332            float                 max_bias,
2333            float                 logit_softcap);
2334
2335    GGML_API void ggml_flash_attn_ext_set_prec(
2336            struct ggml_tensor * a,
2337            enum ggml_prec       prec);
2338
2339    GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
2340            const struct ggml_tensor * a);
2341
2342    GGML_API void ggml_flash_attn_ext_add_sinks(
2343            struct ggml_tensor * a,
2344            struct ggml_tensor * sinks);
2345
2346    // TODO: needs to be adapted to ggml_flash_attn_ext
2347    GGML_API struct ggml_tensor * ggml_flash_attn_back(
2348           struct ggml_context * ctx,
2349           struct ggml_tensor  * q,
2350           struct ggml_tensor  * k,
2351           struct ggml_tensor  * v,
2352           struct ggml_tensor  * d,
2353           bool                  masked);
2354
2355    GGML_API struct ggml_tensor * ggml_ssm_conv(
2356            struct ggml_context * ctx,
2357            struct ggml_tensor  * sx,
2358            struct ggml_tensor  * c);
2359
2360    GGML_API struct ggml_tensor * ggml_ssm_scan(
2361            struct ggml_context * ctx,
2362            struct ggml_tensor  * s,
2363            struct ggml_tensor  * x,
2364            struct ggml_tensor  * dt,
2365            struct ggml_tensor  * A,
2366            struct ggml_tensor  * B,
2367            struct ggml_tensor  * C,
2368            struct ggml_tensor  * ids);
2369
2370    // partition into non-overlapping windows with padding if needed
2371    // example:
2372    // a:   768   64   64    1
2373    // w:    14
2374    // res: 768   14   14    25
2375    // used in sam
2376    GGML_API struct ggml_tensor * ggml_win_part(
2377            struct ggml_context * ctx,
2378            struct ggml_tensor  * a,
2379            int                   w);
2380
2381    // reverse of ggml_win_part
2382    // used in sam
2383    GGML_API struct ggml_tensor * ggml_win_unpart(
2384            struct ggml_context * ctx,
2385            struct ggml_tensor  * a,
2386            int                   w0,
2387            int                   h0,
2388            int                   w);
2389
2390    GGML_API struct ggml_tensor * ggml_unary(
2391            struct ggml_context * ctx,
2392             struct ggml_tensor * a,
2393             enum ggml_unary_op op);
2394
2395    GGML_API struct ggml_tensor * ggml_unary_inplace(
2396        struct ggml_context * ctx,
2397        struct ggml_tensor  * a,
2398        enum ggml_unary_op op);
2399
2400    // used in sam
2401    GGML_API struct ggml_tensor * ggml_get_rel_pos(
2402            struct ggml_context * ctx,
2403            struct ggml_tensor  * a,
2404            int                   qh,
2405            int                   kh);
2406
2407    // used in sam
2408    GGML_API struct ggml_tensor * ggml_add_rel_pos(
2409            struct ggml_context * ctx,
2410            struct ggml_tensor  * a,
2411            struct ggml_tensor  * pw,
2412            struct ggml_tensor  * ph);
2413
2414    GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
2415            struct ggml_context * ctx,
2416            struct ggml_tensor  * a,
2417            struct ggml_tensor  * pw,
2418            struct ggml_tensor  * ph);
2419
2420    GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
2421            struct ggml_context * ctx,
2422            struct ggml_tensor  * k,
2423            struct ggml_tensor  * v,
2424            struct ggml_tensor  * r,
2425            struct ggml_tensor  * tf,
2426            struct ggml_tensor  * td,
2427            struct ggml_tensor  * state);
2428
2429    GGML_API struct ggml_tensor * ggml_gated_linear_attn(
2430            struct ggml_context * ctx,
2431            struct ggml_tensor  * k,
2432            struct ggml_tensor  * v,
2433            struct ggml_tensor  * q,
2434            struct ggml_tensor  * g,
2435            struct ggml_tensor  * state,
2436            float scale);
2437
2438    GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
2439            struct ggml_context * ctx,
2440            struct ggml_tensor  * r,
2441            struct ggml_tensor  * w,
2442            struct ggml_tensor  * k,
2443            struct ggml_tensor  * v,
2444            struct ggml_tensor  * a,
2445            struct ggml_tensor  * b,
2446            struct ggml_tensor  * state);
2447
2448    /* Solves a specific equation of the form Ax=B, where A is a triangular matrix
2449    *  without zeroes on the diagonal (i.e. invertible).
2450    *  B can have any number of columns, but must have the same number of rows as A
2451    *  If A is [n, n] and B is [n, m], then the result will be [n, m] as well
2452    *  Has O(n^3) complexity (unlike most matrix ops out there), so use on cases
2453    *  where n > 100 sparingly, pre-chunk if necessary.
2454    *
2455    *  If left = false, solves xA=B instead
2456    *  If lower = false, assumes upper triangular instead
2457    *  If uni = true, assumes diagonal of A to be all ones (will override actual values)
2458    *
2459    *  TODO: currently only lower, right, non-unitriangular variant is implemented
2460    */
2461    GGML_API struct ggml_tensor * ggml_solve_tri(
2462        struct ggml_context * ctx,
2463        struct ggml_tensor  * a,
2464        struct ggml_tensor  * b,
2465        bool                  left,
2466        bool                  lower,
2467        bool                  uni);
2468
2469    // custom operators
2470
2471    typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
2472    typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
2473    typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
2474
2475#define GGML_N_TASKS_MAX (-1)
2476    // n_tasks == GGML_N_TASKS_MAX means to use max number of tasks
2477
2478    GGML_API struct ggml_tensor * ggml_map_custom1(
2479            struct ggml_context   * ctx,
2480            struct ggml_tensor    * a,
2481            ggml_custom1_op_t       fun,
2482            int                     n_tasks,
2483            void                  * userdata);
2484
2485    GGML_API struct ggml_tensor * ggml_map_custom1_inplace(
2486            struct ggml_context   * ctx,
2487            struct ggml_tensor    * a,
2488            ggml_custom1_op_t       fun,
2489            int                     n_tasks,
2490            void                  * userdata);
2491
2492    GGML_API struct ggml_tensor * ggml_map_custom2(
2493            struct ggml_context   * ctx,
2494            struct ggml_tensor    * a,
2495            struct ggml_tensor    * b,
2496            ggml_custom2_op_t       fun,
2497            int                     n_tasks,
2498            void                  * userdata);
2499
2500    GGML_API struct ggml_tensor * ggml_map_custom2_inplace(
2501            struct ggml_context   * ctx,
2502            struct ggml_tensor    * a,
2503            struct ggml_tensor    * b,
2504            ggml_custom2_op_t       fun,
2505            int                     n_tasks,
2506            void                  * userdata);
2507
2508    GGML_API struct ggml_tensor * ggml_map_custom3(
2509            struct ggml_context   * ctx,
2510            struct ggml_tensor    * a,
2511            struct ggml_tensor    * b,
2512            struct ggml_tensor    * c,
2513            ggml_custom3_op_t       fun,
2514            int                     n_tasks,
2515            void                  * userdata);
2516
2517    GGML_API struct ggml_tensor * ggml_map_custom3_inplace(
2518            struct ggml_context   * ctx,
2519            struct ggml_tensor    * a,
2520            struct ggml_tensor    * b,
2521            struct ggml_tensor    * c,
2522            ggml_custom3_op_t       fun,
2523            int                     n_tasks,
2524            void                  * userdata);
2525
2526    typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata);
2527
2528    GGML_API struct ggml_tensor * ggml_custom_4d(
2529            struct ggml_context * ctx,
2530            enum ggml_type        type,
2531            int64_t               ne0,
2532            int64_t               ne1,
2533            int64_t               ne2,
2534            int64_t               ne3,
2535            struct ggml_tensor ** args,
2536            int                   n_args,
2537            ggml_custom_op_t      fun,
2538            int                   n_tasks,
2539            void                * userdata);
2540
2541    GGML_API struct ggml_tensor * ggml_custom_inplace(
2542            struct ggml_context * ctx,
2543            struct ggml_tensor  * a,
2544            struct ggml_tensor ** args,
2545            int                   n_args,
2546            ggml_custom_op_t      fun,
2547            int                   n_tasks,
2548            void                * userdata);
2549
2550    // loss function
2551
2552    GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
2553            struct ggml_context * ctx,
2554            struct ggml_tensor  * a,  // logits
2555            struct ggml_tensor  * b); // labels
2556
2557    GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
2558            struct ggml_context * ctx,
2559            struct ggml_tensor  * a,  // logits
2560            struct ggml_tensor  * b,  // labels
2561            struct ggml_tensor  * c); // gradients of cross_entropy_loss result
2562
2563    // AdamW optimizer step
2564    // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
2565    // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
2566    GGML_API struct ggml_tensor * ggml_opt_step_adamw(
2567            struct ggml_context * ctx,
2568            struct ggml_tensor  * a,
2569            struct ggml_tensor  * grad,
2570            struct ggml_tensor  * m,
2571            struct ggml_tensor  * v,
2572            struct ggml_tensor  * adamw_params); // parameters such as the learning rate
2573
2574    // stochastic gradient descent step (with weight decay)
2575    GGML_API struct ggml_tensor * ggml_opt_step_sgd(
2576        struct ggml_context * ctx,
2577        struct ggml_tensor *  a,
2578        struct ggml_tensor *  grad,
2579        struct ggml_tensor *  sgd_params); // alpha, weight decay
2580
2581    // build forward mutiple tensors and select one of them for computing
2582    // this is useful for creating graphs that have constant topology but compute different things based on the input
2583    // ref: https://github.com/ggml-org/llama.cpp/pull/18550
2584    //
2585    // nodes:
2586    //   | - build forward into the graph but do not compute
2587    //   c - build forward into the graph and compute
2588    //
2589    //    |  |  ...  c  ...  |
2590    //    |  |  ...  c  ...  |
2591    //    |  |  ...  c  ...  |
2592    //   [0  1  ... idx ...  n-1]        <-- ggml_build_forward_select(..., n, idx)
2593    //               c
2594    //               c
2595    //
2596    // example:
2597    //   struct ggml_tensor * curs[3];
2598    //
2599    //   curs[0]  = compute0(...);
2600    //   curs[1]  = compute1(...);
2601    //   curs[2]  = compute2(...);
2602    //
2603    //   int idx = select_branch(some_input);
2604    //
2605    //   struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx);
2606    //
2607    GGML_API struct ggml_tensor * ggml_build_forward_select(
2608            struct ggml_cgraph  * cgraph,
2609            struct ggml_tensor ** tensors,
2610            int                   n_tensors,
2611            int                   idx);
2612
2613    GGML_API void ggml_build_forward_expand(
2614            struct ggml_cgraph * cgraph,
2615            struct ggml_tensor * tensor);
2616
2617    GGML_API void ggml_build_backward_expand(
2618        struct ggml_context *  ctx,        // context for gradient computation
2619        struct ggml_cgraph  *  cgraph,
2620        struct ggml_tensor  ** grad_accs);
2621
2622    // graph allocation in a context
2623    GGML_API struct ggml_cgraph * ggml_new_graph       (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
2624    GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
2625    GGML_API struct ggml_cgraph * ggml_graph_dup       (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);
2626    GGML_API void                 ggml_graph_cpy       (struct ggml_cgraph * src, struct ggml_cgraph * dst);
2627    GGML_API void                 ggml_graph_reset     (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2628    GGML_API void                 ggml_graph_clear     (struct ggml_cgraph * cgraph);
2629
2630    GGML_API int                   ggml_graph_size   (struct ggml_cgraph * cgraph);
2631    GGML_API struct ggml_tensor *  ggml_graph_node   (struct ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i]
2632    GGML_API struct ggml_tensor ** ggml_graph_nodes  (struct ggml_cgraph * cgraph);
2633    GGML_API int                   ggml_graph_n_nodes(struct ggml_cgraph * cgraph);
2634
2635    GGML_API void   ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
2636
2637    GGML_API size_t ggml_graph_overhead(void);
2638    GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
2639
2640    GGML_API struct ggml_tensor * ggml_graph_get_tensor  (const struct ggml_cgraph * cgraph, const char * name);
2641    GGML_API struct ggml_tensor * ggml_graph_get_grad    (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
2642    GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
2643
2644    // print info and performance information for the graph
2645    GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
2646
2647    // dump the graph into a file using the dot format
2648    GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename);
2649
2650    // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
2651    typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
2652
2653    // Set callback for all future logging events.
2654    // If this is not called, or NULL is supplied, everything is output on stderr.
2655    GGML_API void ggml_log_get(ggml_log_callback * log_callback, void ** user_data);
2656    GGML_API void ggml_log_set(ggml_log_callback   log_callback, void *  user_data);
2657
2658    GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
2659
2660    //
2661    // quantization
2662    //
2663
2664    // - ggml_quantize_init can be called multiple times with the same type
2665    //   it will only initialize the quantization tables for the first call or after ggml_quantize_free
2666    //   automatically called by ggml_quantize_chunk for convenience
2667    //
2668    // - ggml_quantize_free will free any memory allocated by ggml_quantize_init
2669    //   call this at the end of the program to avoid memory leaks
2670    //
2671    // note: these are thread-safe
2672    //
2673    GGML_API void ggml_quantize_init(enum ggml_type type);
2674    GGML_API void ggml_quantize_free(void);
2675
2676    // some quantization type cannot be used without an importance matrix
2677    GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type);
2678
2679    // calls ggml_quantize_init internally (i.e. can allocate memory)
2680    GGML_API size_t ggml_quantize_chunk(
2681            enum ggml_type   type,
2682               const float * src,
2683                      void * dst,
2684                   int64_t   start,
2685                   int64_t   nrows,
2686                   int64_t   n_per_row,
2687               const float * imatrix);
2688
2689#ifdef __cplusplus
2690    // restrict not standard in C++
2691#    if defined(__GNUC__)
2692#        define GGML_RESTRICT __restrict__
2693#    elif defined(__clang__)
2694#        define GGML_RESTRICT __restrict
2695#    elif defined(_MSC_VER)
2696#        define GGML_RESTRICT __restrict
2697#    else
2698#        define GGML_RESTRICT
2699#    endif
2700#else
2701#    if defined (_MSC_VER) && (__STDC_VERSION__ < 201112L)
2702#        define GGML_RESTRICT __restrict
2703#    else
2704#        define GGML_RESTRICT restrict
2705#    endif
2706#endif
2707    typedef void (*ggml_to_float_t)  (const void  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
2708    typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void  * GGML_RESTRICT y, int64_t k);
2709
2710    struct ggml_type_traits {
2711        const char             * type_name;
2712        int64_t                  blck_size;
2713        int64_t                  blck_size_interleave; // interleave elements in blocks
2714        size_t                   type_size;
2715        bool                     is_quantized;
2716        ggml_to_float_t          to_float;
2717        ggml_from_float_t        from_float_ref;
2718    };
2719
2720    GGML_API const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type);
2721
2722    // ggml threadpool
2723    // TODO: currently, only a few functions are in the base ggml API, while the rest are in the CPU backend
2724    // the goal should be to create an API that other backends can use move everything to the ggml base
2725
2726    // scheduling priorities
2727    enum ggml_sched_priority {
2728        GGML_SCHED_PRIO_LOW = -1,
2729        GGML_SCHED_PRIO_NORMAL,
2730        GGML_SCHED_PRIO_MEDIUM,
2731        GGML_SCHED_PRIO_HIGH,
2732        GGML_SCHED_PRIO_REALTIME
2733    };
2734
2735    // threadpool params
2736    // Use ggml_threadpool_params_default() or ggml_threadpool_params_init() to populate the defaults
2737    struct ggml_threadpool_params {
2738        bool                cpumask[GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)
2739        int                 n_threads;                   // number of threads
2740        enum ggml_sched_priority prio;                   // thread priority
2741        uint32_t            poll;                        // polling level (0 - no polling, 100 - aggressive polling)
2742        bool                strict_cpu;                  // strict cpu placement
2743        bool                paused;                      // start in paused state
2744    };
2745
2746    struct ggml_threadpool;     // forward declaration, see ggml.c
2747
2748    typedef struct ggml_threadpool * ggml_threadpool_t;
2749
2750    GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads);
2751    GGML_API void                          ggml_threadpool_params_init   (struct ggml_threadpool_params * p, int n_threads);
2752    GGML_API bool                          ggml_threadpool_params_match  (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);
2753
2754#ifdef  __cplusplus
2755}
2756#endif