1// This file defines tests for various GGML ops and backends.
   2// For the forward pass it asserts that the results of multiple backends computing the same GGML ops are consistent.
   3// For the backward pass it asserts that the gradients from backpropagation are consistent
   4// with the gradients obtained via the method of finite differences ("grad" mode, this is optional).
   5// It is also possible to check the performance ("perf" mode).
   6//
   7// this file has three sections: Section 1 does general setup, section 2 defines the GGML ops to be tested,
   8// and section 3 defines which tests to run.
   9// Quick start for adding a new GGML op: Go to section 2 and create a struct that inherits from test_case,
  10// then go to section 3 and add an instantiation of your struct.
  11
  12
  13// ##############################
  14// ## Section 1: General Setup ##
  15// ##############################
  16
  17
  18#include <ggml.h>
  19#include <ggml-alloc.h>
  20#include <ggml-backend.h>
  21#include <ggml-cpp.h>
  22
  23#include <algorithm>
  24#include <array>
  25#include <cfloat>
  26#include <cinttypes>
  27#include <cstdarg>
  28#include <cstdint>
  29#include <cstdio>
  30#include <cstdlib>
  31#include <cstring>
  32#include <ctime>
  33#include <future>
  34#include <memory>
  35#include <random>
  36#include <regex>
  37#include <set>
  38#include <string>
  39#include <string_view>
  40#include <thread>
  41#include <vector>
  42#include <unordered_map>
  43
  44#ifdef __EMSCRIPTEN__
  45#   define N_THREADS 1
  46#else
  47#   define N_THREADS std::thread::hardware_concurrency()
  48#endif
  49
  50static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
  51    size_t nels = ggml_nelements(tensor);
  52    std::vector<float> data(nels);
  53    {
  54        // parallel initialization
  55        static const size_t n_threads = N_THREADS;
  56        // static RNG initialization (revisit if n_threads stops being constant)
  57        static std::vector<std::default_random_engine> generators = []() {
  58            std::random_device rd;
  59            std::vector<std::default_random_engine> vec;
  60            vec.reserve(n_threads);
  61            //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
  62            for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
  63            return vec;
  64        }();
  65
  66        auto init_thread = [&](size_t ith, size_t start, size_t end) {
  67            std::uniform_real_distribution<float> distribution(min, max);
  68            auto & gen = generators[ith];
  69            for (size_t i = start; i < end; i++) {
  70                data[i] = distribution(gen);
  71            }
  72        };
  73
  74        if (n_threads == 1) {
  75            init_thread(0, 0, nels);
  76        } else {
  77            std::vector<std::future<void>> tasks;
  78            tasks.reserve(n_threads);
  79            for (size_t i = 0; i < n_threads; i++) {
  80                size_t start =     i*nels/n_threads;
  81                size_t end   = (i+1)*nels/n_threads;
  82                tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
  83            }
  84            for (auto & t : tasks) {
  85                t.get();
  86            }
  87        }
  88    }
  89
  90    if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
  91        ggml_backend_tensor_set(tensor, data.data(), 0, nels * sizeof(float));
  92    } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
  93        GGML_ASSERT(nels % ggml_blck_size(tensor->type) == 0);
  94
  95         // dummy importance matrix
  96        std::vector<float> imatrix(tensor->ne[0], 1.0f);
  97        const float * im = imatrix.data();
  98        if (!ggml_quantize_requires_imatrix(tensor->type)) {
  99            // when the imatrix is optional, we want to test both quantization with and without imatrix
 100            // use one of the random numbers to decide
 101            if (data[0] > 0.5f*(min + max)) {
 102                im = nullptr;
 103            }
 104        }
 105
 106        std::vector<uint8_t> dataq(ggml_row_size(tensor->type, nels));
 107        {
 108            // parallel quantization by block
 109            size_t blck_size = ggml_blck_size(tensor->type);
 110            size_t n_blocks = nels / blck_size;
 111
 112            auto quantize_thread = [&](size_t start, size_t end) {
 113                ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
 114                    start * blck_size, end - start, blck_size, im);
 115            };
 116
 117            const size_t min_blocks_per_thread = 1;
 118            const size_t n_quant_threads = std::min<size_t>(std::max<size_t>(N_THREADS/2, 1),
 119                                                            std::max<size_t>(1, n_blocks / min_blocks_per_thread));
 120
 121            if (n_quant_threads == 1) {
 122                // single-threaded quantization: do all blocks in the current thread
 123                quantize_thread(0, n_blocks);
 124            } else {
 125                std::vector<std::future<void>> tasks;
 126                tasks.reserve(n_quant_threads);
 127                for (size_t i = 0; i < n_quant_threads; i++) {
 128                    size_t start =     i*n_blocks/n_quant_threads;
 129                    size_t end   = (i+1)*n_blocks/n_quant_threads;
 130                    tasks.push_back(std::async(std::launch::async, quantize_thread, start, end));
 131                }
 132                for (auto & t : tasks) {
 133                    t.get();
 134                }
 135            }
 136        }
 137        ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
 138    } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
 139        // This is going to create some weird integers though.
 140        ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor));
 141    } else if (tensor->type == GGML_TYPE_I64) {
 142        // Integers with a size of 8 bytes can be set by mirroring the float data, the specific values are again not really meaningful.
 143        const size_t nbytes_half = ggml_nbytes(tensor)/2;
 144        ggml_backend_tensor_set(tensor, data.data(), 0*nbytes_half, nbytes_half);
 145        ggml_backend_tensor_set(tensor, data.data(), 1*nbytes_half, nbytes_half);
 146    } else {
 147        GGML_ABORT("fatal error");
 148    }
 149}
 150
 151// generate an F16 mask where certain blocks are randomly masked with -INF value
 152static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
 153    GGML_ASSERT(tensor->type == GGML_TYPE_F16);
 154
 155    GGML_TENSOR_LOCALS( int32_t, ne, tensor, ne);
 156
 157    std::vector<float>       data_f32(ne0*ne1*ne2*ne3);
 158    std::vector<ggml_fp16_t> data_f16(ne0*ne1*ne2*ne3);
 159
 160    std::random_device rd;
 161    std::mt19937 gen(rd());
 162    std::uniform_real_distribution<float> dis(min, max);
 163
 164    for (size_t i = 0; i < data_f32.size(); i++) {
 165        data_f32[i] = dis(gen);
 166    }
 167
 168    // block size
 169    const int blck0 = 128;
 170    const int blck1 = 64;
 171
 172    // number of INF/zero blocks
 173    const int n_inf_zero_blocks = 0.2*(ne0*ne1*ne2*ne3)/(blck0*blck1);
 174
 175    for (int b = 0; b < n_inf_zero_blocks; b++) {
 176        const int p3 = (rd() % ne3);
 177        const int p2 = (rd() % ne2);
 178        const int p1 = (rd() % ne1);
 179        const int p0 = (rd() % ne0);
 180
 181        bool inf = rd() & 1;
 182
 183        for (int i1 = 0; i1 < blck1 && p1 + i1 < ne1; i1++) {
 184            const int idx = p3*ne2*ne1*ne0 + p2*ne1*ne0 + (p1 + i1)*ne0 + p0;
 185
 186            for (int i0 = 0; i0 < blck0 && p0 + i0 < ne0; i0++) {
 187                data_f32[idx + i0] = inf ? -INFINITY : 0.0f;
 188            }
 189        }
 190    }
 191
 192    ggml_fp32_to_fp16_row(data_f32.data(), data_f16.data(), ne0*ne1*ne2*ne3);
 193
 194    ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t));
 195}
 196
 197// generate a lower triangular matrix
 198static void init_tensor_tril(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
 199    GGML_ASSERT(tensor->type == GGML_TYPE_F32);
 200    GGML_ASSERT(tensor->ne[0] == tensor->ne[1]);
 201
 202    GGML_TENSOR_LOCALS(int32_t, ne, tensor, ne);
 203    GGML_TENSOR_LOCALS(size_t, nb, tensor, nb);
 204
 205    std::vector<float> data_f32(ne0*ne1*ne2*ne3);
 206
 207    std::random_device rd;
 208    std::mt19937 gen(rd());
 209    std::uniform_real_distribution<float> dis(min, max);
 210
 211    for (int64_t i3 = 0; i3 < ne3; i3++) {
 212        for (int64_t i2 = 0; i2 < ne2; i2++) {
 213            for (int64_t i1 = 0; i1 < ne1; i1++) {
 214                for (int64_t i0 = 0; i0 < ne0; i0++) {
 215                    int64_t idx = (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3) / sizeof(float);
 216                    if (i0 <= i1) {
 217                        data_f32[idx] = dis(gen);
 218                    } else {
 219                        data_f32[idx] = 0.0f;
 220                    }
 221                }
 222            }
 223        }
 224    }
 225
 226    ggml_backend_tensor_set(tensor, data_f32.data(), 0, ggml_nbytes(tensor));
 227}
 228
 229static std::vector<float> tensor_to_float(const ggml_tensor * t) {
 230    std::vector<float> tv;
 231    tv.reserve(ggml_nelements(t));
 232
 233    std::vector<uint8_t> buf(ggml_nbytes(t));
 234    ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
 235
 236    const auto * tt = ggml_get_type_traits(t->type);
 237    size_t bs = ggml_blck_size(t->type);
 238    std::vector<float> vq(ggml_blck_size(t->type));
 239    bool quantized = ggml_is_quantized(t->type);
 240
 241    // access elements by index to avoid gaps in views
 242    for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
 243        for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
 244            for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
 245                for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
 246                    size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
 247                    if (t->type == GGML_TYPE_F16) {
 248                        tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
 249                    } else if (t->type == GGML_TYPE_BF16) {
 250                        tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i]));
 251                    } else if (t->type == GGML_TYPE_F32) {
 252                        tv.push_back(*(float *) &buf[i]);
 253                    } else if (t->type == GGML_TYPE_I64) {
 254                        tv.push_back((float)*(int64_t *) &buf[i]);
 255                    } else if (t->type == GGML_TYPE_I32) {
 256                        tv.push_back((float)*(int32_t *) &buf[i]);
 257                    } else if (t->type == GGML_TYPE_I16) {
 258                        tv.push_back((float)*(int16_t *) &buf[i]);
 259                    } else if (t->type == GGML_TYPE_I8) {
 260                        tv.push_back((float)*(int8_t *) &buf[i]);
 261                    } else if (quantized) {
 262                        tt->to_float(&buf[i], vq.data(), bs);
 263                        tv.insert(tv.end(), vq.begin(), vq.end());
 264                    } else {
 265                        GGML_ABORT("fatal error");
 266                    }
 267                }
 268            }
 269        }
 270    }
 271
 272    return tv;
 273}
 274
 275// normalized mean squared error = mse(a, b) / mse(a, 0)
 276static double nmse(const float * a, const float * b, size_t n) {
 277    double mse_a_b = 0.0;
 278    double mse_a_0 = 0.0;
 279
 280    for (size_t i = 0; i < n; i++) {
 281        float a_i = a[i];
 282        float b_i = b[i];
 283
 284        mse_a_b += (a_i - b_i) * (a_i - b_i);
 285        mse_a_0 += a_i * a_i;
 286    }
 287
 288    return mse_a_b / mse_a_0;
 289}
 290
 291// difference between 2 sets (Jaccard distance, 0 - no difference, 1 - no overlap)
 292template <typename T>
 293static double jdst(const T * a, const T * b, size_t n) {
 294    std::unordered_map<T, size_t> set_a;
 295    std::unordered_map<T, size_t> set_b;
 296
 297    for (size_t i = 0; i < n; ++i) {
 298        set_a[a[i]]++;
 299        set_b[b[i]]++;
 300    }
 301
 302    size_t diff = 0;
 303
 304    for (const auto & p : set_a) {
 305        const int64_t na = p.second;
 306        const int64_t nb = set_b.find(p.first) != set_b.end() ? set_b.at(p.first) : 0;
 307
 308        diff += std::abs(na - nb);
 309    }
 310
 311    for (const auto & p : set_b) {
 312        if (set_a.find(p.first) == set_a.end()) {
 313            diff += p.second;
 314        }
 315    }
 316
 317    return (double) diff / (2*n);
 318}
 319
 320// maximum absolute asymmetry between a and b
 321// asymmetry: (a - b) / (a + b)
 322// This is more stable than relative error if one of the values fluctuates towards zero.
 323// n: number of values to compare.
 324// expected_vals: optional vector of expected values for a. If expected_vals is not empty, filter out all comparisons where
 325//     a does not match any of the expected values. Needed for noncontinuous gradients where the numerical calculation can fail.
 326static double mean_abs_asymm(const float * a, const float * b, const size_t n, const std::vector<float> & expected_vals) {
 327    double sum = 0.0f;
 328
 329    size_t nvalid = 0;
 330    for (size_t i = 0; i < n; i++) {
 331        if (!expected_vals.empty()) {
 332            bool matches_any = false;
 333            for (const float & ev : expected_vals) {
 334                if (fabsf(a[i] - ev) < 1e-3f) {
 335                    matches_any = true;
 336                    break;
 337                }
 338            }
 339            if (!matches_any) {
 340                continue;
 341            }
 342        }
 343
 344        const float asymm = (a[i] - b[i]) / (a[i] + b[i]);
 345
 346        sum += fabsf(asymm);
 347        nvalid++;
 348    }
 349
 350    return sum/nvalid;
 351}
 352
 353// utils for printing the variables of the test cases
 354
 355static std::string var_to_str(const std::string & x) {
 356    return x;
 357}
 358
 359template<typename T>
 360static std::string var_to_str(const T & x) {
 361    return std::to_string(x);
 362}
 363
 364template<typename T, size_t N>
 365static std::string var_to_str(const T (&x)[N]) {
 366    std::string s = "[";
 367    for (size_t i = 0; i < N; i++) {
 368        if (i > 0) {
 369            s += ",";
 370        }
 371        s += var_to_str(x[i]);
 372    }
 373    s += "]";
 374    return s;
 375}
 376
 377template<typename T, size_t N>
 378static std::string var_to_str(const std::array<T, N> & x) {
 379    std::string s = "[";
 380    for (size_t i = 0; i < N; i++) {
 381        if (i > 0) {
 382            s += ",";
 383        }
 384        s += var_to_str(x[i]);
 385    }
 386    s += "]";
 387    return s;
 388}
 389
 390static std::string var_to_str(ggml_type type) {
 391    return ggml_type_name(type);
 392}
 393
 394static std::string var_to_str(ggml_prec prec) {
 395    return prec == GGML_PREC_F32 ? "f32" : "def";
 396}
 397
 398static std::string var_to_str(ggml_op_pool pool) {
 399    switch (pool) {
 400        case GGML_OP_POOL_AVG:  return "avg";
 401        case GGML_OP_POOL_MAX:  return "max";
 402        default:                return std::to_string(pool);
 403    }
 404}
 405
 406static std::string var_to_str(ggml_scale_mode mode) {
 407    std::string str;
 408    switch (mode & 0xFF) {
 409        case GGML_SCALE_MODE_NEAREST:  str = "nearest"; break;
 410        case GGML_SCALE_MODE_BILINEAR: str = "bilinear"; break;
 411        case GGML_SCALE_MODE_BICUBIC:  str = "bicubic"; break;
 412        default:                       str = std::to_string(mode); break;
 413    }
 414    if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
 415        str += "|align_corners";
 416    }
 417    if (mode & GGML_SCALE_FLAG_ANTIALIAS) {
 418        str += "|antialias";
 419    }
 420    return str;
 421}
 422
 423#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
 424
 425#define VARS_TO_STR1(a) VAR_TO_STR(a)
 426#define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
 427#define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
 428#define VARS_TO_STR4(a, b, c, d) VAR_TO_STR(a) + "," + VARS_TO_STR3(b, c, d)
 429#define VARS_TO_STR5(a, b, c, d, e) VAR_TO_STR(a) + "," + VARS_TO_STR4(b, c, d, e)
 430#define VARS_TO_STR6(a, b, c, d, e, f) VAR_TO_STR(a) + "," + VARS_TO_STR5(b, c, d, e, f)
 431#define VARS_TO_STR7(a, b, c, d, e, f, g) VAR_TO_STR(a) + "," + VARS_TO_STR6(b, c, d, e, f, g)
 432#define VARS_TO_STR8(a, b, c, d, e, f, g, h) VAR_TO_STR(a) + "," + VARS_TO_STR7(b, c, d, e, f, g, h)
 433#define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + "," + VARS_TO_STR8(b, c, d, e, f, g, h, i)
 434#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
 435#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
 436#define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
 437#define VARS_TO_STR13(a, b, c, d, e, f, g, h, i, j, k, l, m) VAR_TO_STR(a) + "," + VARS_TO_STR12(b, c, d, e, f, g, h, i, j, k, l, m)
 438#define VARS_TO_STR14(a, b, c, d, e, f, g, h, i, j, k, l, m, n) VAR_TO_STR(a) + "," + VARS_TO_STR13(b, c, d, e, f, g, h, i, j, k, l, m, n)
 439#define VARS_TO_STR15(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) VAR_TO_STR(a) + "," + VARS_TO_STR14(b, c, d, e, f, g, h, i, j, k, l, m, n, o)
 440#define VARS_TO_STR16(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) VAR_TO_STR(a) + "," + VARS_TO_STR15(b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)
 441
 442#ifdef GGML_USE_SYCL
 443static bool inline _isinf(float f) {
 444    return (*(uint32_t *)&f & 0x7fffffff) == 0x7f800000;
 445}
 446#else
 447static bool inline _isinf(float f) { return std::isinf(f); }
 448#endif
 449
 450// accept FLT_MAX as infinity
 451static bool isinf_or_max(float f) {
 452    return _isinf(f) || f == FLT_MAX || f == -FLT_MAX;
 453}
 454
 455static bool ggml_is_view_op(enum ggml_op op) {
 456    return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
 457}
 458
 459static bool backend_has_feature(ggml_backend_t backend, const char * feature_name) {
 460    ggml_backend_dev_t dev = ggml_backend_get_device(backend);
 461    ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
 462
 463    auto get_features = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
 464    if (!get_features) {
 465        return false;
 466    }
 467
 468    const ggml_backend_feature * features = get_features(reg);
 469    if (!features) {
 470        return false;
 471    }
 472
 473    for (const ggml_backend_feature * f = features; f->name; ++f) {
 474        if (strcmp(f->name, feature_name) == 0 && strcmp(f->value, "1") == 0) {
 475            return true;
 476        }
 477    }
 478    return false;
 479}
 480
 481enum test_mode {
 482    MODE_TEST,
 483    MODE_PERF,
 484    MODE_GRAD,
 485    MODE_SUPPORT,
 486};
 487
 488// Output format support similar to llama-bench
 489enum output_formats { CONSOLE, SQL, CSV };
 490
 491static const char * output_format_str(output_formats format) {
 492    switch (format) {
 493        case CONSOLE:
 494            return "console";
 495        case SQL:
 496            return "sql";
 497        case CSV:
 498            return "csv";
 499        default:
 500            GGML_ABORT("invalid output format");
 501    }
 502}
 503
 504static bool output_format_from_str(const std::string & s, output_formats & format) {
 505    if (s == "console") {
 506        format = CONSOLE;
 507    } else if (s == "sql") {
 508        format = SQL;
 509    } else if (s == "csv") {
 510        format = CSV;
 511    } else {
 512        return false;
 513    }
 514    return true;
 515}
 516
 517// Test result structure for SQL output
 518struct test_result {
 519    std::string test_time;
 520    std::string build_commit;
 521    std::string backend_name;
 522    std::string op_name;
 523    std::string op_params;
 524    std::string test_mode;
 525    bool        supported;
 526    bool        passed;
 527    std::string error_message;
 528    double      time_us;
 529    double      flops;
 530    double      bandwidth_gb_s;
 531    size_t      memory_kb;
 532    int         n_runs;
 533    std::string device_description;
 534    std::string backend_reg_name;
 535
 536    test_result() {
 537        // Initialize with default values
 538        time_us        = 0.0;
 539        flops          = 0.0;
 540        bandwidth_gb_s = 0.0;
 541        memory_kb      = 0;
 542        n_runs         = 0;
 543        supported      = false;
 544        passed         = false;
 545
 546        // Set test time
 547        time_t t = time(NULL);
 548        char   buf[32];
 549        std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
 550        test_time = buf;
 551
 552        // Set build info
 553        build_commit = ggml_commit();
 554    }
 555
 556    test_result(const std::string & backend_name, const std::string & op_name, const std::string & op_params,
 557                const std::string & test_mode, bool supported, bool passed, const std::string & error_message = "",
 558                double time_us = 0.0, double flops = 0.0, double bandwidth_gb_s = 0.0, size_t memory_kb = 0,
 559                int n_runs = 0, const std::string & device_description = "", const std::string & backend_reg_name = "") :
 560        backend_name(backend_name),
 561        op_name(op_name),
 562        op_params(op_params),
 563        test_mode(test_mode),
 564        supported(supported),
 565        passed(passed),
 566        error_message(error_message),
 567        time_us(time_us),
 568        flops(flops),
 569        bandwidth_gb_s(bandwidth_gb_s),
 570        memory_kb(memory_kb),
 571        n_runs(n_runs),
 572        device_description(device_description),
 573        backend_reg_name(backend_reg_name) {
 574        // Set test time
 575        time_t t = time(NULL);
 576        char   buf[32];
 577        std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
 578        test_time = buf;
 579
 580        // Set build info
 581        build_commit = ggml_commit();
 582    }
 583
 584    static const std::vector<std::string> & get_fields() {
 585        static const std::vector<std::string> fields = {
 586            "test_time", "build_commit",  "backend_name", "op_name", "op_params",      "test_mode", "supported",
 587            "passed",    "error_message", "time_us",      "flops",   "bandwidth_gb_s", "memory_kb", "n_runs",
 588            "device_description", "backend_reg_name"
 589        };
 590        return fields;
 591    }
 592
 593    enum field_type { STRING, BOOL, INT, FLOAT };
 594
 595    static field_type get_field_type(const std::string & field) {
 596        if (field == "supported" || field == "passed") {
 597            return BOOL;
 598        }
 599        if (field == "memory_kb" || field == "n_runs") {
 600            return INT;
 601        }
 602        if (field == "time_us" || field == "flops" || field == "bandwidth_gb_s") {
 603            return FLOAT;
 604        }
 605        return STRING;
 606    }
 607
 608    std::vector<std::string> get_values() const {
 609        return { test_time,
 610                 build_commit,
 611                 backend_name,
 612                 op_name,
 613                 op_params,
 614                 test_mode,
 615                 std::to_string(supported),
 616                 std::to_string(passed),
 617                 error_message,
 618                 std::to_string(time_us),
 619                 std::to_string(flops),
 620                 std::to_string(bandwidth_gb_s),
 621                 std::to_string(memory_kb),
 622                 std::to_string(n_runs),
 623                 device_description,
 624                 backend_reg_name };
 625    }
 626};
 627
 628// Printer classes for different output formats
 629enum class test_status_t { NOT_SUPPORTED, OK, FAIL, SKIPPED };
 630
 631struct test_operation_info {
 632    std::string   op_name;
 633    std::string   op_params;
 634    std::string   backend_name;
 635    test_status_t status = test_status_t::OK;
 636    std::string   failure_reason;
 637
 638    // Additional information fields that were previously in separate structs
 639    std::string error_component;
 640    std::string error_details;
 641
 642    // Gradient info
 643    int64_t     gradient_index = -1;
 644    std::string gradient_param_name;
 645    float       gradient_value = 0.0f;
 646
 647    // MAA error info
 648    double maa_error     = 0.0;
 649    double maa_threshold = 0.0;
 650
 651    // Flags for different types of information
 652    bool has_error            = false;
 653    bool has_gradient_info    = false;
 654    bool has_maa_error        = false;
 655    bool is_compare_failure   = false;
 656    bool is_large_tensor_skip = false;
 657
 658    test_operation_info() = default;
 659
 660    test_operation_info(const std::string & op_name, const std::string & op_params, const std::string & backend_name,
 661                        test_status_t status = test_status_t::OK, const std::string & failure_reason = "") :
 662        op_name(op_name),
 663        op_params(op_params),
 664        backend_name(backend_name),
 665        status(status),
 666        failure_reason(failure_reason) {}
 667
 668    // Set error information
 669    void set_error(const std::string & component, const std::string & details) {
 670        has_error       = true;
 671        error_component = component;
 672        error_details   = details;
 673        if (status == test_status_t::OK) {
 674            status = test_status_t::FAIL;
 675        }
 676    }
 677
 678    // Set gradient information
 679    void set_gradient_info(int64_t index, const std::string & param_name, float value) {
 680        has_gradient_info   = true;
 681        gradient_index      = index;
 682        gradient_param_name = param_name;
 683        gradient_value      = value;
 684        if (status == test_status_t::OK) {
 685            status = test_status_t::FAIL;
 686        }
 687    }
 688
 689    // Set MAA error information
 690    void set_maa_error(double error, double threshold) {
 691        has_maa_error = true;
 692        maa_error     = error;
 693        maa_threshold = threshold;
 694        if (status == test_status_t::OK) {
 695            status = test_status_t::FAIL;
 696        }
 697    }
 698
 699    // Set compare failure
 700    void set_compare_failure() {
 701        is_compare_failure = true;
 702        if (status == test_status_t::OK) {
 703            status = test_status_t::FAIL;
 704        }
 705    }
 706
 707    // Set large tensor skip
 708    void set_large_tensor_skip() { is_large_tensor_skip = true; }
 709};
 710
 711struct test_summary_info {
 712    size_t tests_passed;
 713    size_t tests_total;
 714    bool   is_backend_summary = false;  // true for backend summary, false for test summary
 715
 716    test_summary_info() = default;
 717
 718    test_summary_info(size_t tests_passed, size_t tests_total, bool is_backend_summary = false) :
 719        tests_passed(tests_passed),
 720        tests_total(tests_total),
 721        is_backend_summary(is_backend_summary) {}
 722};
 723
 724struct testing_start_info {
 725    size_t device_count;
 726
 727    testing_start_info() = default;
 728
 729    testing_start_info(size_t device_count) : device_count(device_count) {}
 730};
 731
 732struct backend_init_info {
 733    size_t      device_index;
 734    size_t      total_devices;
 735    std::string device_name;
 736    bool        skipped = false;
 737    std::string skip_reason;
 738    std::string description;
 739    size_t      memory_total_mb = 0;
 740    size_t      memory_free_mb  = 0;
 741    bool        has_memory_info = false;
 742
 743    backend_init_info() = default;
 744
 745    backend_init_info(size_t device_index, size_t total_devices, const std::string & device_name, bool skipped = false,
 746                      const std::string & skip_reason = "", const std::string & description = "",
 747                      size_t memory_total_mb = 0, size_t memory_free_mb = 0, bool has_memory_info = false) :
 748        device_index(device_index),
 749        total_devices(total_devices),
 750        device_name(device_name),
 751        skipped(skipped),
 752        skip_reason(skip_reason),
 753        description(description),
 754        memory_total_mb(memory_total_mb),
 755        memory_free_mb(memory_free_mb),
 756        has_memory_info(has_memory_info) {}
 757};
 758
 759struct backend_status_info {
 760    std::string   backend_name;
 761    test_status_t status;
 762
 763    backend_status_info() = default;
 764
 765    backend_status_info(const std::string & backend_name, test_status_t status) :
 766        backend_name(backend_name),
 767        status(status) {}
 768};
 769
 770struct overall_summary_info {
 771    size_t backends_passed;
 772    size_t backends_total;
 773    bool   all_passed;
 774
 775    overall_summary_info() = default;
 776
 777    overall_summary_info(size_t backends_passed, size_t backends_total, bool all_passed) :
 778        backends_passed(backends_passed),
 779        backends_total(backends_total),
 780        all_passed(all_passed) {}
 781};
 782
 783struct printer {
 784    virtual ~printer() {}
 785
 786    FILE * fout = stdout;
 787
 788    virtual void print_header() {}
 789
 790    virtual void print_test_result(const test_result & result) = 0;
 791
 792    virtual void print_footer() {}
 793
 794    virtual void print_operation(const test_operation_info & info) { (void) info; }
 795
 796    virtual void print_summary(const test_summary_info & info) { (void) info; }
 797
 798    virtual void print_testing_start(const testing_start_info & info) { (void) info; }
 799
 800    virtual void print_backend_init(const backend_init_info & info) { (void) info; }
 801
 802    virtual void print_backend_status(const backend_status_info & info) { (void) info; }
 803
 804    virtual void print_overall_summary(const overall_summary_info & info) { (void) info; }
 805
 806    virtual void print_failed_tests(const std::vector<std::string> & failed_tests) { (void) failed_tests; }
 807};
 808
 809struct console_printer : public printer {
 810    void print_test_result(const test_result & result) override {
 811        if (result.test_mode == "test") {
 812            print_test_console(result);
 813        } else if (result.test_mode == "perf") {
 814            print_perf_console(result);
 815        } else if (result.test_mode == "support") {
 816            print_support_console(result);
 817        }
 818    }
 819
 820    void print_operation(const test_operation_info & info) override {
 821        printf("  %s(%s): ", info.op_name.c_str(), info.op_params.c_str());
 822        fflush(stdout);
 823
 824        // Handle large tensor skip first
 825        if (info.is_large_tensor_skip) {
 826            printf("skipping large tensors for speed \n");
 827            return;
 828        }
 829
 830        // Handle not supported status
 831        if (info.status == test_status_t::NOT_SUPPORTED) {
 832            if (!info.failure_reason.empty()) {
 833                printf("not supported [%s]\n", info.failure_reason.c_str());
 834            } else {
 835                printf("not supported [%s]\n", info.backend_name.c_str());
 836            }
 837            return;
 838        }
 839
 840        // Handle errors and additional information
 841        if (info.has_error) {
 842            if (info.error_component == "allocation") {
 843                fprintf(stderr, "failed to allocate tensors [%s] ", info.backend_name.c_str());
 844            } else if (info.error_component == "backend") {
 845                fprintf(stderr, "  Failed to initialize %s backend\n", info.backend_name.c_str());
 846            } else {
 847                fprintf(stderr, "Error in %s: %s\n", info.error_component.c_str(), info.error_details.c_str());
 848            }
 849        }
 850
 851        // Handle gradient info
 852        if (info.has_gradient_info) {
 853            printf("[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", info.op_name.c_str(), info.gradient_index,
 854                   info.gradient_param_name.c_str(), info.gradient_value);
 855        }
 856
 857        // Handle MAA error
 858        if (info.has_maa_error) {
 859            printf("[%s] MAA = %.9f > %.9f ", info.op_name.c_str(), info.maa_error, info.maa_threshold);
 860        }
 861
 862        // Handle compare failure
 863        if (info.is_compare_failure) {
 864            printf("compare failed ");
 865        }
 866
 867        // Print final status
 868        if (info.status == test_status_t::OK) {
 869            printf("\033[1;32mOK\033[0m\n");
 870        } else {
 871            printf("\033[1;31mFAIL\033[0m\n");
 872        }
 873    }
 874
 875    void print_summary(const test_summary_info & info) override {
 876        if (info.is_backend_summary) {
 877            printf("%zu/%zu backends passed\n", info.tests_passed, info.tests_total);
 878        } else {
 879            printf("  %zu/%zu tests passed\n", info.tests_passed, info.tests_total);
 880        }
 881    }
 882
 883    void print_backend_status(const backend_status_info & info) override {
 884        printf("  Backend %s: ", info.backend_name.c_str());
 885        if (info.status == test_status_t::OK) {
 886            printf("\033[1;32mOK\033[0m\n");
 887        } else {
 888            printf("\033[1;31mFAIL\033[0m\n");
 889        }
 890    }
 891
 892    void print_testing_start(const testing_start_info & info) override {
 893        printf("Testing %zu devices\n\n", info.device_count);
 894    }
 895
 896    void print_backend_init(const backend_init_info & info) override {
 897        printf("Backend %zu/%zu: %s\n", info.device_index + 1, info.total_devices, info.device_name.c_str());
 898
 899        if (info.skipped) {
 900            printf("  %s\n", info.skip_reason.c_str());
 901            return;
 902        }
 903
 904        if (!info.description.empty()) {
 905            printf("  Device description: %s\n", info.description.c_str());
 906        }
 907
 908        if (info.has_memory_info) {
 909            printf("  Device memory: %zu MB (%zu MB free)\n", info.memory_total_mb, info.memory_free_mb);
 910        }
 911
 912        printf("\n");
 913    }
 914
 915    void print_overall_summary(const overall_summary_info & info) override {
 916        printf("%zu/%zu backends passed\n", info.backends_passed, info.backends_total);
 917        if (info.all_passed) {
 918            printf("\033[1;32mOK\033[0m\n");
 919        } else {
 920            printf("\033[1;31mFAIL\033[0m\n");
 921        }
 922    }
 923
 924    void print_failed_tests(const std::vector<std::string> & failed_tests) override {
 925        if (failed_tests.empty()) {
 926            return;
 927        }
 928
 929        printf("\nFailing tests:\n");
 930        for (const auto & test_name : failed_tests) {
 931            printf("  %s\n", test_name.c_str());
 932        }
 933    }
 934
 935  private:
 936    void print_test_console(const test_result & result) {
 937        printf("  %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
 938        fflush(stdout);
 939
 940        if (!result.supported) {
 941            printf("not supported [%s] ", result.backend_name.c_str());
 942            printf("\n");
 943            return;
 944        }
 945
 946        if (result.passed) {
 947            printf("\033[1;32mOK\033[0m\n");
 948        } else {
 949            printf("\033[1;31mFAIL\033[0m\n");
 950        }
 951    }
 952
 953    void print_perf_console(const test_result & result) {
 954        int len = printf("  %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
 955        fflush(stdout);
 956
 957        if (!result.supported) {
 958            printf("not supported\n");
 959            return;
 960        }
 961
 962        // align while also leaving some margin for variations in parameters
 963        int align = 8;
 964        int last  = (len + align - 1) / align * align;
 965        if (last - len < 5) {
 966            last += align;
 967        }
 968        printf("%*s", last - len, "");
 969
 970        printf("    %8d runs - %8.2f us/run - ", result.n_runs, result.time_us);
 971
 972        if (result.flops > 0) {
 973            auto format_flops = [](double flops) -> std::string {
 974                char buf[256];
 975                if (flops >= 1e12) {
 976                    snprintf(buf, sizeof(buf), "%6.2f TFLOP", flops / 1e12);
 977                } else if (flops >= 1e9) {
 978                    snprintf(buf, sizeof(buf), "%6.2f GFLOP", flops / 1e9);
 979                } else if (flops >= 1e6) {
 980                    snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6);
 981                } else {
 982                    snprintf(buf, sizeof(buf), "%6.2f kFLOP", flops / 1e3);
 983                }
 984                return buf;
 985            };
 986            uint64_t op_flops_per_run = result.flops * result.time_us / 1e6;
 987            printf("%s/run - \033[1;34m%sS\033[0m", format_flops(op_flops_per_run).c_str(),
 988                   format_flops(result.flops).c_str());
 989        } else {
 990            printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m", result.memory_kb, result.bandwidth_gb_s);
 991        }
 992        printf("\n");
 993    }
 994
 995    void print_support_console(const test_result & result) {
 996        printf("  %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
 997        fflush(stdout);
 998
 999        if (result.supported) {
1000            printf("\033[1;32mSUPPORTED\033[0m\n");
1001        } else {
1002            printf("\033[1;31mNOT SUPPORTED\033[0m\n");
1003        }
1004    }
1005};
1006
1007struct sql_printer : public printer {
1008    static std::string get_sql_field_type(const std::string & field) {
1009        switch (test_result::get_field_type(field)) {
1010            case test_result::STRING:
1011                return "TEXT";
1012            case test_result::BOOL:
1013            case test_result::INT:
1014                return "INTEGER";
1015            case test_result::FLOAT:
1016                return "REAL";
1017            default:
1018                GGML_ABORT("invalid field type");
1019        }
1020    }
1021
1022    void print_header() override {
1023        std::vector<std::string> fields = test_result::get_fields();
1024        fprintf(fout, "CREATE TABLE IF NOT EXISTS test_backend_ops (\n");
1025        for (size_t i = 0; i < fields.size(); i++) {
1026            fprintf(fout, "  %s %s%s\n", fields[i].c_str(), get_sql_field_type(fields[i]).c_str(),
1027                    i < fields.size() - 1 ? "," : "");
1028        }
1029        fprintf(fout, ");\n\n");
1030    }
1031
1032    void print_test_result(const test_result & result) override {
1033        fprintf(fout, "INSERT INTO test_backend_ops (");
1034        std::vector<std::string> fields = test_result::get_fields();
1035        for (size_t i = 0; i < fields.size(); i++) {
1036            fprintf(fout, "%s%s", fields[i].c_str(), i < fields.size() - 1 ? ", " : "");
1037        }
1038        fprintf(fout, ") VALUES (");
1039        std::vector<std::string> values = result.get_values();
1040        for (size_t i = 0; i < values.size(); i++) {
1041            fprintf(fout, "'%s'%s", values[i].c_str(), i < values.size() - 1 ? ", " : "");
1042        }
1043        fprintf(fout, ");\n");
1044    }
1045};
1046
1047struct csv_printer : public printer {
1048    void print_header() override {
1049
1050        std::vector<std::string> fields     = test_result::get_fields();
1051        std::vector<std::string> fields_csv = get_fields_csv();
1052        for (size_t i = 0; i < fields.size(); i++) {
1053            if (std::find(std::begin(fields_csv), std::end(fields_csv), fields[i]) == std::end(fields_csv)) {
1054                continue;
1055            }
1056            printf("\"%s\"%s", fields[i].c_str(), i < fields.size() - 1 ? "," : "");
1057        }
1058        printf("\n");
1059    }
1060
1061    void print_test_result(const test_result & result) override {
1062
1063        std::vector<std::string> values     = result.get_values();
1064        std::vector<std::string> fields     = test_result::get_fields();
1065        std::vector<std::string> fields_csv = get_fields_csv();
1066
1067        for (size_t i = 0; i < values.size(); i++) {
1068
1069            if (std::find(std::begin(fields_csv), std::end(fields_csv), fields[i]) == std::end(fields_csv)) {
1070                continue;
1071            }
1072
1073            // Escape quotes and wrap in quotes for CSV
1074            std::string escaped_value = values[i];
1075            size_t pos = 0;
1076            while ((pos = escaped_value.find("\"", pos)) != std::string::npos) {
1077                escaped_value.replace(pos, 1, "\"\"");
1078                pos += 2;
1079            }
1080            printf("\"%s\"%s", escaped_value.c_str(), i < values.size() - 1 ? "," : "");
1081        }
1082        printf("\n");
1083    }
1084
1085    static std::vector<std::string> get_fields_csv() {
1086        return {
1087            "op_name",
1088            "op_params",
1089            "supported",
1090            "error_message",
1091            "test_mode",
1092            "backend_reg_name",
1093            "backend_name",
1094        };
1095    }
1096
1097};
1098
1099static std::unique_ptr<printer> create_printer(output_formats format) {
1100    switch (format) {
1101        case CONSOLE:
1102            return std::make_unique<console_printer>();
1103        case SQL:
1104            return std::make_unique<sql_printer>();
1105        case CSV:
1106            return std::make_unique<csv_printer>();
1107    }
1108    GGML_ABORT("invalid output format");
1109}
1110
1111struct test_case {
1112    virtual ~test_case() {}
1113
1114    virtual std::string op_desc(ggml_tensor * t) {
1115        return ggml_op_desc(t);
1116    }
1117
1118    virtual std::string vars() {
1119        return "";
1120    }
1121
1122    virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
1123
1124    virtual double max_nmse_err() {
1125        return 1e-7;
1126    }
1127
1128    virtual double max_nmse_err(ggml_backend_t backend) {
1129        GGML_UNUSED(backend);
1130        return max_nmse_err();
1131    }
1132
1133    virtual double max_maa_err() {
1134        return 1e-4;
1135    }
1136
1137    virtual double max_err() {
1138        return max_nmse_err();
1139    }
1140
1141    virtual double max_err(ggml_backend_t backend) {
1142        return max_nmse_err(backend);
1143    }
1144
1145    virtual double err(const float * a, const float * b, size_t n) {
1146        return nmse(a, b, n);
1147    }
1148
1149    virtual float grad_eps() {
1150        return 1e-1f;
1151    }
1152
1153    // If false, estimate gradient with 2 points, neglects 3rd order derivative and higher.
1154    // If true,  estimate gradient with 4 points, neglects 5th order derivative and higher.
1155    virtual bool grad_precise() {
1156        return false;
1157    }
1158
1159    // Skip gradient checks if total number of gradients to be checked is larger than this (to speed up the tests).
1160    virtual int64_t grad_nmax() {
1161        return 10000;
1162    }
1163
1164    // No effect if empty.
1165    // If not empty, skip all gradient checks where the numerical result does not match any of the values.
1166    // Needed for dealing with noncontinuous gradients (e.g. ReLU) where estimation using finite differences is unreliable.
1167    virtual std::vector<float> grad_expect() {
1168        return {};
1169    }
1170
1171    virtual void initialize_tensors(ggml_context * ctx) {
1172        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
1173            init_tensor_uniform(t);
1174        }
1175    }
1176
1177    virtual size_t op_size(ggml_tensor * t) {
1178        size_t size = ggml_nbytes(t);
1179        // add source tensors
1180        for (int i = 0; i < GGML_MAX_SRC; i++) {
1181            if (t->src[i] != NULL) {
1182                size += ggml_nbytes(t->src[i]);
1183            }
1184        }
1185        return size;
1186    }
1187
1188    virtual uint64_t op_flops(ggml_tensor * t) {
1189        GGML_UNUSED(t);
1190        return 0;
1191    }
1192
1193    virtual bool run_whole_graph() { return false; }
1194    virtual std::vector<ggml_tensor *> fusion_test_nodes() { return {}; }
1195
1196    ggml_cgraph * gf = nullptr;
1197    ggml_cgraph * gb = nullptr;
1198
1199    static const int sentinel_size = 1024;
1200
1201    test_mode mode;
1202
1203    std::vector<ggml_tensor *> sentinels;
1204
1205    std::string current_op_name;
1206
1207    void add_sentinel(ggml_context * ctx) {
1208        if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) {
1209            return;
1210        }
1211        ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);
1212        ggml_format_name(sentinel, "sent_%zu", sentinels.size());
1213        sentinels.push_back(sentinel);
1214    }
1215
1216    // hijack ggml_new_tensor to add sentinels after each tensor to check for overflows in the backend
1217
1218    ggml_tensor * ggml_new_tensor(ggml_context * ctx, ggml_type type, int n_dims, const int64_t * ne) {
1219        ggml_tensor * t = ::ggml_new_tensor(ctx, type, n_dims, ne);
1220        add_sentinel(ctx);
1221        return t;
1222    }
1223
1224    ggml_tensor * ggml_new_tensor_1d(ggml_context * ctx, ggml_type type, int64_t ne0) {
1225        ggml_tensor * t = ::ggml_new_tensor_1d(ctx, type, ne0);
1226        add_sentinel(ctx);
1227        return t;
1228    }
1229
1230    ggml_tensor * ggml_new_tensor_2d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1) {
1231        ggml_tensor * t = ::ggml_new_tensor_2d(ctx, type, ne0, ne1);
1232        add_sentinel(ctx);
1233        return t;
1234    }
1235
1236    ggml_tensor * ggml_new_tensor_3d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2) {
1237        ggml_tensor * t = ::ggml_new_tensor_3d(ctx, type, ne0, ne1, ne2);
1238        add_sentinel(ctx);
1239        return t;
1240    }
1241
1242    ggml_tensor * ggml_new_tensor_4d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
1243        ggml_tensor * t = ::ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
1244        add_sentinel(ctx);
1245        return t;
1246    }
1247
1248    // Checks an op against the test filter, which is a comma separated list of OP names or specific variations
1249    bool matches_filter(ggml_tensor * op, const char * op_names_filter) {
1250        if (op_names_filter) {
1251            const auto op_name = op_desc(op);
1252            const auto op_full_name = op_name + "(" + vars() + ")";
1253            std::string_view filter(op_names_filter);
1254            while (!filter.empty()) {
1255                auto comma_pos = filter.find_first_of(',');
1256                const auto lparen_pos = filter.find_first_of('(');
1257                if (lparen_pos < comma_pos) {
1258                    auto rparen_pos = filter.find_first_of(')');
1259                    comma_pos = filter.find_first_of(',', rparen_pos);
1260                    const auto op_filter = filter.substr(0, comma_pos);
1261                    if (op_filter == op_full_name) {
1262                        return true;
1263                    }
1264                } else {
1265                    const auto op_filter = filter.substr(0, comma_pos);
1266                    if (op_filter == op_name) {
1267                        return true;
1268                    }
1269                }
1270                filter = comma_pos != std::string_view::npos ? filter.substr(comma_pos + 1) : "";
1271            }
1272            return false;
1273        } else {
1274            return true;
1275        }
1276    }
1277
1278    test_status_t eval(ggml_backend_t backend1,
1279                       ggml_backend_t backend2,
1280                       const char *   op_names_filter,
1281                       printer *      output_printer) {
1282        mode = MODE_TEST;
1283
1284        ggml_init_params params = {
1285            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
1286            /* .mem_base = */ NULL,
1287            /* .no_alloc = */ true,
1288        };
1289        ggml_context * ctx = ggml_init(params);
1290        GGML_ASSERT(ctx);
1291
1292        gf = ggml_new_graph(ctx);
1293
1294        // pre-graph sentinel
1295        add_sentinel(ctx);
1296
1297        ggml_tensor * out = build_graph(ctx);
1298        current_op_name   = op_desc(out);
1299
1300        if (!matches_filter(out, op_names_filter)) {
1301            //printf("  %s: skipping\n", op_desc(out).c_str());
1302            ggml_free(ctx);
1303            return test_status_t::SKIPPED;
1304        }
1305
1306        // check if the backends support the ops
1307        bool supported = true;
1308        for (ggml_backend_t backend : {backend1, backend2}) {
1309            for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1310                if (!ggml_backend_supports_op(backend, t)) {
1311                    supported = false;
1312                    break;
1313                }
1314            }
1315        }
1316
1317        if (!supported) {
1318            // Create test result for unsupported operation
1319            test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test",
1320                             false, false, "not supported");
1321
1322            if (output_printer) {
1323                output_printer->print_test_result(result);
1324            }
1325
1326            ggml_free(ctx);
1327            return test_status_t::NOT_SUPPORTED;
1328        }
1329
1330        // post-graph sentinel
1331        add_sentinel(ctx);
1332
1333        // allocate
1334        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
1335
1336        if (buf == NULL) {
1337            printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
1338            ggml_free(ctx);
1339            return test_status_t::FAIL;
1340        }
1341
1342        // build graph
1343        ggml_build_forward_expand(gf, out);
1344
1345        // add sentinels as graph nodes so that they are checked in the callback
1346        for (ggml_tensor * sentinel : sentinels) {
1347            ggml_graph_add_node(gf, sentinel);
1348        }
1349
1350        // randomize tensors
1351        initialize_tensors(ctx);
1352
1353        // compare
1354        struct callback_userdata {
1355            bool   ok;
1356            test_case * tc;
1357            ggml_backend_t backend1;
1358            ggml_backend_t backend2;
1359        };
1360
1361        callback_userdata ud {
1362            true,
1363            this,
1364            backend1,
1365            backend2,
1366        };
1367
1368        auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
1369            callback_userdata * ud = (callback_userdata *) user_data;
1370            const char * bn1 = ggml_backend_name(ud->backend1);
1371            const char * bn2 = ggml_backend_name(ud->backend2);
1372
1373            if (t1->op == GGML_OP_NONE) {
1374                // sentinels must be unchanged
1375                std::vector<uint8_t> t1_data(ggml_nbytes(t1));
1376                std::vector<uint8_t> t2_data(ggml_nbytes(t2));
1377                ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1));
1378                ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2));
1379
1380                if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) {
1381                    printf("sentinel mismatch: %s ", t1->name);
1382                    ud->ok = false;
1383                    return true;
1384                }
1385            }
1386
1387            std::vector<float> f1 = tensor_to_float(t1);
1388            std::vector<float> f2 = tensor_to_float(t2);
1389
1390            for (size_t i = 0; i < f1.size(); i++) {
1391                // check for nans
1392                if (std::isnan(f1[i]) || std::isnan(f2[i])) {
1393                    printf("[%s] NaN at index %zu (%s=%f %s=%f) ", ggml_op_desc(t1), i, bn1, f1[i], bn2, f2[i]);
1394                    ud->ok = false;
1395                    return true;
1396                }
1397                // check for infs: both must be inf of the same sign, or both must be finite
1398                if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
1399                    if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
1400                        if (std::signbit(f1[i]) != std::signbit(f2[i])) {
1401                            printf("[%s] inf sign mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]);
1402                            ud->ok = false;
1403                            return true;
1404                        }
1405                    } else {
1406                        printf("[%s] inf mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]);
1407                        ud->ok = false;
1408                        return true;
1409                    }
1410                }
1411            }
1412
1413            double err = ud->tc->err(f1.data(), f2.data(), f1.size());
1414            if (err > ud->tc->max_err(ud->backend1)) {
1415                printf("[%s] ERR = %.9f > %.9f ", ggml_op_desc(t1), err, ud->tc->max_err(ud->backend1));
1416                //for (int i = 0; i < (int) f1.size(); i++) {
1417                //    printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
1418                //}
1419                //printf("\n");
1420                //exit(1);
1421                ud->ok = false;
1422            }
1423            return true;
1424
1425            GGML_UNUSED(index);
1426        };
1427
1428        std::vector<ggml_tensor *> fused_nodes_to_verify = fusion_test_nodes();
1429        if (fused_nodes_to_verify.size() == 0 && run_whole_graph()) {
1430            fused_nodes_to_verify.push_back(out);
1431        }
1432        const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud,
1433                                                               run_whole_graph() ? fused_nodes_to_verify.data() : nullptr,
1434                                                               fused_nodes_to_verify.size());
1435
1436        ggml_backend_buffer_free(buf);
1437
1438        ggml_free(ctx);
1439
1440        // Create test result
1441        bool        test_passed = ud.ok && cmp_ok;
1442        std::string error_msg   = test_passed ? "" : (!cmp_ok ? "compare failed" : "test failed");
1443        test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test", supported, test_passed,
1444                           error_msg);
1445
1446        if (output_printer) {
1447            output_printer->print_test_result(result);
1448        }
1449
1450        return test_passed ? test_status_t::OK : test_status_t::FAIL;
1451    }
1452
1453    bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
1454        mode = MODE_PERF;
1455
1456        static const size_t graph_nodes = 8192;
1457
1458        ggml_init_params params = {
1459            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false),
1460            /* .mem_base = */ NULL,
1461            /* .no_alloc = */ true,
1462        };
1463        ggml_context_ptr ctx(ggml_init(params)); // smart ptr
1464        GGML_ASSERT(ctx);
1465
1466        ggml_tensor * out             = build_graph(ctx.get());
1467        current_op_name               = op_desc(out);
1468        if (!matches_filter(out, op_names_filter)) {
1469            //printf("  %s: skipping\n", op_desc(out).c_str());
1470            return true;
1471        }
1472
1473        if (!ggml_backend_supports_op(backend, out)) {
1474            // Create test result for unsupported performance test
1475            test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", false, false,
1476                               "not supported");
1477
1478            output_printer->print_test_result(result);
1479
1480            return true;
1481        }
1482
1483        // allocate
1484        ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
1485
1486        if (buf == NULL) {
1487            printf("failed to allocate tensors\n");
1488            return false;
1489        }
1490
1491        // randomize tensors
1492        initialize_tensors(ctx.get());
1493
1494        // build graph
1495        ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false);
1496        ggml_build_forward_expand(gf, out);
1497
1498        // warmup run
1499        ggml_status status = ggml_backend_graph_compute(backend, gf);
1500        if (status != GGML_STATUS_SUCCESS) {
1501            fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
1502            return false;
1503        }
1504
1505        // determine number of runs
1506        int n_runs;
1507        bool is_cpu = ggml_backend_dev_type(ggml_backend_get_device(backend)) == GGML_BACKEND_DEVICE_TYPE_CPU;
1508        if (op_flops(out) > 0) {
1509            // based on flops
1510            const uint64_t GFLOP = 1000 * 1000 * 1000;
1511            const uint64_t target_flops_cpu =   8ULL * GFLOP;
1512            const uint64_t target_flops_gpu = 100ULL * GFLOP;
1513            uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;
1514            n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
1515        } else {
1516            // based on memory size
1517            const size_t GB = 1ULL << 30;
1518            const size_t target_size_cpu =  8 * GB;
1519            const size_t target_size_gpu = 32 * GB;
1520            size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;
1521            n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
1522        }
1523
1524        // duplicate the op
1525        for (int i = 1; i < n_runs; i++) {
1526            ggml_graph_add_node(gf, out);
1527        }
1528
1529        // calculate memory
1530        size_t mem = n_runs * op_size(out);
1531        auto tensor_op_size = [](ggml_tensor * t) {
1532            size_t size = ggml_nbytes(t);
1533            // add source tensors
1534            for (int i = 0; i < GGML_MAX_SRC; i++) {
1535                if (t->src[i] != NULL) {
1536                    size += ggml_nbytes(t->src[i]);
1537                }
1538            }
1539            return size;
1540        };
1541        for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) {
1542            if (ggml_is_view_op(ggml_graph_node(gf, i)->op) || ggml_graph_node(gf, i) == out) {
1543                continue;
1544            }
1545            mem += tensor_op_size(ggml_graph_node(gf, i));
1546        }
1547
1548        // run
1549        int64_t total_time_us = 0;
1550        int64_t total_mem = 0;
1551        int total_runs = 0;
1552        do {
1553            int64_t start_time = ggml_time_us();
1554            ggml_status status = ggml_backend_graph_compute(backend, gf);
1555            if (status != GGML_STATUS_SUCCESS) {
1556                fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
1557                return false;
1558            }
1559            int64_t end_time = ggml_time_us();
1560
1561            total_time_us += end_time - start_time;
1562            total_mem += mem;
1563            total_runs += n_runs;
1564        } while (total_time_us < 1000*1000); // run for at least 1 second
1565
1566        // Create test result
1567        double avg_time_us      = (double) total_time_us / total_runs;
1568        double calculated_flops = (op_flops(out) > 0) ? (op_flops(out) * total_runs) / (total_time_us / 1e6) : 0.0;
1569        double calculated_bandwidth =
1570            (op_flops(out) == 0) ? total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0 : 0.0;
1571        size_t calculated_memory_kb = op_size(out) / 1024;
1572
1573        test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", true, true, "", avg_time_us,
1574                           calculated_flops, calculated_bandwidth, calculated_memory_kb, total_runs);
1575
1576        if (output_printer) {
1577            output_printer->print_test_result(result);
1578        }
1579
1580        return true;
1581    }
1582
1583    bool eval_support(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
1584        mode = MODE_SUPPORT;
1585
1586        static const size_t graph_nodes = 8192;
1587
1588        ggml_init_params params = {
1589            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false),
1590            /* .mem_base = */ NULL,
1591            /* .no_alloc = */ true,
1592        };
1593        ggml_context_ptr ctx(ggml_init(params)); // smart ptr
1594        GGML_ASSERT(ctx);
1595
1596        gf = ggml_new_graph_custom(ctx.get(), graph_nodes, false);
1597
1598        ggml_tensor * out = build_graph(ctx.get());
1599        current_op_name   = op_desc(out);
1600
1601        if (!matches_filter(out, op_names_filter)) {
1602            return true;
1603        }
1604
1605        bool supported = ggml_backend_supports_op(backend, out);
1606
1607        std::string device_desc = ggml_backend_dev_description(ggml_backend_get_device(backend));
1608        std::string backend_reg_name = ggml_backend_reg_name(ggml_backend_dev_backend_reg(ggml_backend_get_device(backend)));
1609
1610        test_result result(ggml_backend_name(backend), current_op_name, vars(), "support", supported, supported,
1611                           supported ? "yes" : "no", 0.0, 0.0, 0.0, 0, 0, device_desc, backend_reg_name);
1612
1613        output_printer->print_test_result(result);
1614
1615        return true;
1616    }
1617
1618    bool eval_grad(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
1619        mode = MODE_GRAD;
1620        const std::vector<float> expect = grad_expect();
1621
1622        ggml_init_params params = {
1623            /* .mem_size = */ ggml_tensor_overhead()*128 + 2*ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, true),
1624            /* .mem_base = */ NULL,
1625            /* .no_alloc = */ true,
1626        };
1627        ggml_context_ptr ctx(ggml_init(params)); // smart ptr
1628        GGML_ASSERT(ctx);
1629
1630        gf = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
1631        gb = ggml_new_graph_custom(ctx.get(), GGML_DEFAULT_GRAPH_SIZE, true);
1632
1633        ggml_tensor * out = build_graph(ctx.get());
1634
1635        if (!matches_filter(out, op_names_filter) || out->op == GGML_OP_OPT_STEP_ADAMW) {
1636            return true;
1637        }
1638
1639        if (out->type != GGML_TYPE_F32) {
1640            output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend),
1641                                                                test_status_t::NOT_SUPPORTED,
1642                                                                out->name + std::string("->type != FP32")));
1643            return true;
1644        }
1645
1646        // Print operation info first
1647        output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend)));
1648
1649        // check if the backend supports the ops
1650        bool        supported  = true;
1651        bool        any_params = false;
1652        std::string failure_reason;
1653
1654        for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
1655            if (!ggml_backend_supports_op(backend, t)) {
1656                supported      = false;
1657                failure_reason = ggml_backend_name(backend);
1658                break;
1659            }
1660            if ((t->flags & GGML_TENSOR_FLAG_PARAM)) {
1661                any_params = true;
1662                if (t->type != GGML_TYPE_F32) {
1663                    supported      = false;
1664                    failure_reason = std::string(t->name) + "->type != FP32";
1665                    break;
1666                }
1667            }
1668        }
1669        if (!any_params) {
1670            supported      = false;
1671            failure_reason = op_desc(out);
1672        }
1673
1674        if (!supported) {
1675            output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend),
1676                                                                test_status_t::NOT_SUPPORTED, failure_reason));
1677            return true;
1678        }
1679
1680        int64_t ngrads = 0;
1681        for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
1682            if (t->flags & GGML_TENSOR_FLAG_PARAM) {
1683                ngrads += ggml_nelements(t);
1684            }
1685        }
1686        if (ngrads > grad_nmax()) {
1687            test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));
1688            info.set_large_tensor_skip();
1689            output_printer->print_operation(info);
1690            return true;
1691        }
1692
1693
1694        if (!ggml_is_scalar(out)) {
1695            out = ggml_sum(ctx.get(), out);
1696            ggml_set_name(out, "sum_of_out");
1697        }
1698        ggml_set_loss(out);
1699
1700        ggml_build_forward_expand(gf, out);
1701        ggml_graph_cpy(gf, gb);
1702        ggml_build_backward_expand(ctx.get(), gb, nullptr);
1703        if (expect.size() != 1 || expect[0] != 0.0f) {
1704            GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
1705            for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
1706                GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE);
1707            }
1708        }
1709
1710        for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
1711            if (!ggml_backend_supports_op(backend, t)) {
1712                output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend),
1713                                                                    test_status_t::NOT_SUPPORTED,
1714                                                                    ggml_backend_name(backend)));
1715                supported = false;
1716                break;
1717            }
1718            if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {
1719                output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend),
1720                                                                    test_status_t::NOT_SUPPORTED,
1721                                                                    std::string(t->name) + "->type != FP32"));
1722                supported = false;
1723                break;
1724            }
1725        }
1726        if (!supported) {
1727            return true;
1728        }
1729
1730        // allocate
1731        ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
1732        if (buf == NULL) {
1733            test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));
1734            info.set_error("allocation", "");
1735            output_printer->print_operation(info);
1736            return false;
1737        }
1738
1739        initialize_tensors(ctx.get()); // Randomizes all tensors (including gradients).
1740        ggml_graph_reset(gb);    // Sets gradients to 1 if loss, 0 otherwise.
1741
1742        ggml_status status = ggml_backend_graph_compute(backend, gf);
1743        if (status != GGML_STATUS_SUCCESS) {
1744            fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
1745            return false;
1746        }
1747        status = ggml_backend_graph_compute(backend, gb);
1748        if (status != GGML_STATUS_SUCCESS) {
1749            fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
1750            return false;
1751        }
1752
1753        bool ok = true;
1754        for (struct ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) {
1755            if (!(t->flags & GGML_TENSOR_FLAG_PARAM)) {
1756                continue;
1757            }
1758
1759            const char * bn = ggml_backend_name(backend);
1760            const int64_t ne = ggml_nelements(t);
1761
1762            std::vector<float> ga;
1763            struct ggml_tensor * grad = ggml_graph_get_grad(gb, t);
1764            if (grad) {
1765                ga = tensor_to_float(grad);
1766            } else {
1767                ga.resize(ne); // default value is 0.0f
1768            }
1769
1770            for (int64_t i = 0; i < ne; ++i) { // gradient algebraic
1771                // check for nans
1772                if (!std::isfinite(ga[i])) {
1773                    test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));
1774                    info.set_gradient_info(i, bn, ga[i]);
1775                    output_printer->print_operation(info);
1776                    ok = false;
1777                    break;
1778                }
1779            }
1780            if (!ok) {
1781                break;
1782            }
1783
1784            std::vector<float> gn(ne); // gradient numeric
1785            GGML_ASSERT(ga.size() == gn.size());
1786
1787            std::vector<float> x0 = tensor_to_float(t); // original t data
1788            GGML_ASSERT(ggml_is_scalar(out));
1789            GGML_ASSERT(out->type == GGML_TYPE_F32);
1790
1791            const float eps = grad_eps();
1792            for (int64_t i = 0; i < ne; ++i) {
1793                const float xiu  = x0[i] + 1.0f*eps; // x, index i, up
1794                const float xiuh = x0[i] + 0.5f*eps; // x, index i, up half
1795                const float xidh = x0[i] - 0.5f*eps; // x, index i, down half
1796                const float xid  = x0[i] - 1.0f*eps; // x, index i, down
1797
1798                float fu, fuh, fdh, fd; // output values for xiu, xiuh, xid, xidh
1799
1800                ggml_backend_tensor_set(t, &xiu, i*sizeof(float), sizeof(float));
1801                status = ggml_backend_graph_compute(backend, gf);
1802                if (status != GGML_STATUS_SUCCESS) {
1803                    fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
1804                    return false;
1805                }
1806                ggml_backend_tensor_get(out, &fu, 0, ggml_nbytes(out));
1807
1808                ggml_backend_tensor_set(t, &xid, i*sizeof(float), sizeof(float));
1809                status = ggml_backend_graph_compute(backend, gf);
1810                if (status != GGML_STATUS_SUCCESS) {
1811                    fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
1812                    return false;
1813                }
1814                ggml_backend_tensor_get(out, &fd, 0, ggml_nbytes(out));
1815
1816                if (grad_precise()) {
1817                    ggml_backend_tensor_set(t, &xiuh, i*sizeof(float), sizeof(float));
1818                    status = ggml_backend_graph_compute(backend, gf);
1819                    if (status != GGML_STATUS_SUCCESS) {
1820                        fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
1821                        return false;
1822                    }
1823                    ggml_backend_tensor_get(out, &fuh, 0, ggml_nbytes(out));
1824
1825                    ggml_backend_tensor_set(t, &xidh, i*sizeof(float), sizeof(float));
1826                    status = ggml_backend_graph_compute(backend, gf);
1827                    if (status != GGML_STATUS_SUCCESS) {
1828                        fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status));
1829                        return false;
1830                    }
1831                    ggml_backend_tensor_get(out, &fdh, 0, ggml_nbytes(out));
1832
1833                    gn[i] = (8.0*(double)fuh + (double)fd - (8.0*(double)fdh + (double)fu)) / (6.0*(double)eps);
1834                } else {
1835                    gn[i] = (fu - fd) / (2.0f*eps);
1836                }
1837
1838                ggml_backend_tensor_set(t, x0.data(), 0, ggml_nbytes(t));
1839            }
1840
1841            const double err = mean_abs_asymm(gn.data(), ga.data(), gn.size(), expect);
1842            if (err > max_maa_err()) {
1843                test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));
1844                info.set_maa_error(err, max_maa_err());
1845                output_printer->print_operation(info);
1846                ok = false;
1847                break;
1848            }
1849            if (!ok) {
1850                break;
1851            }
1852        }
1853
1854        // Create final test result
1855        test_operation_info final_info(op_desc(out), vars(), ggml_backend_name(backend));
1856        if (!ok) {
1857            final_info.set_compare_failure();
1858        }
1859        final_info.status = ok ? test_status_t::OK : test_status_t::FAIL;
1860        output_printer->print_operation(final_info);
1861
1862        if (ok) {
1863            return true;
1864        }
1865
1866        return false;
1867    }
1868};
1869
1870
1871// ###################################
1872// ## Section 2: GGML Op Defintions ##
1873// ###################################
1874
1875
1876// The following is an example showing the bare minimum for creating a test for a GGML op.
1877
1878// GGML_OP_EXAMPLE
1879struct test_example : public test_case {
1880    // Always define these 2 or variants thereof:
1881    const ggml_type type; // The type of the input tensors.
1882    const std::array<int64_t, 4> ne; // The shape of the input tensors.
1883    // For some ops it's necessary to define multiple types or shapes for the inputs.
1884    // Or they may need additional parameters.
1885
1886    // Put all parameters needed to fully define the test into one of the VARS_TO_STR macros.
1887    // In most cases these are just the properties of the struct that you defined above.
1888    // This is needed for info prints.
1889    std::string vars() override {
1890        return VARS_TO_STR2(type, ne);
1891    }
1892
1893    // Define a constructor for the struct.
1894    // In most cases it will be sufficient to have the same arguments as the struct has properties
1895    // and just use initializer lists.
1896    test_example(ggml_type type = GGML_TYPE_F32,
1897            std::array<int64_t, 4> ne = {10, 5, 4, 3})
1898        : type(type), ne(ne) {}
1899
1900    // Define how a simple GGML compute graph can be constructed for the new GGML op.
1901    ggml_tensor * build_graph(ggml_context * ctx) override {
1902        // Step 1: create input tensors that don't depend on any other tensors:
1903        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1904        ggml_set_name(a, "a"); // Setting names is optional but it's useful for debugging.
1905
1906        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
1907        ggml_set_name(b, "b");
1908
1909        // Step 2: use the op that you want to test in the GGML compute graph.
1910        ggml_tensor * out = ggml_add(ctx, a, b); // For this example we're just doing a simple addition.
1911        ggml_set_name(out, "out");
1912
1913        // Step 3: return the output tensor.
1914        return out;
1915    }
1916    // In order to also check the gradients for your op, add calls like ggml_set_param(a)
1917    // immediately after you create the tensors.
1918    // This is optional and only makes sense if a backward pass has actually been implemented for the new op.
1919};
1920
1921
1922// GGML_OP_UNARY
1923struct test_unary : public test_case {
1924    const ggml_unary_op op;
1925    const ggml_type type;
1926    const std::array<int64_t, 4> ne_a;
1927    int v; // view (1 : non-contiguous a)
1928
1929    std::string vars() override {
1930        return VARS_TO_STR3(type, ne_a, v);
1931    }
1932
1933    test_unary(ggml_unary_op op,
1934            ggml_type type = GGML_TYPE_F32,
1935            std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
1936            int v = 0)
1937        : op(op), type(type), ne_a(ne_a), v(v) {}
1938
1939    ggml_tensor * build_graph(ggml_context * ctx) override {
1940        const bool grad_supported = op == GGML_UNARY_OP_ABS || op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_NEG ||
1941            op == GGML_UNARY_OP_STEP || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU ||
1942            op == GGML_UNARY_OP_EXPM1 || op == GGML_UNARY_OP_SOFTPLUS;
1943
1944        ggml_tensor * a;
1945        if (v & 1) {
1946            auto ne = ne_a;
1947            ne[0] *= 3;
1948            ne[1] *= 2;
1949            ne[2] *= 5;
1950            ne[3] *= 4;
1951            a = ggml_new_tensor(ctx, type, 4, ne.data());
1952            if (grad_supported) {
1953                ggml_set_param(a);
1954            }
1955            ggml_set_name(a, "a");
1956
1957            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
1958            ggml_set_name(a, "view_of_a");
1959        } else {
1960            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1961            if (grad_supported) {
1962                ggml_set_param(a);
1963            }
1964            ggml_set_name(a, "a");
1965        }
1966
1967        ggml_tensor * out = ggml_unary(ctx, a, op);
1968        ggml_set_name(out, "out");
1969
1970        return out;
1971    }
1972
1973    void initialize_tensors(ggml_context * ctx) override {
1974        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1975            // test extended range of values to check for NaNs in GELU
1976            init_tensor_uniform(t, -150.f, 150.f);
1977        }
1978    }
1979
1980    float grad_eps() override {
1981        return 15.0f;
1982    }
1983
1984    std::vector<float> grad_expect() override {
1985        if (op == GGML_UNARY_OP_ABS) {
1986            return {-1.0f, 1.0f};
1987        }
1988        if (op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_STEP) {
1989            return {0.0f};
1990        }
1991        if (op == GGML_UNARY_OP_RELU) {
1992            return {0.0f, 1.0f};
1993        }
1994        return {};
1995    }
1996
1997};
1998
1999// GGML_OP_GLU
2000struct test_glu : public test_case {
2001    const ggml_glu_op op;
2002    const ggml_type type;
2003    const std::array<int64_t, 4> ne_a;
2004    int v; // view (1 : non-contiguous a)
2005    bool swapped;
2006
2007    std::string vars() override {
2008        return VARS_TO_STR4(type, ne_a, v, swapped);
2009    }
2010
2011    test_glu(ggml_glu_op op,
2012            ggml_type type = GGML_TYPE_F32,
2013            std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
2014            int v = 0,
2015            bool swapped = false)
2016        : op(op), type(type), ne_a(ne_a), v(v), swapped(swapped) {}
2017
2018    ggml_tensor * build_graph(ggml_context * ctx) override {
2019        ggml_tensor * a;
2020        if (v & 1) {
2021            auto ne = ne_a; ne[0] *= 3;
2022            a = ggml_new_tensor(ctx, type, 4, ne.data());
2023            ggml_set_name(a, "a");
2024
2025            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
2026            ggml_set_name(a, "view_of_a");
2027        } else {
2028            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
2029            ggml_set_name(a, "a");
2030        }
2031
2032        ggml_tensor * out = ggml_glu(ctx, a, op, swapped);
2033        ggml_set_name(out, "out");
2034
2035        return out;
2036    }
2037
2038    void initialize_tensors(ggml_context * ctx) override {
2039        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2040            // test extended range of values to check for NaNs in GELU
2041            init_tensor_uniform(t, -150.f, 150.f);
2042        }
2043    }
2044};
2045
2046struct test_glu_split : public test_case {
2047    const ggml_glu_op op;
2048    const ggml_type type;
2049    const std::array<int64_t, 4> ne_a;
2050    int v; // view (1 : non-contiguous a)
2051
2052    std::string vars() override {
2053        return VARS_TO_STR3(type, ne_a, v) + ",split";
2054    }
2055
2056    test_glu_split(ggml_glu_op op,
2057            ggml_type type = GGML_TYPE_F32,
2058            std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
2059            int v = 0)
2060        : op(op), type(type), ne_a(ne_a), v(v) {}
2061
2062    ggml_tensor * build_graph(ggml_context * ctx) override {
2063        ggml_tensor * a;
2064        ggml_tensor * b;
2065        if (v & 1) {
2066            auto ne = ne_a; ne[0] *= 3;
2067            a = ggml_new_tensor(ctx, type, 4, ne.data());
2068            ggml_set_param(a);
2069            ggml_set_name(a, "a");
2070
2071            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
2072            ggml_set_name(a, "view_of_a");
2073
2074            b = ggml_new_tensor(ctx, type, 4, ne.data());
2075            ggml_set_param(b);
2076            ggml_set_name(b, "b");
2077
2078            b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
2079            ggml_set_name(a, "view_of_b");
2080        } else {
2081            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
2082            ggml_set_param(a);
2083            ggml_set_name(a, "a");
2084
2085            b = ggml_new_tensor(ctx, type, 4, ne_a.data());
2086            ggml_set_param(b);
2087            ggml_set_name(b, "b");
2088        }
2089
2090        ggml_tensor * out = ggml_glu_split(ctx, a, b, op);
2091        ggml_set_name(out, "out");
2092
2093        return out;
2094    }
2095
2096    void initialize_tensors(ggml_context * ctx) override {
2097        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2098            // test extended range of values to check for NaNs in GELU
2099            init_tensor_uniform(t, -150.f, 150.f);
2100        }
2101    }
2102};
2103
2104struct test_swiglu_oai : public test_case {
2105    const ggml_type type;
2106    const std::array<int64_t, 4> ne_a;
2107    int v; // view (1 : non-contiguous a)
2108    float alpha;
2109    float limit;
2110
2111    std::string vars() override {
2112        return VARS_TO_STR5(type, ne_a, v, alpha, limit);
2113    }
2114
2115    test_swiglu_oai(ggml_type type = GGML_TYPE_F32,
2116                    std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
2117                    int v = 0,
2118                    float alpha = 1.702f,
2119                    float limit = 7.0f)
2120        : type(type), ne_a(ne_a), v(v), alpha(alpha), limit(limit) {}
2121
2122    ggml_tensor * build_graph(ggml_context * ctx) override {
2123        ggml_tensor * a;
2124        ggml_tensor * b;
2125        if (v & 1) {
2126            auto ne = ne_a; ne[0] *= 3;
2127            a = ggml_new_tensor(ctx, type, 4, ne.data());
2128            ggml_set_param(a);
2129            ggml_set_name(a, "a");
2130
2131            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
2132            ggml_set_name(a, "view_of_a");
2133
2134            b = ggml_new_tensor(ctx, type, 4, ne.data());
2135            ggml_set_param(b);
2136            ggml_set_name(b, "b");
2137
2138            b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
2139            ggml_set_name(a, "view_of_b");
2140        } else {
2141            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
2142            ggml_set_param(a);
2143            ggml_set_name(a, "a");
2144
2145            b = ggml_new_tensor(ctx, type, 4, ne_a.data());
2146            ggml_set_param(b);
2147            ggml_set_name(b, "b");
2148        }
2149
2150        ggml_tensor * out = ggml_swiglu_oai(ctx, a, b, alpha, limit);
2151        ggml_set_name(out, "out");
2152
2153        return out;
2154    }
2155
2156    void initialize_tensors(ggml_context * ctx) override {
2157        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2158            // test extended range of values to check for NaNs in GELU
2159            init_tensor_uniform(t, -150.f, 150.f);
2160        }
2161    }
2162};
2163
2164// GGML_OP_GET_ROWS
2165struct test_get_rows : public test_case {
2166    const ggml_type type;
2167    const int n; // cols
2168    const int m; // rows
2169    const int r; // rows to get
2170    const int be1; // batch size
2171    const int be2; // batch size
2172    const bool v; // view (non-contiguous src1)
2173
2174    std::string vars() override {
2175        return VARS_TO_STR7(type, n, m, r, be1, be2, v);
2176    }
2177
2178    test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int be1 = 1, int be2 = 1, bool v = false)
2179        : type(type), n(n), m(m), r(r), be1(be1), be2(be2), v(v) {}
2180
2181    ggml_tensor * build_graph(ggml_context * ctx) override {
2182        ggml_tensor * in = ggml_new_tensor_4d(ctx, type, n, m, be1, be2);
2183        ggml_set_name(in, "in");
2184
2185        ggml_tensor * rows = ggml_new_tensor_3d(ctx, GGML_TYPE_I32, r, be1, be2);
2186        ggml_set_name(rows, "rows");
2187        if (v) {
2188            rows = ggml_view_3d(ctx, rows, r/2, be1, be2, rows->nb[1], rows->nb[2], 0);
2189            ggml_set_name(rows, "view_of_rows");
2190        }
2191
2192        const bool grad_supported = ggml_is_matrix(in) && ggml_is_vector(rows);
2193        if (grad_supported) {
2194            ggml_set_param(in);
2195            // rows is a constant input -> no gradients
2196        }
2197
2198        ggml_tensor * out = ggml_get_rows(ctx, in, rows);
2199        ggml_set_name(out, "out");
2200
2201        return out;
2202    }
2203
2204    void initialize_tensors(ggml_context * ctx) override {
2205        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2206            if (t->type == GGML_TYPE_I32) {
2207                if (ggml_is_view_op(t->op)) { continue; }
2208                // rows
2209                std::vector<int> data(r*be1*be2);
2210                for (int i = 0; i < r*be1*be2; i++) {
2211                    data[i] = rand() % m;
2212                }
2213                ggml_backend_tensor_set(t, data.data(), 0, r * be1 * be2 * sizeof(int));
2214            } else {
2215                init_tensor_uniform(t);
2216            }
2217        }
2218    }
2219};
2220
2221// GGML_OP_GET_ROWS_BACK
2222struct test_get_rows_back : public test_case {
2223    const ggml_type type;
2224    const int n; // cols
2225    const int m; // rows
2226    const int r; // rows to get
2227    const int b; // batch size
2228    const bool v; // view (non-contiguous src1)
2229
2230    std::string vars() override {
2231        return VARS_TO_STR6(type, n, m, r, b, v);
2232    }
2233
2234    test_get_rows_back(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
2235        : type(type), n(n), m(m), r(r), b(b), v(v) {}
2236
2237    ggml_tensor * build_graph(ggml_context * ctx) override {
2238        ggml_tensor * in_forward = ggml_new_tensor_3d(ctx, type, n, m, b);
2239        ggml_set_name(in_forward, "in_forward");
2240
2241        ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
2242        ggml_set_name(rows, "rows");
2243        if (v) {
2244            rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
2245            ggml_set_name(rows, "view_of_rows");
2246        }
2247
2248        ggml_tensor * grad = ggml_new_tensor_3d(ctx, type, n, r, b);
2249        ggml_set_name(grad, "grad");
2250
2251        ggml_tensor * out = ggml_get_rows_back(ctx, grad, rows, in_forward);
2252        ggml_set_name(out, "out");
2253
2254        return out;
2255    }
2256
2257    void initialize_tensors(ggml_context * ctx) override {
2258        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2259            if (t->type == GGML_TYPE_I32) {
2260                if (ggml_is_view_op(t->op)) { continue; }
2261                // rows
2262                std::vector<int> data(r*b);
2263                for (int i = 0; i < r*b; i++) {
2264                    data[i] = rand() % m;
2265                }
2266                ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
2267            } else {
2268                init_tensor_uniform(t);
2269            }
2270        }
2271    }
2272};
2273
2274static void init_set_rows_row_ids(ggml_tensor * t, int num_rows) {
2275    std::random_device rd;
2276    std::default_random_engine rng(rd());
2277    for (int i2 = 0; i2 < t->ne[2]; i2++) {
2278        for (int i1 = 0; i1 < t->ne[1]; i1++) {
2279            // generate a shuffled subset of row indices
2280            std::vector<int64_t> data(num_rows);
2281            for (int i = 0; i < num_rows; i++) {
2282                data[i] = i;
2283            }
2284            std::shuffle(data.begin(), data.end(), rng);
2285            data.resize(t->ne[0]);
2286
2287            const size_t offs = i1*t->nb[1] + i2*t->nb[2];
2288            if (t->type == GGML_TYPE_I32) {
2289                // TODO: Make a template or something
2290                std::vector<int32_t> data_i32(t->ne[0]);
2291                for (int i = 0; i < t->ne[0]; i++) {
2292                    data_i32[i] = static_cast<int32_t>(data[i]);
2293                }
2294                ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t));
2295            } else {
2296                ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
2297            }
2298        }
2299    }
2300}
2301
2302// GGML_OP_SET_ROWS
2303struct test_set_rows : public test_case {
2304    const ggml_type type;
2305    const ggml_type type_idx;
2306    const std::array<int64_t, 4> ne;
2307    const std::array<int, 2> nr23; // broadcast only dims 2 and 3
2308    const int r; // rows to set
2309    const bool v; // view (non-contiguous src1)
2310
2311    std::string vars() override {
2312        return VARS_TO_STR6(type, type_idx, ne, nr23, r, v);
2313    }
2314
2315    test_set_rows(ggml_type type,
2316            ggml_type type_idx,
2317            std::array<int64_t, 4> ne,
2318            std::array<int, 2> nr23,
2319            int r, bool v = false)
2320        : type(type), type_idx(type_idx), ne(ne), nr23(nr23), r(r), v(v) {}
2321
2322    ggml_tensor * build_graph(ggml_context * ctx) override {
2323        ggml_tensor * dst = ggml_new_tensor_4d(ctx, type,          ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
2324        ggml_set_name(dst, "dst");
2325
2326        ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], r,     ne[2]*nr23[0], ne[3]*nr23[1]);
2327        ggml_set_name(src, "src");
2328
2329        ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, r, ne[2], ne[3]);
2330        ggml_set_name(row_idxs, "row_idxs");
2331
2332        if (v) {
2333            src      = ggml_view_4d(ctx, src, ne[0], r/2, ne[2]*nr23[0], ne[3]*nr23[1], src->nb[1], src->nb[2], src->nb[3], 0);
2334            row_idxs = ggml_view_3d(ctx, row_idxs, r/2, ne[2], ne[3], row_idxs->nb[1], row_idxs->nb[2], 0);
2335            ggml_set_name(row_idxs, "view_of_rows");
2336        }
2337
2338        ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs);
2339        ggml_set_name(out, "out");
2340
2341        return out;
2342    }
2343
2344    void initialize_tensors(ggml_context * ctx) override {
2345        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2346            if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
2347                if (ggml_is_view_op(t->op)) {
2348                    continue;
2349                }
2350
2351                init_set_rows_row_ids(t, ne[1]);
2352            } else {
2353                init_tensor_uniform(t);
2354            }
2355        }
2356    }
2357
2358    double max_nmse_err() override {
2359        if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL ||
2360            type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1 || type == GGML_TYPE_Q8_0) {
2361            // estimate what the max nmse error would be if one quantized value is
2362            // off by one. The test values are distributed in [-1,1], so it'll be
2363            // roughly (2.0 / 2^bits)^2, divided by the mean square value of the reference,
2364            // which is roughly 0.25 times the number of elements.
2365            double err_estimate = 1.0f/8.0f;
2366            if (type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
2367                err_estimate /= 2.0f;
2368            }
2369            if (type == GGML_TYPE_Q8_0) {
2370                err_estimate /= 8.0f;
2371            }
2372            err_estimate *= err_estimate;
2373            err_estimate /= 0.25f*float(ne[0] * r * ne[2]*nr23[0] * ne[3]*nr23[1]);
2374            return err_estimate;
2375        }
2376        return 1e-7;
2377    }
2378};
2379
2380// GGML_OP_ROPE + GGML_OP_VIEW + GGML_OP_SET_ROWS
2381struct test_rope_set_rows : public test_case {
2382    const ggml_type type;
2383    const ggml_type type_idx;
2384    const std::array<int64_t, 4> ne_a;
2385    int mode;
2386    const int n_ctx{512};
2387    const int n_dims{128};
2388
2389    std::string vars() override {
2390        return VARS_TO_STR4(type, type_idx, ne_a, mode);
2391    }
2392
2393    std::string op_desc(ggml_tensor * t) override {
2394        GGML_UNUSED(t);
2395        return "ROPE_SET_ROWS";
2396    }
2397
2398    bool run_whole_graph() override { return true; }
2399
2400    test_rope_set_rows(ggml_type type,
2401            ggml_type type_idx,
2402            std::array<int64_t, 4> ne_a,
2403            int mode)
2404        : type(type), type_idx(type_idx), ne_a(ne_a), mode(mode) {}
2405
2406    ggml_tensor * build_graph(ggml_context * ctx) override {
2407        ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne_a[0], ne_a[1], ne_a[2], 1);
2408        ggml_set_name(a, "a");
2409
2410        const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
2411        const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
2412
2413        ggml_tensor * pos;
2414        if (is_mrope || is_vision) {
2415            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);
2416        } else {
2417            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
2418        }
2419        ggml_set_name(pos, "pos");
2420
2421        float fs = 1.4245f;
2422        float ef = 0.7465f;
2423        float af = 1.4245f;
2424        ggml_tensor * freq = nullptr;
2425
2426        ggml_tensor * rope = nullptr;
2427        if (is_mrope) {
2428            if (is_vision) {
2429                GGML_ASSERT(n_dims/4 > 0);
2430                int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
2431                rope = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2432            } else {
2433                GGML_ASSERT(n_dims/3 > 0);
2434                int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
2435                rope = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2436            }
2437        } else {
2438            rope = ggml_rope(ctx, a, pos, ne_a[0], mode);
2439        }
2440
2441        ggml_tensor * view = ggml_view_2d(ctx, rope, ne_a[0] * ne_a[1], ne_a[2], rope->nb[2], 0);
2442
2443        ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne_a[0] * ne_a[1], ne_a[2] * ne_a[3], 1, 1);
2444        ggml_set_name(dst, "dst");
2445
2446        ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, ne_a[2], 1, 1);
2447        ggml_set_name(row_idxs, "row_idxs");
2448
2449        ggml_tensor * out = ggml_set_rows(ctx, dst, view, row_idxs);
2450        ggml_set_name(out, "out");
2451
2452        return out;
2453    }
2454
2455    void initialize_tensors(ggml_context * ctx) override {
2456        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2457            if (strcmp(t->name, "row_idxs") == 0) {
2458                if (ggml_is_view_op(t->op)) {
2459                    continue;
2460                }
2461                init_set_rows_row_ids(t, ne_a[2]);
2462            } else if (t->type == GGML_TYPE_I32) {
2463                // pos
2464                const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];
2465                std::vector<int> data(num_pos_ids);
2466                for (int i = 0; i < num_pos_ids; i++) {
2467                    data[i] = rand() % n_ctx;
2468                }
2469                ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));
2470            } else {
2471                if (t->ne[0] == n_dims/2) {
2472                    // frequency factors in the range [0.9f, 1.1f]
2473                    init_tensor_uniform(t, 0.9f, 1.1f);
2474                } else {
2475                    init_tensor_uniform(t);
2476                }
2477            }
2478        }
2479    }
2480};
2481
2482// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ROPE (+ GGML_OP_VIEW + GGML_OP_SET_ROWS)
2483struct test_rms_norm_mul_rope : public test_case {
2484    const std::array<int64_t, 4> ne;
2485    const float eps;
2486    const bool multi_add; // test a sequence of adds feeding into rms_norm
2487    const bool set_rows;
2488    int mode;
2489
2490    std::string op_desc(ggml_tensor * t) override {
2491        GGML_UNUSED(t);
2492        return "RMS_NORM_MUL_ROPE";
2493    }
2494
2495    bool run_whole_graph() override { return true; }
2496
2497    std::string vars() override {
2498        return VARS_TO_STR5(ne, eps, multi_add, set_rows, mode);
2499    }
2500
2501    test_rms_norm_mul_rope(std::array<int64_t, 4> ne, float eps = 1e-6f, bool multi_add = false,
2502                           bool set_rows = false, int mode = GGML_ROPE_TYPE_NORMAL)
2503        : ne(ne), eps(eps), multi_add(multi_add), set_rows(set_rows), mode(mode) {}
2504
2505    ggml_tensor * build_graph(ggml_context * ctx) override {
2506        ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
2507        ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
2508        ggml_tensor * c = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
2509
2510        if (multi_add) {
2511            a = ggml_add(ctx, ggml_add(ctx, a, b), c);
2512        }
2513
2514        a = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
2515
2516        ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
2517
2518        ggml_tensor * rope = ggml_rope(ctx, a, pos, ne[0], mode);
2519
2520        ggml_tensor * out;
2521
2522        if (set_rows) {
2523            ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);
2524
2525            ggml_tensor * dst = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, ne[0] * ne[1], ne[2] * ne[3], 1, 1);
2526            ggml_set_name(dst, "dst");
2527
2528            ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, ne[2], 1, 1);
2529            ggml_set_name(row_idxs, "row_idxs");
2530
2531            out = ggml_set_rows(ctx, dst, view, row_idxs);
2532            ggml_set_name(out, "out");
2533        } else {
2534            out = rope;
2535        }
2536
2537        return out;
2538    }
2539
2540    void initialize_tensors(ggml_context * ctx) override {
2541        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2542            if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
2543                if (ggml_is_view_op(t->op)) {
2544                    continue;
2545                }
2546
2547                init_set_rows_row_ids(t, ne[2]);
2548            } else {
2549                init_tensor_uniform(t);
2550            }
2551        }
2552    }
2553};
2554
2555// GGML_OP_ARGMAX
2556struct test_argmax : public test_case {
2557    const ggml_type type;
2558    const std::array<int64_t, 4> ne;
2559
2560    std::string vars() override {
2561        return VARS_TO_STR2(type, ne);
2562    }
2563
2564    test_argmax(ggml_type type = GGML_TYPE_F32,
2565            std::array<int64_t, 4> ne = {10, 100, 1, 1})
2566        : type(type), ne(ne) {}
2567
2568    ggml_tensor * build_graph(ggml_context * ctx) override {
2569        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2570        ggml_set_name(a, "a");
2571
2572        ggml_tensor * out = ggml_argmax(ctx, a);
2573        ggml_set_name(out, "out");
2574
2575        return out;
2576    }
2577
2578    void initialize_tensors(ggml_context * ctx) override {
2579        std::random_device rd;
2580        std::default_random_engine rng(rd());
2581        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2582            if (t->type == GGML_TYPE_F32) {
2583                // initialize with unique values to avoid ties
2584                for (int64_t r = 0; r < ggml_nrows(t); r++) {
2585                    std::vector<float> data(t->ne[0]);
2586                    for (int i = 0; i < t->ne[0]; i++) {
2587                        data[i] = i;
2588                    }
2589                    std::shuffle(data.begin(), data.end(), rng);
2590                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
2591                }
2592            } else {
2593                init_tensor_uniform(t);
2594            }
2595        }
2596    }
2597
2598    double max_nmse_err() override {
2599        return 0.0;
2600    }
2601};
2602
2603// GGML_OP_COUNT_EQUAL
2604struct test_count_equal : public test_case {
2605    const ggml_type type;
2606    const std::array<int64_t, 4> ne;
2607
2608    std::string vars() override {
2609        return VARS_TO_STR2(type, ne);
2610    }
2611
2612    test_count_equal(ggml_type type = GGML_TYPE_F32,
2613            std::array<int64_t, 4> ne = {4, 500, 1, 1})
2614        : type(type), ne(ne) {}
2615
2616    ggml_tensor * build_graph(ggml_context * ctx) override {
2617        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
2618        ggml_set_name(a, "a");
2619
2620        ggml_tensor * a_argmax = ggml_argmax(ctx, a);
2621        ggml_set_name(a_argmax, "a_argmax");
2622
2623        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
2624        ggml_set_name(b, "b");
2625
2626        ggml_tensor * b_argmax = ggml_argmax(ctx, b);
2627        ggml_set_name(b_argmax, "b_argmax");
2628
2629        ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
2630        ggml_set_name(out, "out");
2631
2632        return out;
2633    }
2634
2635    double max_nmse_err() override {
2636        return 0.0;
2637    }
2638
2639    void initialize_tensors(ggml_context * ctx) override {
2640        std::random_device rd;
2641        std::default_random_engine rng(rd());
2642        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2643            if (t->type == GGML_TYPE_F32) {
2644                // initialize with unique values to avoid ties
2645                for (int64_t r = 0; r < ggml_nrows(t); r++) {
2646                    std::vector<float> data(t->ne[0]);
2647                    for (int i = 0; i < t->ne[0]; i++) {
2648                        data[i] = i;
2649                    }
2650                    std::shuffle(data.begin(), data.end(), rng);
2651                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
2652                }
2653            } else {
2654                init_tensor_uniform(t);
2655            }
2656        }
2657    }
2658};
2659
2660// GGML_OP_REPEAT
2661struct test_repeat : public test_case {
2662    const ggml_type type;
2663    const std::array<int64_t, 4> ne;
2664    const std::array<int, 4> nr;
2665
2666    std::string vars() override {
2667        return VARS_TO_STR3(type, ne, nr);
2668    }
2669
2670    size_t op_size(ggml_tensor * t) override {
2671        return ggml_nbytes(t) * 2;
2672    }
2673
2674    test_repeat(ggml_type type = GGML_TYPE_F32,
2675            std::array<int64_t, 4> ne = {10, 5, 4, 3},
2676            std::array<int, 4> nr = {2, 2, 2, 2})
2677        : type(type), ne(ne), nr(nr) {}
2678
2679    ggml_tensor * build_graph(ggml_context * ctx) override {
2680        ggml_tensor * target = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
2681        ggml_set_name(target, "target");
2682
2683        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
2684        ggml_set_param(src);
2685        ggml_set_name(src, "src");
2686
2687        ggml_tensor * out = ggml_repeat(ctx, src, target);
2688        ggml_set_name(out, "out");
2689
2690        return out;
2691    }
2692};
2693
2694// GGML_OP_REPEAT_BACK
2695struct test_repeat_back : public test_case {
2696    const ggml_type type;
2697    const std::array<int64_t, 4> ne;
2698    const std::array<int, 4> nr;
2699    const bool v; // whether src is a noncontiguous view
2700
2701    std::string vars() override {
2702        return VARS_TO_STR4(type, ne, nr, v);
2703    }
2704
2705    size_t op_size(ggml_tensor * t) override {
2706        return ggml_nbytes(t) * 2;
2707    }
2708
2709    test_repeat_back(ggml_type type = GGML_TYPE_F32,
2710            std::array<int64_t, 4> ne = {8, 6, 4, 2},
2711            std::array<int, 4> nr = {2, 2, 2, 2},
2712            bool v = false)
2713        : type(type), ne(ne), nr(nr), v(v) {}
2714
2715    ggml_tensor * build_graph(ggml_context * ctx) override {
2716        ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
2717        ggml_set_name(src, "src");
2718
2719        if (v) {
2720            GGML_ASSERT(ne[0] % 2 == 0);
2721            GGML_ASSERT(ne[1] % 2 == 0);
2722            GGML_ASSERT(ne[2] % 2 == 0);
2723            GGML_ASSERT(ne[3] % 2 == 0);
2724            GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1);
2725            GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1);
2726            GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1);
2727            GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1);
2728
2729            const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2;
2730            const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2;
2731            const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2;
2732            const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2;
2733
2734            src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0);
2735        }
2736
2737        ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data());
2738        ggml_set_name(target, "target");
2739
2740        ggml_tensor * out = ggml_repeat_back(ctx, src, target);
2741        ggml_set_name(out, "out");
2742
2743        return out;
2744    }
2745};
2746
2747// GGML_OP_DUP
2748struct test_dup : public test_case {
2749    const ggml_type type;
2750    const std::array<int64_t, 4> ne;
2751    const std::array<int64_t, 4> permute;
2752    bool _use_permute;
2753
2754    std::string vars() override {
2755        std::string v = VARS_TO_STR2(type, ne);
2756        if (_use_permute) v += "," + VAR_TO_STR(permute);
2757        return v;
2758    }
2759
2760    test_dup(ggml_type type = GGML_TYPE_F32,
2761            std::array<int64_t, 4> ne = {10, 10, 20, 1},
2762            std::array<int64_t, 4> permute = {0, 0, 0, 0})
2763        : type(type), ne(ne), permute(permute),
2764            _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
2765
2766    ggml_tensor * build_graph(ggml_context * ctx) override {
2767        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
2768        ggml_set_param(src);
2769        ggml_set_name(src, "src");
2770
2771        if (_use_permute) {
2772            src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);
2773            ggml_set_name(src, "src_permuted");
2774        }
2775
2776        ggml_tensor * out = ggml_dup(ctx, src);
2777        ggml_set_name(out, "out");
2778
2779        return out;
2780    }
2781};
2782
2783// GGML_OP_SET
2784struct test_set : public test_case {
2785    const ggml_type type_src;
2786    const ggml_type type_dst;
2787    const std::array<int64_t, 4> ne;
2788    const int dim;
2789
2790    std::string vars() override {
2791        return VARS_TO_STR4(type_src, type_dst, ne, dim);
2792    }
2793
2794    size_t op_size(ggml_tensor * t) override {
2795        return ggml_nbytes(t) + ggml_nbytes(t->src[0]);
2796    }
2797
2798    test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
2799            std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1)
2800        : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {}
2801
2802    ggml_tensor * build_graph(ggml_context * ctx) override {
2803        ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
2804        ggml_set_param(src);
2805        ggml_set_name(src, "src");
2806
2807        auto ne_dst = ne;
2808        for (int i = 0; i < dim; ++i) {
2809            ne_dst[i] *= 2;
2810        }
2811        ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
2812        ggml_set_param(dst);
2813        ggml_set_name(dst, "dst");
2814
2815        size_t offset = 0;
2816        for (int i = 0; i < dim; ++i) {
2817            offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
2818        }
2819        ggml_tensor * out = ggml_set(ctx, dst, src,
2820            // The backward pass requires setting a contiguous region:
2821            src->nb[1], src->nb[2], src->nb[3], offset);
2822        ggml_set_name(out, "out");
2823
2824        return out;
2825    }
2826};
2827
2828// GGML_OP_CPY
2829struct test_cpy : public test_case {
2830    const ggml_type type_src;
2831    const ggml_type type_dst;
2832    const std::array<int64_t, 4> ne;
2833    const std::array<int64_t, 4> permute_src;
2834    const std::array<int64_t, 4> permute_dst;
2835    bool _src_use_permute;
2836    bool _dst_use_permute;
2837    bool _src_transpose;
2838
2839    std::string vars() override {
2840        return VARS_TO_STR6(type_src, type_dst, ne, permute_src, permute_dst, _src_transpose);
2841    }
2842
2843    double max_nmse_err() override {
2844        if (type_src == type_dst) {
2845            return 0.0;
2846        }
2847        if (type_dst == GGML_TYPE_Q4_0 || type_dst == GGML_TYPE_Q4_1 || type_dst == GGML_TYPE_IQ4_NL ||
2848            type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1 || type_dst == GGML_TYPE_Q8_0) {
2849            // estimate what the max nmse error would be if one quantized value is
2850            // off by one. The test values are distributed in [-150,150], so it'll be
2851            // roughly (150*2.0 / 2^bits)^2, divided by the mean square value of the reference,
2852            // which is roughly 0.25*150^2 times the number of elements.
2853            double err_estimate = 1.0f/8.0f * 150.0f;
2854            if (type_dst == GGML_TYPE_IQ4_NL) {
2855                // iq4_nl values are a bit more spread out
2856                err_estimate *= 2.0f;
2857            }
2858            if (type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1) {
2859                err_estimate /= 2.0f;
2860            }
2861            if (type_dst == GGML_TYPE_Q8_0) {
2862                err_estimate /= 8.0f;
2863            }
2864            err_estimate *= err_estimate;
2865            err_estimate /= (150.0f*150.0f*0.25f)*float(ne[0] * ne[1] * ne[2] * ne[3]);
2866            return err_estimate;
2867        }
2868        return 1e-6;
2869    }
2870
2871    size_t op_size(ggml_tensor * t) override {
2872        return ggml_nbytes(t) + ggml_nbytes(t->src[0]);
2873    }
2874
2875    test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
2876            std::array<int64_t, 4> ne = {10, 10, 10, 1},
2877            std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
2878            std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},
2879            bool transpose_src = false)
2880        : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
2881          _src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
2882          _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),
2883          _src_transpose(transpose_src){}
2884
2885    ggml_tensor * build_graph(ggml_context * ctx) override {
2886        ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
2887        ggml_set_param(src);
2888        ggml_set_name(src, "src");
2889
2890        if (_src_use_permute) {
2891            src = ggml_permute(ctx, src, permute_src[0], permute_src[1], permute_src[2], permute_src[3]);
2892            ggml_set_name(src, "src_permuted");
2893        }
2894
2895        if (_src_transpose) {
2896            src = ggml_transpose(ctx, src);
2897            ggml_set_name(src, "src_transposed");
2898        }
2899
2900        ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
2901        ggml_set_name(dst, "dst");
2902
2903        if (_dst_use_permute) {
2904            dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
2905            ggml_set_name(dst, "dst_permuted");
2906        }
2907
2908        ggml_tensor * out = ggml_cpy(ctx, src, dst);
2909        ggml_set_name(out, "out");
2910
2911        return out;
2912    }
2913
2914    void initialize_tensors(ggml_context * ctx) override {
2915        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2916            // test extended range of values to check if casting between f32 and i32 is consistent
2917            init_tensor_uniform(t, -150.f, 150.f);
2918        }
2919    }
2920};
2921
2922// GGML_OP_CONT
2923struct test_cont : public test_case {
2924    const ggml_type type;
2925    const std::array<int64_t, 4> ne;
2926    bool use_view_slice;
2927
2928    std::string vars() override {
2929        return VARS_TO_STR3(type, ne, use_view_slice);
2930    }
2931
2932    test_cont(ggml_type type = GGML_TYPE_F32,
2933            std::array<int64_t, 4> ne = {10, 10, 10, 1},
2934            bool use_view_slice = false)
2935        : type(type), ne(ne), use_view_slice(use_view_slice) {}
2936
2937    ggml_tensor * build_graph(ggml_context * ctx) override {
2938        ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
2939        ggml_set_param(src);
2940        ggml_set_name(src, "src");
2941
2942
2943        ggml_tensor * dst;
2944        if (use_view_slice) {
2945            dst = ggml_view_4d(ctx, src, src->ne[0], 1, src->ne[2], src->ne[3],
2946                src->nb[1], src->nb[2], src->nb[3], src->nb[0] * (src->ne[1] - 1));
2947            ggml_set_name(dst, "src_view_slice");
2948        } else {
2949            dst = ggml_transpose(ctx, src);
2950            ggml_set_name(dst, "src_transposed");
2951        }
2952
2953        ggml_tensor * out = ggml_cont(ctx, dst);
2954        ggml_set_name(out, "out");
2955
2956        return out;
2957    }
2958};
2959
2960// GGML_OP_ADD
2961// GGML_OP_SUB
2962// GGML_OP_MUL
2963// GGML_OP_DIV
2964struct test_bin_bcast : public test_case {
2965    using op_t = ggml_tensor * (*) (ggml_context *, ggml_tensor *, ggml_tensor *);
2966    op_t op;
2967    const ggml_type type;
2968    const std::array<int64_t, 4> ne;
2969    const std::array<int, 4> nr;
2970    int nf; // number of fused ops, nf == 1 -> single op (no fusion)
2971    bool perm1; // permute src1?
2972
2973    bool run_whole_graph() override { return nf > 1; }
2974
2975    std::string vars() override {
2976        return VARS_TO_STR5(type, ne, nr, nf, perm1);
2977    }
2978
2979    size_t op_size(ggml_tensor * t) override {
2980        return ggml_nbytes(t) * 3;
2981    }
2982
2983    test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
2984            std::array<int64_t, 4> ne = {10, 10, 1, 1},
2985            std::array<int, 4> nr = {1, 2, 1, 1},
2986            int nf = 1,
2987            bool perm1 = false)
2988        : op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {}
2989
2990    ggml_tensor * build_graph(ggml_context * ctx) override {
2991        GGML_ASSERT(nf <= 16);
2992
2993        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
2994        ggml_set_name(a, "a");
2995
2996        ggml_tensor * b[16];
2997        for (int i = 0; i < nf; ++i) {
2998            if (perm1) {
2999                const int p[4] = { 1, 2, 0, 3 }; // hardcoded for now
3000
3001                b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]);
3002                b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]);
3003            } else {
3004                b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
3005            }
3006            ggml_set_name(b[i], (std::string("b") + std::to_string(i)).c_str());
3007        }
3008
3009        // The backward pass supports broadcasting only for GGML_ADD:
3010        const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1 && !perm1;
3011        if (grad_supported) {
3012            ggml_set_param(a);
3013            ggml_set_param(b[0]);
3014        }
3015
3016        ggml_tensor * out = a;
3017
3018        for (int i = 0; i < nf; ++i) {
3019            out = op(ctx, out, b[i]);
3020        }
3021
3022        ggml_set_name(out, "out");
3023
3024        return out;
3025    }
3026
3027    void initialize_tensors(ggml_context * ctx) override {
3028        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
3029            if (op == ggml_mul || op == ggml_div) {
3030                // MUL and DIV have numerical issues around zero:
3031                init_tensor_uniform(t, 0.9f, 1.1f);
3032            } else {
3033                init_tensor_uniform(t);
3034            }
3035        }
3036    }
3037
3038    float grad_eps() override {
3039        return 0.1f * (op == ggml_mul ? ne[0]*ne[1]*ne[2]*ne[3] : 1);
3040    }
3041
3042    bool grad_precise() override {
3043        return op == ggml_div;
3044    }
3045
3046    double max_maa_err() override {
3047        return op == ggml_add ? 1e-4 : 1e-3;
3048    }
3049};
3050
3051// GGML_OP_ADD_ID
3052struct test_add_id : public test_case {
3053    const ggml_type type_a;
3054    const ggml_type type_b;
3055    const int64_t n_embd;
3056    const int64_t n_experts;
3057    const int64_t n_experts_used;
3058    const int64_t n_token;
3059
3060    std::string vars() override {
3061        return VARS_TO_STR6(type_a, type_b, n_embd, n_experts, n_experts_used, n_token);
3062    }
3063
3064    size_t op_size(ggml_tensor * t) override {
3065        return ggml_nbytes(t) + ggml_nbytes(t->src[0]) + ggml_nbytes(t->src[2]);
3066    }
3067
3068    test_add_id(ggml_type type_a = GGML_TYPE_F32,
3069            ggml_type type_b = GGML_TYPE_F32,
3070            int64_t n_embd = 128,
3071            int64_t n_experts = 16,
3072            int64_t n_experts_used = 8,
3073            int64_t n_token = 10)
3074        : type_a(type_a), type_b(type_b), n_embd(n_embd),
3075          n_experts(n_experts), n_experts_used(n_experts_used), n_token(n_token) {}
3076
3077    ggml_tensor * build_graph(ggml_context * ctx) override {
3078        ggml_tensor * a = ggml_new_tensor_3d(ctx, type_a, n_embd, n_experts_used, n_token);
3079        ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, n_embd, n_experts);
3080        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_experts, n_token);
3081        if (n_experts_used != n_experts) {
3082            ids = ggml_view_2d(ctx, ids, n_experts_used, n_token, ids->nb[1], 0);
3083            ggml_set_name(ids, "view_of_ids");
3084        }
3085
3086        ggml_tensor * out = ggml_add_id(ctx, a, b, ids);
3087        ggml_set_name(out, "out");
3088        return out;
3089    }
3090
3091    void initialize_tensors(ggml_context * ctx) override {
3092        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
3093            if (t->type == GGML_TYPE_I32) {
3094                if (ggml_is_view_op(t->op)) { continue; }
3095                std::random_device rd;
3096                std::default_random_engine rng(rd());
3097                // ids
3098                for (int64_t r = 0; r < ggml_nrows(t); r++) {
3099                    std::vector<int32_t> data(t->ne[0]);
3100                    for (int i = 0; i < t->ne[0]; i++) {
3101                        data[i] = i % n_experts;
3102                    }
3103                    std::shuffle(data.begin(), data.end(), rng);
3104                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
3105                }
3106            } else {
3107                init_tensor_uniform(t);
3108            }
3109        }
3110    }
3111};
3112
3113// GGML_OP_ADD1
3114struct test_add1 : public test_case {
3115    const ggml_type type;
3116    const std::array<int64_t, 4> ne;
3117
3118    std::string vars() override {
3119        return VARS_TO_STR2(type, ne);
3120    }
3121
3122    test_add1(ggml_type type = GGML_TYPE_F32,
3123            std::array<int64_t, 4> ne = {10, 5, 4, 3})
3124        : type(type), ne(ne) {}
3125
3126    ggml_tensor * build_graph(ggml_context * ctx) override {
3127        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
3128        ggml_set_param(a);
3129        ggml_set_name(a, "a");
3130
3131        ggml_tensor * b = ggml_new_tensor_1d(ctx, type, 1);
3132        // ggml_set_param(b); // TODO: implement
3133        ggml_set_name(b, "b");
3134
3135        ggml_tensor * out = ggml_add1(ctx, a, b);
3136        ggml_set_name(out, "out");
3137
3138        return out;
3139    }
3140
3141    float grad_eps() override {
3142        return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];
3143    }
3144};
3145
3146// GGML_OP_SCALE
3147struct test_scale : public test_case {
3148    const ggml_type type;
3149    const std::array<int64_t, 4> ne;
3150    float scale;
3151    float bias;
3152    bool inplace;
3153
3154    std::string vars() override {
3155        return VARS_TO_STR5(type, ne, scale, bias, inplace);
3156    }
3157
3158    test_scale(ggml_type type = GGML_TYPE_F32,
3159            std::array<int64_t, 4> ne = {10, 10, 10, 10},
3160            float scale = 2.0f,
3161            float bias = 0.0f,
3162            bool inplace = false)
3163        : type(type), ne(ne), scale(scale), bias(bias), inplace(inplace) {}
3164
3165    ggml_tensor * build_graph(ggml_context * ctx) override {
3166        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
3167        ggml_set_param(a);
3168        ggml_set_name(a, "a");
3169
3170        ggml_tensor * out;
3171        if (inplace) {
3172            out = ggml_scale_bias_inplace(ctx, a, scale, bias);
3173        } else {
3174            out = ggml_scale_bias(ctx, a, scale, bias);
3175        }
3176        ggml_set_name(out, "out");
3177
3178        return out;
3179    }
3180};
3181
3182// GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
3183struct test_softcap : public test_case {
3184    const ggml_type type;
3185    const std::array<int64_t, 4> ne;
3186    float softcap;
3187
3188    std::string op_desc(ggml_tensor * t) override {
3189        GGML_UNUSED(t);
3190        return "SOFTCAP";
3191    }
3192
3193    bool run_whole_graph() override { return true; }
3194
3195    std::string vars() override {
3196        return VARS_TO_STR3(type, ne, softcap);
3197    }
3198
3199    test_softcap(ggml_type type = GGML_TYPE_F32,
3200            std::array<int64_t, 4> ne = {10, 10, 10, 10},
3201            float softcap = 30.0f)
3202        : type(type), ne(ne), softcap(softcap) {}
3203
3204    ggml_tensor * build_graph(ggml_context * ctx) override {
3205        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
3206
3207        ggml_set_param(a);
3208        ggml_set_name(a, "a");
3209
3210        ggml_tensor * out = ggml_scale(ctx, ggml_tanh(ctx, ggml_scale(ctx, a, 1.0f / softcap)), softcap);
3211        ggml_set_name(out, "out");
3212
3213        return out;
3214    }
3215};
3216
3217// GGML_OP_SILU_BACK
3218struct test_silu_back : public test_case {
3219    const ggml_type type;
3220    const std::array<int64_t, 4> ne;
3221    float eps;
3222
3223    std::string vars() override {
3224        return VARS_TO_STR3(type, ne, eps);
3225    }
3226
3227    test_silu_back(ggml_type type = GGML_TYPE_F32,
3228            std::array<int64_t, 4> ne = {64, 5, 4, 3},
3229            float eps = 1e-6f)
3230        : type(type), ne(ne), eps(eps) {}
3231
3232    ggml_tensor * build_graph(ggml_context * ctx) override {
3233        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
3234        ggml_set_name(a, "a");
3235
3236        ggml_tensor * grad = ggml_new_tensor(ctx, type, 4, ne.data());
3237        ggml_set_name(grad, "grad");
3238
3239        ggml_tensor * out = ggml_silu_back(ctx, a, grad);
3240        ggml_set_name(out, "out");
3241
3242        return out;
3243    }
3244
3245    bool grad_precise() override {
3246        return true;
3247    }
3248};
3249
3250// GGML_OP_NORM
3251struct test_norm : public test_case {
3252    const ggml_type type;
3253    const std::array<int64_t, 4> ne;
3254    const bool v; // whether a is a non-contiguous view
3255    const float eps;
3256
3257    std::string vars() override {
3258        return VARS_TO_STR4(type, ne, v, eps);
3259    }
3260
3261    test_norm(ggml_type type = GGML_TYPE_F32,
3262            std::array<int64_t, 4> ne = {64, 5, 4, 3},
3263            bool v = false,
3264            float eps = 1e-6f)
3265        : type(type), ne(ne), v(v), eps(eps) {}
3266
3267    ggml_tensor * build_graph(ggml_context * ctx) override {
3268        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
3269        ggml_set_name(a, "a");
3270
3271        if (v) {
3272            a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
3273            ggml_set_name(a, "view of a");
3274        }
3275
3276        ggml_tensor * out = ggml_norm(ctx, a, eps);
3277        ggml_set_name(out, "out");
3278
3279        return out;
3280    }
3281};
3282
3283// GGML_OP_NORM + GGML_OP_MUL + GGML_OP_ADD
3284struct test_norm_mul_add : public test_case {
3285    const ggml_type type;
3286    const std::array<int64_t, 4> ne;
3287    float eps;
3288    const bool broadcast;
3289
3290    std::string op_desc(ggml_tensor * t) override {
3291        GGML_UNUSED(t);
3292        return "NORM_MUL_ADD";
3293    }
3294
3295    bool run_whole_graph() override { return true; }
3296
3297    std::string vars() override {
3298        return VARS_TO_STR4(type, ne, eps, broadcast);
3299    }
3300
3301    test_norm_mul_add(ggml_type type = GGML_TYPE_F32,
3302            std::array<int64_t, 4> ne = {128, 2, 1, 1},
3303            float eps = 1e-5f,
3304            bool broadcast = false)
3305        : type(type), ne(ne), eps(eps), broadcast(broadcast) {}
3306
3307    ggml_tensor * build_graph(ggml_context * ctx) override {
3308        std::array<int64_t, 4> broadcast_dims = {ne[0], ne[1] * 2, ne[2] * 2, ne[3] * 2};
3309
3310        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
3311        ggml_tensor * w = ggml_new_tensor(ctx, type, 4, ne.data());
3312        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
3313        ggml_set_param(a); ggml_set_param(w); ggml_set_param(b);
3314        ggml_set_name(a, "a"); ggml_set_name(w, "w"); ggml_set_name(b, "b");
3315
3316        // Use a, w and b early to avoid OP_NONE in graph
3317        a = ggml_add(ctx, ggml_add(ctx, a, w), b);
3318
3319        ggml_tensor * n = ggml_norm(ctx, a, eps);
3320        ggml_tensor * m = ggml_mul(ctx, n, w);
3321        ggml_tensor * out = ggml_add(ctx, m, b);
3322        ggml_set_name(out, "out");
3323        return out;
3324    }
3325};
3326// GGML_OP_RMS_NORM
3327struct test_rms_norm : public test_case {
3328    const ggml_type type;
3329    const std::array<int64_t, 4> ne;
3330    const bool v; // whether a is a non-contiguous view
3331    const float eps;
3332    const bool inplace; // whether to do the operation inplace
3333
3334    std::string vars() override {
3335        return VARS_TO_STR5(type, ne, v, eps, inplace);
3336    }
3337
3338    test_rms_norm(ggml_type type = GGML_TYPE_F32,
3339            std::array<int64_t, 4> ne = {64, 5, 4, 3},
3340            bool v = false,
3341            float eps = 1e-6f,
3342            bool inplace = false)
3343        : type(type), ne(ne), v(v), eps(eps), inplace(inplace) {}
3344
3345    ggml_tensor * build_graph(ggml_context * ctx) override {
3346        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
3347        ggml_set_param(a);
3348        ggml_set_name(a, "a");
3349
3350        if (v) {
3351            a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
3352            ggml_set_name(a, "view of a");
3353        }
3354
3355        ggml_tensor * out;
3356        if (inplace) {
3357            out = ggml_rms_norm_inplace(ctx, a, eps);
3358        } else {
3359            out = ggml_rms_norm(ctx, a, eps);
3360        }
3361        ggml_set_name(out, "out");
3362
3363        return out;
3364    }
3365
3366    void initialize_tensors(ggml_context * ctx) override {
3367        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
3368            init_tensor_uniform(t, -10.f, 10.f);
3369        }
3370    }
3371
3372    float grad_eps() override {
3373        return 1.0f;
3374    }
3375
3376    bool grad_precise() override {
3377        return true;
3378    }
3379};
3380
3381// GGML_OP_RMS_NORM_BACK
3382struct test_rms_norm_back : public test_case {
3383    const ggml_type type;
3384    const std::array<int64_t, 4> ne;
3385    const float eps;
3386
3387    std::string vars() override {
3388        return VARS_TO_STR3(type, ne, eps);
3389    }
3390
3391    test_rms_norm_back(ggml_type type = GGML_TYPE_F32,
3392            std::array<int64_t, 4> ne = {64, 5, 4, 3},
3393            float eps = 1e-6f)
3394        : type(type), ne(ne), eps(eps) {}
3395
3396    ggml_tensor * build_graph(ggml_context * ctx) override {
3397        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
3398        ggml_set_name(a, "a");
3399
3400        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
3401        ggml_set_name(b, "b");
3402
3403        ggml_tensor * out = ggml_rms_norm_back(ctx, a, b, eps);
3404        ggml_set_name(out, "out");
3405
3406        return out;
3407    }
3408
3409    void initialize_tensors(ggml_context * ctx) override {
3410        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
3411            init_tensor_uniform(t, -10.f, 10.f);
3412        }
3413    }
3414};
3415
3416// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ADD
3417struct test_rms_norm_mul_add : public test_case {
3418    const ggml_type type;
3419    const std::array<int64_t, 4> ne;
3420    const float eps;
3421    const bool broadcast;
3422    const bool multi_add; // test a sequence of adds feeding into rms_norm
3423
3424    std::string op_desc(ggml_tensor * t) override {
3425        GGML_UNUSED(t);
3426        return "RMS_NORM_MUL_ADD";
3427    }
3428
3429    bool run_whole_graph() override { return true; }
3430
3431    std::string vars() override {
3432        return VARS_TO_STR5(type, ne, eps, broadcast, multi_add);
3433    }
3434
3435    test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
3436            std::array<int64_t, 4> ne = {64, 5, 4, 3},
3437            float eps = 1e-6f, bool broadcast = false, bool multi_add = false)
3438        : type(type), ne(ne), eps(eps), broadcast(broadcast), multi_add(multi_add) {}
3439
3440    ggml_tensor * build_graph(ggml_context * ctx) override {
3441        std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
3442
3443        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
3444        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
3445        ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data());
3446
3447        ggml_set_param(a);
3448        ggml_set_name(a, "a");
3449        ggml_set_param(b);
3450        ggml_set_name(b, "b");
3451        ggml_set_param(c);
3452        ggml_set_name(c, "c");
3453
3454        // Use a, b and c early, so we don't end up with an OP_NONE between rms_norm and mul
3455        a = ggml_add(ctx, ggml_add(ctx, a, b), c);
3456        if (multi_add) {
3457            a = ggml_add(ctx, ggml_add(ctx, a, b), c);
3458        }
3459        ggml_tensor * out = ggml_add(ctx, ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b), c);
3460        ggml_set_name(out, "out");
3461
3462        return out;
3463    }
3464
3465    void initialize_tensors(ggml_context * ctx) override {
3466        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
3467            init_tensor_uniform(t, -10.f, 10.f);
3468        }
3469    }
3470
3471    float grad_eps() override {
3472        return 1.0f;
3473    }
3474
3475    bool grad_precise() override {
3476        return true;
3477    }
3478};
3479
3480// GGML_OP_ADD + GGML_OP_RMS_NORM (fused operation)
3481struct test_add_rms_norm : public test_case {
3482    const ggml_type type;
3483    const std::array<int64_t, 4> ne;
3484    const float eps;
3485    const bool broadcast;
3486
3487    std::string op_desc(ggml_tensor * t) override {
3488        GGML_UNUSED(t);
3489        return "ADD_RMS_NORM";
3490    }
3491
3492    bool run_whole_graph() override { return true; }
3493
3494    std::string vars() override {
3495        return VARS_TO_STR4(type, ne, eps, broadcast);
3496    }
3497
3498    test_add_rms_norm(ggml_type type = GGML_TYPE_F32,
3499            std::array<int64_t, 4> ne = {64, 5, 4, 3},
3500            float eps = 1e-6f, bool broadcast = false)
3501        : type(type), ne(ne), eps(eps), broadcast(broadcast) {}
3502
3503    ggml_tensor * build_graph(ggml_context * ctx) override {
3504        std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
3505
3506        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
3507        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
3508
3509        ggml_set_param(a);
3510        ggml_set_name(a, "a");
3511        ggml_set_param(b);
3512        ggml_set_name(b, "b");
3513
3514        // ADD operation followed by RMS_NORM
3515        ggml_tensor * add_result = ggml_add(ctx, a, b);
3516        ggml_set_name(add_result, "add_result");
3517
3518        ggml_tensor * out = ggml_rms_norm(ctx, add_result, eps);
3519        ggml_set_name(out, "out");
3520
3521        return out;
3522    }
3523
3524    void initialize_tensors(ggml_context * ctx) override {
3525        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
3526            init_tensor_uniform(t, -10.f, 10.f);
3527        }
3528    }
3529
3530    float grad_eps() override {
3531        return 1.0f;
3532    }
3533
3534    bool grad_precise() override {
3535        return true;
3536    }
3537};
3538
3539// GGML_OP_SSM_CONV
3540struct test_ssm_conv : public test_case {
3541    const ggml_type type;
3542    const std::array<int64_t, 4> ne_a;
3543    const std::array<int64_t, 4> ne_b;
3544
3545    std::string vars() override {
3546        return VARS_TO_STR3(type, ne_a, ne_b);
3547    }
3548
3549    test_ssm_conv(ggml_type type = GGML_TYPE_F32,
3550            std::array<int64_t, 4> ne_a = {10, 10, 10, 1},
3551            std::array<int64_t, 4> ne_b = {3, 3, 1, 1})
3552        : type(type), ne_a(ne_a), ne_b(ne_b) {}
3553
3554    ggml_tensor * build_graph(ggml_context * ctx) override {
3555        ggml_tensor * a   = ggml_new_tensor(ctx, type, 4, ne_a.data());
3556        ggml_tensor * b   = ggml_new_tensor(ctx, type, 4, ne_b.data());
3557        ggml_tensor * out = ggml_ssm_conv(ctx, a, b);
3558        return out;
3559    }
3560};
3561
3562// GGML_OP_SSM_SCAN
3563struct test_ssm_scan : public test_case {
3564    const ggml_type type;
3565
3566    const int64_t d_state;
3567    const int64_t head_dim;
3568    const int64_t n_head;
3569    const int64_t n_group;
3570    const int64_t n_seq_tokens;
3571    const int64_t n_seqs;
3572
3573    std::string vars() override {
3574        return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs);
3575    }
3576
3577    test_ssm_scan(ggml_type type = GGML_TYPE_F32,
3578            int64_t d_state = 32,
3579            int64_t head_dim = 1, // non-zero for Mamba-2
3580            int64_t n_head  = 32,
3581            int64_t n_group = 1,
3582            int64_t n_seq_tokens = 32,
3583            int64_t n_seqs = 32)
3584        : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
3585
3586    ggml_tensor * build_graph(ggml_context * ctx) override {
3587        ggml_tensor * s   = ggml_new_tensor_4d(ctx, type, d_state,  head_dim,     n_head,       n_seqs);
3588        ggml_tensor * x   = ggml_new_tensor_4d(ctx, type, head_dim, n_head,       n_seq_tokens, n_seqs);
3589        ggml_tensor * dt  = ggml_new_tensor_3d(ctx, type, n_head,   n_seq_tokens, n_seqs);
3590        ggml_tensor * A   = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head);
3591        ggml_tensor * B   = ggml_new_tensor_4d(ctx, type, d_state,  n_group,      n_seq_tokens, n_seqs);
3592        ggml_tensor * C   = ggml_new_tensor_4d(ctx, type, d_state,  n_group,      n_seq_tokens, n_seqs);
3593        ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32,  n_seqs);
3594        ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids);
3595        return out;
3596    }
3597
3598    // similar to test_mul_mat_id
3599    void initialize_tensors(ggml_context * ctx) override {
3600        std::random_device rd;
3601        std::default_random_engine rng(rd());
3602        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
3603            if (t->type == GGML_TYPE_I32) {
3604                if (ggml_is_view_op(t->op)) { continue; }
3605                // ids
3606                for (int64_t r = 0; r < ggml_nrows(t); r++) {
3607                    std::vector<int32_t> data(t->ne[0]);
3608                    for (int i = 0; i < t->ne[0]; i++) {
3609                        data[i] = i;
3610                    }
3611                    std::shuffle(data.begin(), data.end(), rng);
3612                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
3613                }
3614            } else {
3615                init_tensor_uniform(t);
3616            }
3617        }
3618    }
3619};
3620
3621// GGML_OP_RWKV_WKV6
3622struct test_rwkv_wkv6 : public test_case {
3623    const ggml_type type;
3624
3625    const int64_t head_count;
3626    const int64_t head_size;
3627    const int64_t n_seq_tokens;
3628    const int64_t n_seqs;
3629
3630    std::string vars() override {
3631        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
3632    }
3633
3634    test_rwkv_wkv6(ggml_type type = GGML_TYPE_F32,
3635            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
3636        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
3637
3638    ggml_tensor * build_graph(ggml_context * ctx) override {
3639        const int64_t n_tokens = n_seq_tokens * n_seqs;
3640        ggml_tensor * r   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3641        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3642        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3643        ggml_tensor * tf  = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
3644        ggml_tensor * td  = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3645        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
3646        ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
3647        return out;
3648    }
3649};
3650
3651// GGML_OP_GATED_LINEAR_ATTN
3652struct test_gla : public test_case {
3653    const ggml_type type;
3654
3655    const int64_t head_count;
3656    const int64_t head_size;
3657    const int64_t n_seq_tokens;
3658    const int64_t n_seqs;
3659
3660    std::string vars() override {
3661        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
3662    }
3663
3664    test_gla(ggml_type type = GGML_TYPE_F32,
3665            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
3666        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
3667
3668    ggml_tensor * build_graph(ggml_context * ctx) override {
3669        const int64_t n_tokens = n_seq_tokens * n_seqs;
3670        ggml_tensor * q   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3671        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3672        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3673        ggml_tensor * g   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3674        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
3675        ggml_tensor * out = ggml_gated_linear_attn(ctx, k, v, q, g, s, pow(head_size, -0.5));
3676        return out;
3677    }
3678};
3679
3680// GGML_OP_RWKV_WKV7
3681struct test_rwkv_wkv7 : public test_case {
3682    const ggml_type type;
3683
3684    const int64_t head_count;
3685    const int64_t head_size;
3686    const int64_t n_seq_tokens;
3687    const int64_t n_seqs;
3688
3689    std::string vars() override {
3690        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
3691    }
3692
3693    test_rwkv_wkv7(ggml_type type = GGML_TYPE_F32,
3694            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
3695        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
3696
3697    ggml_tensor * build_graph(ggml_context * ctx) override {
3698        const int64_t n_tokens = n_seq_tokens * n_seqs;
3699        ggml_tensor * r   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3700        ggml_tensor * w   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3701        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3702        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3703        ggml_tensor * a   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3704        ggml_tensor * b   = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
3705        // Outputs may become NaN with long seqlen without these normalization
3706        a = ggml_l2_norm(ctx, a, 1e-7F);
3707        b = ggml_l2_norm(ctx, b, 1e-7F);
3708        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
3709        ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
3710        return out;
3711    }
3712};
3713
3714// GGML_OP_MUL_MAT
3715struct test_mul_mat : public test_case {
3716    const ggml_type type_a;
3717    const ggml_type type_b;
3718    const int64_t m;
3719    const int64_t n;
3720    const int64_t k;
3721    const std::array<int64_t, 2> bs;  // dims 3 and 4
3722    const std::array<int64_t, 2> nr;  // repeat in dims 3 and 4
3723    const std::array<int64_t, 4> per; // permutation of dimensions
3724    const int64_t k_v; // size of k in memory, resulting in a non-contiguous view for k_v > k, no view for k_v == 0
3725    const uint32_t o; // number of outputs
3726
3727    std::string vars() override {
3728        return VARS_TO_STR10(type_a, type_b, m, n, k, bs, nr, per, k_v, o);
3729    }
3730
3731    double max_nmse_err() override {
3732        return 5e-4;
3733    }
3734
3735    double max_nmse_err(ggml_backend_t backend) override {
3736        // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
3737        if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
3738            return 2e-2;
3739        }
3740        return max_nmse_err();
3741    }
3742
3743    int64_t grad_nmax() override {
3744        return 20000;
3745    }
3746
3747    uint64_t op_flops(ggml_tensor * t) override {
3748        GGML_UNUSED(t);
3749        return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
3750    }
3751
3752    test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
3753            int64_t m = 32, int64_t n = 32, int64_t k = 32,
3754            std::array<int64_t, 2> bs = {10, 10},
3755            std::array<int64_t, 2> nr = {2, 2},
3756            std::array<int64_t, 4> per = {0, 1, 2, 3},
3757            int64_t k_v = 0, uint32_t o = 1)
3758        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), k_v(k_v), o(o) {}
3759
3760    ggml_tensor * build_graph(ggml_context * ctx) override {
3761        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
3762        ggml_tensor * a;
3763        ggml_tensor * b;
3764
3765        const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
3766        if (npermuted > 0) {
3767            GGML_ASSERT(npermuted == 2);
3768            GGML_ASSERT(k_v == 0); // not handled
3769            GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
3770            GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
3771
3772            // Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
3773            const int64_t ne_a[4] = {k, m, bs[0],       bs[1]};
3774            const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]};
3775
3776            a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
3777            b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
3778            if (!ggml_is_quantized(type_a)) {
3779                if (bs[1] == 1 && nr[1] == 1) {
3780                    ggml_set_param(a);
3781                }
3782                ggml_set_param(b);
3783            }
3784            ggml_set_name(a, "a");
3785            ggml_set_name(b, "b");
3786
3787            a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]);
3788            b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]);
3789            ggml_set_name(a, "a_permuted");
3790            ggml_set_name(b, "b_permuted");
3791        } else {
3792            const int64_t k_physical = k_v == 0 ? k : k_v;
3793            a = ggml_new_tensor_4d(ctx, type_a, k_physical, m, bs[0],       bs[1]);
3794            b = ggml_new_tensor_4d(ctx, type_b, k_physical, n, bs[0]*nr[0], bs[1]*nr[1]);
3795
3796            if (!ggml_is_quantized(type_a)) {
3797                if (bs[1] == 1 && nr[1] == 1) {
3798                    ggml_set_param(a);
3799                }
3800                ggml_set_param(b);
3801            }
3802
3803            if (k_v != 0) {
3804                GGML_ASSERT(k_v > k);
3805                a = ggml_view_4d(ctx, a, k, m, bs[0],       bs[1],       a->nb[1], a->nb[2], a->nb[3], 0);
3806                b = ggml_view_4d(ctx, b, k, n, bs[0]*nr[0], bs[1]*nr[1], b->nb[1], b->nb[2], b->nb[3], 0);
3807            }
3808            ggml_set_name(a, "a");
3809            ggml_set_name(b, "b");
3810        }
3811
3812        ggml_tensor * out = ggml_mul_mat(ctx, a, b);
3813        ggml_set_name(out, "out");
3814        for (uint32_t i = 1; i < o; ++i) {
3815            ggml_tensor * out2 = ggml_mul_mat(ctx, a, b);
3816            ggml_set_name(out2, "out2");
3817            out = ggml_add(ctx, out, out2);
3818        }
3819
3820        return out;
3821    }
3822
3823    bool run_whole_graph() override { return o > 1; }
3824
3825    std::string op_desc(ggml_tensor * t) override {
3826        GGML_UNUSED(t);
3827        return ggml_op_name(GGML_OP_MUL_MAT);
3828    }
3829};
3830
3831static void init_mul_mat_id_tensors(ggml_context * ctx, int n_mats) {
3832    std::random_device rd;
3833    std::default_random_engine rng(rd());
3834    for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
3835        if (t->type == GGML_TYPE_I32) {
3836            if (ggml_is_view_op(t->op)) { continue; }
3837            // ids
3838            for (int64_t r = 0; r < ggml_nrows(t); r++) {
3839                std::vector<int32_t> data(t->ne[0]);
3840                for (int i = 0; i < t->ne[0]; i++) {
3841                    data[i] = i % n_mats;
3842                }
3843                std::shuffle(data.begin(), data.end(), rng);
3844                ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
3845            }
3846        } else {
3847            init_tensor_uniform(t);
3848        }
3849    }
3850}
3851
3852// GGML_OP_MUL_MAT_ID
3853struct test_mul_mat_id : public test_case {
3854    const ggml_type type_a;
3855    const ggml_type type_b;
3856    const int n_mats;
3857    const int n_used;
3858    const bool b; // broadcast b matrix
3859    const int64_t m;
3860    const int64_t n;
3861    const int64_t k;
3862
3863    std::string vars() override {
3864        return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);
3865    }
3866
3867    double max_nmse_err() override {
3868        return 5e-4;
3869    }
3870
3871    double max_nmse_err(ggml_backend_t backend) override {
3872        // for blackwell we quantize activations to mxfp4 instead of q8_1 so we add higher tolerance
3873        if (type_a == GGML_TYPE_MXFP4 && backend_has_feature(backend, "BLACKWELL_NATIVE_FP4")) {
3874            return 2e-2;
3875        }
3876        return max_nmse_err();
3877    }
3878
3879    uint64_t op_flops(ggml_tensor * t) override {
3880        GGML_UNUSED(t);
3881        return 2 * m * k * n * n_used;
3882    }
3883
3884    test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
3885            int n_mats = 8, int n_used = 2, bool b = false,
3886            int64_t m = 32, int64_t n = 32, int64_t k = 32)
3887        : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
3888            m(m), n(n), k(k) {
3889            GGML_ASSERT(n_used <= n_mats);
3890        }
3891
3892    ggml_tensor * build_graph(ggml_context * ctx) override {
3893        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
3894        ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
3895        ggml_set_name(as, "as");
3896
3897        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
3898        ggml_set_name(ids, "ids");
3899        if (n_used != n_mats) {
3900            ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
3901            ggml_set_name(ids, "view_of_ids");
3902        }
3903
3904        ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
3905        ggml_set_name(b, "b");
3906
3907        ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
3908        ggml_set_name(out, "out");
3909
3910        return out;
3911    }
3912
3913    void initialize_tensors(ggml_context * ctx) override {
3914        init_mul_mat_id_tensors(ctx, n_mats);
3915    }
3916};
3917
3918// GGML_OP_MUL_MAT_ID + GGML_OP_ADD or GGML_OP_MUL
3919struct test_mul_mat_id_fusion : public test_case {
3920    const ggml_type type_a;
3921    const ggml_type type_b;
3922    const int n_mats;
3923    const int n_used;
3924    const bool b; // broadcast b matrix
3925    const int64_t m;
3926    const int64_t n;
3927    const int64_t k;
3928    const uint32_t o; // number of outputs
3929    const bool mul;
3930
3931    std::string vars() override {
3932        return VARS_TO_STR10(type_a, type_b, n_mats, n_used, b, m, n, k, o, mul);
3933    }
3934
3935    double max_nmse_err() override {
3936        return 5e-4;
3937    }
3938
3939    uint64_t op_flops(ggml_tensor * t) override {
3940        GGML_UNUSED(t);
3941        return 2 * m * k * n * n_used;
3942    }
3943
3944    test_mul_mat_id_fusion(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
3945            int n_mats = 8, int n_used = 2, bool b = false,
3946            int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1, bool mul = false)
3947        : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
3948            m(m), n(n), k(k), o(o), mul(mul) {
3949            GGML_ASSERT(n_used <= n_mats);
3950        }
3951
3952    ggml_tensor * build_graph(ggml_context * ctx) override {
3953        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
3954        ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
3955        ggml_set_name(as, "as");
3956
3957        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
3958        ggml_set_name(ids, "ids");
3959        if (n_used != n_mats) {
3960            ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
3961            ggml_set_name(ids, "view_of_ids");
3962        }
3963
3964        ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
3965        ggml_set_name(b, "b");
3966
3967        ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
3968        ggml_set_name(out, "out");
3969
3970        for (uint32_t i = 1; i < o; ++i) {
3971            ggml_tensor * a2 = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
3972            ggml_tensor * out2 = ggml_mul_mat_id(ctx, a2, b, ids);
3973            ggml_set_name(out2, "out2");
3974            out = ggml_add(ctx, out, out2);
3975        }
3976
3977        if (mul) {
3978            std::array<int64_t, 4> ne { 1, out->ne[1], out->ne[2], out->ne[3] };
3979            ne[0] = 1;
3980            ggml_tensor * m = ggml_new_tensor(ctx, out->type, 4, ne.data());
3981            out = ggml_mul(ctx, out, m);
3982        }
3983
3984        return out;
3985    }
3986
3987    void initialize_tensors(ggml_context * ctx) override {
3988        init_mul_mat_id_tensors(ctx, n_mats);
3989    }
3990
3991    bool run_whole_graph() override { return true; }
3992
3993    std::string op_desc(ggml_tensor * t) override {
3994        GGML_UNUSED(t);
3995        return "MUL_MAT_ID_FUSION";
3996    }
3997};
3998
3999// GGML_OP_OUT_PROD
4000struct test_out_prod : public test_case {
4001    const ggml_type type_a;
4002    const ggml_type type_b;
4003    const int64_t m;
4004    const int64_t n;
4005    const int64_t k;
4006    const std::array<int64_t, 2> bs; // dims 3 and 4
4007    const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
4008    const bool trans_b;
4009
4010    std::string vars() override {
4011        return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, trans_b);
4012    }
4013
4014    double max_nmse_err() override {
4015        return 5e-4;
4016    }
4017
4018    test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
4019            int64_t m = 32, int64_t n = 32, int64_t k = 32,
4020            std::array<int64_t, 2> bs = {10, 10},
4021            std::array<int64_t, 2> nr = {2, 2},
4022            bool trans_b = false)
4023        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), trans_b(trans_b) {}
4024
4025    ggml_tensor * build_graph(ggml_context * ctx) override {
4026        ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]);
4027        ggml_set_name(a, "a");
4028
4029        ggml_tensor * b;
4030        if (trans_b) {
4031            b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
4032            b = ggml_transpose(ctx, b);
4033        } else {
4034            b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0]*nr[0], bs[1]*nr[1]);
4035        }
4036        ggml_set_name(b, "b");
4037
4038        ggml_tensor * out = ggml_out_prod(ctx, a, b);
4039        ggml_set_name(out, "out");
4040
4041        return out;
4042    }
4043};
4044
4045// GGML_OP_SQR
4046struct test_sqr : public test_case {
4047    const ggml_type type;
4048    const std::array<int64_t, 4> ne;
4049
4050    std::string vars() override {
4051        return VARS_TO_STR2(type, ne);
4052    }
4053
4054    test_sqr(ggml_type type = GGML_TYPE_F32,
4055            std::array<int64_t, 4> ne = {10, 5, 4, 3})
4056        : type(type), ne(ne) {}
4057
4058    ggml_tensor * build_graph(ggml_context * ctx) override {
4059        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4060        ggml_set_param(a);
4061        ggml_set_name(a, "a");
4062
4063        ggml_tensor * out = ggml_sqr(ctx, a);
4064        ggml_set_name(out, "out");
4065
4066        return out;
4067    }
4068
4069    float grad_eps() override {
4070        return 0.1f * 0.25f*ne[0]*ne[1]*ne[2]*ne[3]; // 10% of expected value of sum.
4071    }
4072};
4073
4074// GGML_OP_SQRT
4075struct test_sqrt : public test_case {
4076    const ggml_type type;
4077    const std::array<int64_t, 4> ne;
4078
4079    std::string vars() override {
4080        return VARS_TO_STR2(type, ne);
4081    }
4082
4083    test_sqrt(ggml_type type = GGML_TYPE_F32,
4084            std::array<int64_t, 4> ne = {10, 3, 3, 2})
4085        : type(type), ne(ne) {}
4086
4087    ggml_tensor * build_graph(ggml_context * ctx) override {
4088        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4089        ggml_set_param(a);
4090        ggml_set_name(a, "a");
4091
4092        ggml_tensor * out = ggml_sqrt(ctx, a);
4093        ggml_set_name(out, "out");
4094
4095        return out;
4096    }
4097
4098    void initialize_tensors(ggml_context * ctx) override {
4099        // fill with positive values
4100        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
4101            init_tensor_uniform(t, 50.0f, 100.0f);
4102        }
4103    }
4104
4105    float grad_eps() override {
4106        return 20.0f;
4107    }
4108
4109    bool grad_precise() override {
4110        return true;
4111    }
4112};
4113
4114// GGML_OP_LOG
4115struct test_log : public test_case {
4116    const ggml_type type;
4117    const std::array<int64_t, 4> ne;
4118
4119    std::string vars() override {
4120        return VARS_TO_STR2(type, ne);
4121    }
4122
4123    test_log(ggml_type type = GGML_TYPE_F32,
4124            std::array<int64_t, 4> ne = {10, 5, 4, 3})
4125        : type(type), ne(ne) {}
4126
4127    ggml_tensor * build_graph(ggml_context * ctx) override {
4128        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4129        ggml_set_param(a);
4130        ggml_set_name(a, "a");
4131
4132        ggml_tensor * out = ggml_log(ctx, a);
4133        ggml_set_name(out, "out");
4134
4135        return out;
4136    }
4137
4138    void initialize_tensors(ggml_context * ctx) override {
4139        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
4140            // log(1) == 0, cluster values there to keep the sum low for better precision in the backward pass:
4141            init_tensor_uniform(t, 0.9f, 1.1f);
4142        }
4143    }
4144
4145    bool grad_precise() override {
4146        return true;
4147    }
4148};
4149
4150// GGML_OP_SIN
4151struct test_sin : public test_case {
4152    const ggml_type type;
4153    const std::array<int64_t, 4> ne;
4154
4155    std::string vars() override {
4156        return VARS_TO_STR2(type, ne);
4157    }
4158
4159    test_sin(ggml_type type = GGML_TYPE_F32,
4160            std::array<int64_t, 4> ne = {10, 2, 2, 2})
4161        : type(type), ne(ne) {}
4162
4163    ggml_tensor * build_graph(ggml_context * ctx) override {
4164        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4165        ggml_set_param(a);
4166        ggml_set_name(a, "a");
4167
4168        ggml_tensor * out = ggml_sin(ctx, a);
4169        ggml_set_name(out, "out");
4170
4171        return out;
4172    }
4173
4174    void initialize_tensors(ggml_context * ctx) override {
4175        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
4176            init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].
4177        }
4178    }
4179
4180    double max_maa_err() override {
4181        return 1e-3;
4182    }
4183
4184    float grad_eps() override {
4185        return 0.2f;
4186    }
4187
4188    bool grad_precise() override {
4189        return true;
4190    }
4191};
4192
4193// GGML_OP_COS
4194struct test_cos : public test_case {
4195    const ggml_type type;
4196    const std::array<int64_t, 4> ne;
4197
4198    std::string vars() override {
4199        return VARS_TO_STR2(type, ne);
4200    }
4201
4202    test_cos(ggml_type type = GGML_TYPE_F32,
4203            std::array<int64_t, 4> ne = {10, 2, 2, 2})
4204        : type(type), ne(ne) {}
4205
4206    ggml_tensor * build_graph(ggml_context * ctx) override {
4207        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4208        ggml_set_param(a);
4209        ggml_set_name(a, "a");
4210
4211        ggml_tensor * out = ggml_cos(ctx, a);
4212        ggml_set_name(out, "out");
4213
4214        return out;
4215    }
4216
4217    void initialize_tensors(ggml_context * ctx) override {
4218        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
4219            init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].
4220        }
4221    }
4222
4223    double max_maa_err() override {
4224        return 1e-3;
4225    }
4226
4227    float grad_eps() override {
4228        return 0.2f;
4229    }
4230
4231    bool grad_precise() override {
4232        return true;
4233    }
4234};
4235
4236// GGML_OP_CLAMP
4237struct test_clamp : public test_case {
4238    const ggml_type type;
4239    const std::array<int64_t, 4> ne;
4240    float min;
4241    float max;
4242
4243    std::string vars() override {
4244        return VARS_TO_STR4(type, ne, min, max);
4245    }
4246
4247    test_clamp(ggml_type type = GGML_TYPE_F32,
4248            std::array<int64_t, 4> ne = {10, 5, 4, 3},
4249            float min = -0.5f, float max = 0.5f)
4250        : type(type), ne(ne), min(min), max(max) {}
4251
4252    ggml_tensor * build_graph(ggml_context * ctx) override {
4253        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4254        ggml_set_name(a, "a");
4255
4256        ggml_tensor * out = ggml_clamp(ctx, a, min, max);
4257        ggml_set_name(out, "out");
4258
4259        return out;
4260    }
4261
4262    float grad_eps() override {
4263        return 1e-2f;
4264    }
4265
4266    std::vector<float> grad_expect() override {
4267        return {0.0f, 1.0f};
4268    }
4269};
4270
4271// GGML_OP_FLOOR
4272struct test_floor : public test_case {
4273    const ggml_type type;
4274    const std::array<int64_t, 4> ne;
4275
4276    std::string vars() override {
4277        return VARS_TO_STR2(type, ne);
4278    }
4279
4280    test_floor(ggml_type type = GGML_TYPE_F32,
4281               std::array<int64_t, 4> ne = {10, 2, 2, 2})
4282        : type(type), ne(ne) {}
4283
4284    ggml_tensor * build_graph(ggml_context * ctx) override {
4285        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4286        ggml_set_param(a);
4287        ggml_set_name(a, "a");
4288
4289        ggml_tensor * out = ggml_floor(ctx, a);
4290        ggml_set_name(out, "out");
4291
4292        return out;
4293    }
4294
4295    void initialize_tensors(ggml_context * ctx) override {
4296        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
4297            init_tensor_uniform(t, -10.0f, 10.0f);
4298        }
4299    }
4300};
4301
4302// GGML_OP_CEIL
4303struct test_ceil : public test_case {
4304    const ggml_type type;
4305    const std::array<int64_t, 4> ne;
4306
4307    std::string vars() override {
4308        return VARS_TO_STR2(type, ne);
4309    }
4310
4311    test_ceil(ggml_type type = GGML_TYPE_F32,
4312              std::array<int64_t, 4> ne = {10, 2, 2, 2})
4313        : type(type), ne(ne) {}
4314
4315    ggml_tensor * build_graph(ggml_context * ctx) override {
4316        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4317        ggml_set_param(a);
4318        ggml_set_name(a, "a");
4319
4320        ggml_tensor * out = ggml_ceil(ctx, a);
4321        ggml_set_name(out, "out");
4322
4323        return out;
4324    }
4325
4326    void initialize_tensors(ggml_context * ctx) override {
4327        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
4328            init_tensor_uniform(t, -10.0f, 10.0f);
4329        }
4330    }
4331};
4332
4333// GGML_OP_ROUND
4334struct test_round : public test_case {
4335    const ggml_type type;
4336    const std::array<int64_t, 4> ne;
4337
4338    std::string vars() override {
4339        return VARS_TO_STR2(type, ne);
4340    }
4341
4342    test_round(ggml_type type = GGML_TYPE_F32,
4343               std::array<int64_t, 4> ne = {10, 2, 2, 2})
4344        : type(type), ne(ne) {}
4345
4346    ggml_tensor * build_graph(ggml_context * ctx) override {
4347        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4348        ggml_set_param(a);
4349        ggml_set_name(a, "a");
4350
4351        ggml_tensor * out = ggml_round(ctx, a);
4352        ggml_set_name(out, "out");
4353
4354        return out;
4355    }
4356
4357    void initialize_tensors(ggml_context * ctx) override {
4358        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
4359            init_tensor_uniform(t, -10.0f, 10.0f);
4360        }
4361    }
4362};
4363
4364// GGML_OP_TRUNC
4365struct test_trunc : public test_case {
4366    const ggml_type type;
4367    const std::array<int64_t, 4> ne;
4368
4369    std::string vars() override {
4370        return VARS_TO_STR2(type, ne);
4371    }
4372
4373    test_trunc(ggml_type type = GGML_TYPE_F32,
4374               std::array<int64_t, 4> ne = {10, 2, 2, 2})
4375        : type(type), ne(ne) {}
4376
4377    ggml_tensor * build_graph(ggml_context * ctx) override {
4378        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4379        ggml_set_param(a);
4380        ggml_set_name(a, "a");
4381
4382        ggml_tensor * out = ggml_trunc(ctx, a);
4383        ggml_set_name(out, "out");
4384
4385        return out;
4386    }
4387
4388    void initialize_tensors(ggml_context * ctx) override {
4389        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
4390            init_tensor_uniform(t, -10.0f, 10.0f);
4391        }
4392    }
4393};
4394
4395// GGML_OP_DIAG_MASK_INF
4396struct test_diag_mask_inf : public test_case {
4397    const ggml_type type;
4398    const std::array<int64_t, 4> ne;
4399    const int n_past;
4400
4401    std::string vars() override {
4402        return VARS_TO_STR3(type, ne, n_past);
4403    }
4404
4405    test_diag_mask_inf(ggml_type type = GGML_TYPE_F32,
4406            std::array<int64_t, 4> ne = {10, 10, 3, 2},
4407            int n_past = 5)
4408        : type(type), ne(ne), n_past(n_past) {}
4409
4410    ggml_tensor * build_graph(ggml_context * ctx) override {
4411        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4412        ggml_set_param(a);
4413        ggml_set_name(a, "a");
4414
4415        ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past);
4416        ggml_set_name(out, "out");
4417
4418        return out;
4419    }
4420};
4421
4422// GGML_OP_SOFT_MAX
4423struct test_soft_max : public test_case {
4424    const ggml_type type;
4425    const std::array<int64_t, 4> ne;
4426    const bool mask;
4427    const bool sinks;
4428    const ggml_type m_prec;
4429    const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3
4430    const float scale;
4431    const float max_bias;
4432    const bool inplace;
4433
4434    std::string vars() override {
4435        return VARS_TO_STR9(type, ne, mask, sinks, m_prec, nr23, scale, max_bias, inplace);
4436    }
4437
4438    // the 1024 test with bias occasionally fails:
4439    // SOFT_MAX(type=f32,ne=[1024,16,1,1],mask=1,scale=1.000000,max_bias=8.000000): [SOFT_MAX] NMSE = 0.000000103 > 0.000000100 FAIL
4440    virtual double max_nmse_err() override {
4441        return 1e-6;
4442    }
4443
4444    test_soft_max(ggml_type type = GGML_TYPE_F32,
4445            std::array<int64_t, 4> ne = {10, 5, 4, 3},
4446            bool mask = false,
4447            bool sinks = false,
4448            ggml_type m_prec = GGML_TYPE_F32,
4449            std::array<int64_t, 2> nr23 = {1, 1},
4450            float scale = 1.0f,
4451            float max_bias = 0.0f,
4452            bool inplace = false)
4453        : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias), inplace(inplace) {}
4454
4455    ggml_tensor * build_graph(ggml_context * ctx) override {
4456        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
4457        ggml_set_param(a);
4458        ggml_set_name(a, "a");
4459
4460        ggml_tensor * mask = nullptr;
4461        if (this->mask) {
4462            mask = ggml_new_tensor_4d(ctx, m_prec, ne[0], ne[1], ne[2], ne[3]);
4463            ggml_set_name(mask, "mask");
4464        }
4465
4466        ggml_tensor * sinks = nullptr;
4467        if (this->sinks) {
4468            sinks = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[2]*nr23[0]);
4469            ggml_set_name(sinks, "sinks");
4470        }
4471
4472        ggml_tensor * out;
4473        if (inplace) {
4474            out = ggml_soft_max_ext_inplace(ctx, a, mask, scale, max_bias);
4475        } else {
4476            out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
4477        }
4478        ggml_soft_max_add_sinks(out, sinks);
4479        ggml_set_name(out, "out");
4480
4481        return out;
4482    }
4483
4484    bool grad_precise() override {
4485        return true;
4486    }
4487};
4488
4489// GGML_OP_SOFT_MAX_BACK
4490struct test_soft_max_back : public test_case {
4491    const ggml_type type;
4492    const std::array<int64_t, 4> ne;
4493    const float scale;
4494    const float max_bias;
4495
4496    std::string vars() override {
4497        return VARS_TO_STR4(type, ne, scale, max_bias);
4498    }
4499
4500    test_soft_max_back(ggml_type type = GGML_TYPE_F32,
4501            std::array<int64_t, 4> ne = {10, 5, 4, 3},
4502            float scale = 1.0f,
4503            float max_bias = 0.0f)
4504        : type(type), ne(ne), scale(scale), max_bias(max_bias) {}
4505
4506    ggml_tensor * build_graph(ggml_context * ctx) override {
4507        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
4508        ggml_set_name(a, "a");
4509
4510        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
4511        ggml_set_name(a, "a");
4512
4513        ggml_tensor * out = ggml_soft_max_ext_back(ctx, a, b, scale, max_bias);
4514        ggml_set_name(out, "out");
4515
4516        return out;
4517    }
4518};
4519
4520// GGML_OP_ROPE + GGML_OP_ROPE_BACK
4521struct test_rope : public test_case {
4522    const ggml_type type;
4523    const std::array<int64_t, 4> ne_a;
4524    int n_dims;
4525    int mode;
4526    int n_ctx; // used to generate positions
4527    float fs; // freq_scale
4528    float ef; // ext_factor
4529    float af; // attn_factor
4530    bool ff;
4531    int v; // view (1 : non-contiguous a)
4532    bool forward;
4533    bool inplace;
4534
4535    std::string vars() override {
4536        // forward can be inferred from the op, does not need to be printed
4537        return VARS_TO_STR11(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v, inplace);
4538    }
4539
4540    test_rope(ggml_type type = GGML_TYPE_F32,
4541            std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
4542            int n_dims = 10, int mode = GGML_ROPE_TYPE_NORMAL, int n_ctx = 512, float fs = 1.0f,
4543            float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true, bool inplace = false)
4544        : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward), inplace(inplace) {}
4545
4546    ggml_tensor * build_graph(ggml_context * ctx) override {
4547        ggml_tensor * a;
4548        if (v & 1) {
4549            auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
4550            a = ggml_new_tensor(ctx, type, 4, ne.data());
4551            if (forward) {
4552                ggml_set_param(a);
4553            }
4554            ggml_set_name(a, "a");
4555
4556            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
4557            ggml_set_name(a, "view_of_a");
4558        } else {
4559            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
4560            if (forward) {
4561                ggml_set_param(a);
4562            }
4563            ggml_set_name(a, "a");
4564        }
4565
4566        const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
4567        const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
4568
4569        ggml_tensor * pos;
4570        if (is_mrope || is_vision) {
4571            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);
4572        } else {
4573            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
4574        }
4575        ggml_set_name(pos, "pos");
4576
4577        ggml_tensor * freq = nullptr;
4578        if (ff) {
4579            freq = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2);
4580            ggml_set_name(freq, "freq");
4581        }
4582
4583        ggml_tensor * out;
4584        if (is_mrope) {
4585            if (is_vision) {
4586                GGML_ASSERT(n_dims/4 > 0);
4587                int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
4588                if (forward) {
4589                    if (inplace) {
4590                        out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
4591                    } else {
4592                        out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
4593                    }
4594                } else {
4595                    out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
4596                }
4597            } else {
4598                GGML_ASSERT(n_dims/3 > 0);
4599                int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
4600                if (forward) {
4601                    if (inplace) {
4602                        out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
4603                    } else {
4604                        out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
4605                    }
4606                } else {
4607                    out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
4608                }
4609            }
4610        } else {
4611            if (forward) {
4612                if (inplace) {
4613                    out = ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
4614                } else {
4615                    out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
4616                }
4617            } else {
4618                out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
4619            }
4620
4621            // TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp
4622        }
4623        ggml_set_name(out, "out");
4624
4625        return out;
4626    }
4627
4628    void initialize_tensors(ggml_context * ctx) override {
4629        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
4630            if (t->type == GGML_TYPE_I32) {
4631                // pos
4632                const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];
4633                std::vector<int> data(num_pos_ids);
4634                for (int i = 0; i < num_pos_ids; i++) {
4635                    data[i] = rand() % n_ctx;
4636                }
4637                ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));
4638            } else {
4639                if (t->ne[0] == n_dims/2) {
4640                    // frequency factors in the range [0.9f, 1.1f]
4641                    init_tensor_uniform(t, 0.9f, 1.1f);
4642                } else {
4643                    init_tensor_uniform(t);
4644                }
4645            }
4646        }
4647    }
4648
4649    double max_maa_err() override {
4650        return 1e-3;
4651    }
4652
4653    bool grad_precise() override {
4654        return true;
4655    }
4656};
4657
4658// GGML_OP_POOL2D
4659struct test_pool2d : public test_case {
4660    enum ggml_op_pool pool_type;
4661    const ggml_type type_input;
4662    const std::array<int64_t, 4> ne_input;
4663    // kernel size
4664    const int k0;
4665    const int k1;
4666    // stride
4667    const int s0;
4668    const int s1;
4669    // padding
4670    const int p0;
4671    const int p1;
4672
4673    std::string vars() override {
4674        return VARS_TO_STR9(pool_type, type_input, ne_input, k0, k1, s0, s1, p0, p1);
4675    }
4676
4677    test_pool2d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,
4678            ggml_type type_input = GGML_TYPE_F32,
4679            std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
4680            int k0 = 3, int k1 = 3,
4681            int s0 = 1, int s1 = 1,
4682            int p0 = 1, int p1 = 1)
4683        : pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), k1(k1), s0(s0), s1(s1), p0(p0), p1(p1) {}
4684
4685    ggml_tensor * build_graph(ggml_context * ctx) override {
4686        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
4687        ggml_set_param(input);
4688        ggml_set_name(input, "input");
4689
4690        ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);
4691        ggml_set_name(out, "out");
4692
4693        return out;
4694    }
4695};
4696
4697// GGML_OP_POOL1D
4698struct test_pool1d : public test_case {
4699    enum ggml_op_pool pool_type;
4700    const ggml_type type_input;
4701    const std::array<int64_t, 4> ne_input;
4702    const int k0;
4703    const int s0;
4704    const int p0;
4705
4706    std::string vars() override {
4707        return VARS_TO_STR6(pool_type, type_input, ne_input, k0, s0, p0);
4708    }
4709
4710    test_pool1d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,
4711                ggml_type type_input = GGML_TYPE_F32,
4712                std::array<int64_t,4> ne_input = {10, 1, 1, 1},
4713                int k0 = 3, int s0 = 3, int p0 = 0)
4714        : pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), s0(s0), p0(p0) {}
4715
4716    ggml_tensor * build_graph(ggml_context * ctx) override {
4717        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
4718        ggml_set_param(input);
4719        ggml_set_name(input, "input");
4720
4721        ggml_tensor * out = ggml_pool_1d(ctx, input, pool_type, k0, s0, p0);
4722        ggml_set_name(out, "out");
4723
4724        return out;
4725    }
4726};
4727
4728// GGML_OP_CONV_TRANSPOSE_1D
4729struct test_conv_transpose_1d : public test_case {
4730    const std::array<int64_t, 4> ne_input;
4731    const std::array<int64_t, 4> ne_kernel;
4732
4733    const int s0; // stride
4734    const int p0; // padding
4735    const int d0; // dilation
4736
4737    std::string vars() override {
4738        return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
4739    }
4740
4741    test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_channels, 1 /* assert in cpu kernel*/, 1 (should be batch)]
4742                           std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, output_channels, input_channels, 1 (should be batch)]
4743                           int s0 = 1, int p0 = 0, int d0 = 1)
4744        : ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {}
4745
4746    ggml_tensor * build_graph(ggml_context * ctx) override {
4747        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
4748        ggml_set_name(input, "input");
4749
4750        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
4751        ggml_set_name(kernel, "kernel");
4752
4753        ggml_tensor * out = ggml_conv_transpose_1d(ctx, kernel, input, s0, p0, d0);
4754        ggml_set_name(out, "out");
4755
4756        return out;
4757    }
4758};
4759
4760// GGML_OP_CONV_TRANSPOSE_2D
4761struct test_conv_transpose_2d : public test_case {
4762    const std::array<int64_t, 4> ne_input;
4763    const std::array<int64_t, 4> ne_kernel;
4764    const int stride;
4765
4766    std::string vars() override {
4767        return VARS_TO_STR3(ne_input, ne_kernel, stride);
4768    }
4769
4770    double max_nmse_err() override {
4771        return 5e-4; // The default 1e-7 is too small for Vulkan.
4772    }
4773
4774    test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
4775                           std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
4776                           int stride = 1)
4777        : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){}
4778
4779    ggml_tensor * build_graph(ggml_context * ctx) override {
4780        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
4781        ggml_set_name(input, "input");
4782
4783        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data());
4784        ggml_set_name(kernel, "kernel");
4785
4786        ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride);
4787        ggml_set_name(out, "out");
4788
4789        return out;
4790    }
4791};
4792
4793// GGML_OP_IM2COL
4794struct test_im2col : public test_case {
4795    const ggml_type type_input;
4796    const ggml_type type_kernel;
4797    const ggml_type dst_type;
4798    const std::array<int64_t, 4> ne_input;
4799    const std::array<int64_t, 4> ne_kernel;
4800    // stride
4801    const int s0;
4802    const int s1;
4803    // padding
4804    const int p0;
4805    const int p1;
4806    // dilation
4807    const int d0;
4808    const int d1;
4809    // mode
4810    const bool is_2D;
4811
4812    std::string vars() override {
4813        return VARS_TO_STR12(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
4814    }
4815
4816    test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,
4817            std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
4818            std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
4819            int s0 = 1, int s1 = 1,
4820            int p0 = 1, int p1 = 1,
4821            int d0 = 1, int d1 = 1,
4822            bool is_2D = true)
4823        : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}
4824
4825    ggml_tensor * build_graph(ggml_context * ctx) override {
4826        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
4827        ggml_set_param(input);
4828        ggml_set_name(input, "input");
4829
4830        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
4831        ggml_set_name(kernel, "kernel");
4832
4833        ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, dst_type);
4834        ggml_set_name(out, "out");
4835
4836        return out;
4837    }
4838};
4839
4840// GGML_OP_IM2COL_3D
4841struct test_im2col_3d : public test_case {
4842    const ggml_type type_input;
4843    const ggml_type type_kernel;
4844    const ggml_type dst_type;
4845    const std::array<int64_t, 4> ne_input;
4846    const std::array<int64_t, 4> ne_kernel;
4847    // stride
4848    const int s0;
4849    const int s1;
4850    const int s2;
4851    // padding
4852    const int p0;
4853    const int p1;
4854    const int p2;
4855    // dilation
4856    const int d0;
4857    const int d1;
4858    const int d2;
4859
4860    const int64_t IC;
4861    const bool v;
4862
4863    std::string vars() override {
4864        return VARS_TO_STR16(type_input, type_kernel, dst_type, ne_input, ne_kernel, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, v);
4865    }
4866
4867    test_im2col_3d(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,
4868                std::array<int64_t, 4> ne_input = {10, 10, 10, 9}, // [OC*IC, KD, KH, KW]
4869                std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [N*IC, ID, IH, IW]
4870                int64_t IC = 3,
4871                int s0 = 1, int s1 = 1, int s2 = 1,
4872                int p0 = 1, int p1 = 1, int p2 = 1,
4873                int d0 = 1, int d1 = 1, int d2 = 1,
4874                bool v = false)
4875        : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), IC(IC), v(v) {}
4876
4877    ggml_tensor * build_graph(ggml_context * ctx) override {
4878        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
4879        ggml_set_param(input);
4880        ggml_set_name(input, "input");
4881
4882        if (v) {
4883            input = ggml_view_4d(ctx, input, ne_input[0] - 2, ne_input[1] - 2, ne_input[2] - 2, ne_input[3] - 2, input->nb[1], input->nb[2], input->nb[3], 0);
4884            ggml_set_name(input, "view_of_input");
4885        }
4886
4887        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
4888        ggml_set_name(kernel, "kernel");
4889
4890        ggml_tensor * out = ggml_im2col_3d(ctx, kernel, input, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, dst_type);
4891        ggml_set_name(out, "out");
4892
4893        return out;
4894    }
4895};
4896
4897// CONV_2D
4898struct test_conv_2d : public test_case {
4899    const std::array<int64_t, 4> ne_input;
4900    const std::array<int64_t, 4> ne_kernel;
4901    const ggml_type              type_kernel;
4902    const int                    stride0;
4903    const int                    stride1;
4904    const int                    padding0;
4905    const int                    padding1;
4906    const int                    dilation0;
4907    const int                    dilation1;
4908    // Whether the inputs are contiguous in the channel dim or the width dim
4909    const bool                   cwhn;
4910
4911    // If true, the direct CONV_2D will be used in the graph, otherwise it
4912    // uses ggml_conv_2d:
4913    // * if the program is called with -o CONV_2D_DIRECT_IMPL, the
4914    // CONV_2D graph will be built, while
4915    // * if the program is called with -o CONV_2D_INDIRECT_IMPL, the
4916    // IM2COL -> MUL_MM graph will be built.
4917
4918    std::string vars() override {
4919        return VARS_TO_STR10(ne_input, ne_kernel, type_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
4920    }
4921
4922    double max_nmse_err() override {
4923        return 5e-4;
4924    }
4925
4926    uint64_t op_flops(ggml_tensor * t) override {
4927        GGML_UNUSED(t);
4928        // Just counting matmul costs:
4929        // KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops
4930
4931        // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
4932        auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
4933            return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
4934        };
4935
4936        int64_t W    = ne_input[0];
4937        int64_t H    = ne_input[1];
4938        int64_t KW   = ne_kernel[0];
4939        int64_t KH   = ne_kernel[1];
4940        int64_t Cin  = ne_kernel[2];
4941        int64_t Cout = ne_kernel[3];
4942        int64_t N    = ne_input[3];
4943        int64_t OH   = calc_conv_output_size(H, KH, stride0, padding0, dilation0);
4944        int64_t OW   = calc_conv_output_size(W, KW, stride0, padding0, dilation0);
4945
4946        int64_t K   = Cout;
4947        int64_t CRS = Cin * KH * KW;
4948        int64_t NPQ = N * OH * OW;
4949
4950        return K * NPQ * (2 * CRS - 1);
4951    }
4952
4953    test_conv_2d(std::array<int64_t, 4> ne_input  = { 64, 64, 16, 1 },
4954                 std::array<int64_t, 4> ne_kernel = { 3, 3, 1, 16 }, ggml_type type_kernel = GGML_TYPE_F32, int stride0 = 1,
4955                 int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) :
4956        ne_input(ne_input),
4957        ne_kernel(ne_kernel),
4958        type_kernel(type_kernel),
4959        stride0(stride0),
4960        stride1(stride1),
4961        padding0(padding0),
4962        padding1(padding1),
4963        dilation0(dilation0),
4964        dilation1(dilation1),
4965        cwhn(cwhn) {}
4966
4967    ggml_tensor * build_graph(ggml_context * ctx) override {
4968        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
4969        ggml_set_name(input, "input");
4970
4971        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
4972        ggml_set_name(kernel, "kernel");
4973
4974        if (cwhn) {
4975            // change memory layout to channel-most-contiguous (CWHN),
4976            // then permute it back so NE matches the original input
4977            input  = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
4978            input  = ggml_permute(ctx, input, 2, 0, 1, 3);
4979            kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
4980            kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
4981        }
4982
4983        ggml_tensor * out =
4984            ggml_conv_2d_direct(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1);
4985        ggml_set_name(out, "out");
4986        return out;
4987    }
4988};
4989
4990// GGML_OP_CONV_2D_DW
4991struct test_conv_2d_dw : public test_case {
4992    const std::array<int64_t, 4> ne_input;
4993    const std::array<int64_t, 4> ne_kernel;
4994    const int stride;
4995    const int padding;
4996    const int dilation;
4997    const bool cwhn;
4998
4999    std::string vars() override {
5000        return VARS_TO_STR6(ne_input, ne_kernel, stride, padding, dilation, cwhn);
5001    }
5002
5003    test_conv_2d_dw(std::array<int64_t, 4> ne_input = {64, 64, 16, 1},
5004            std::array<int64_t, 4> ne_kernel = {3, 3, 1, 16},
5005            int stride = 1, int padding = 0, int dilation = 1, bool cwhn = false)
5006        : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), padding(padding), dilation(dilation), cwhn(cwhn) {}
5007
5008    ggml_tensor * build_graph(ggml_context * ctx) override {
5009        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
5010        ggml_set_name(input, "input");
5011
5012        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
5013        ggml_set_name(kernel, "kernel");
5014
5015        if (cwhn) {
5016            // change memory layout to channel-most-contiguous (CWHN),
5017            // then permute it back so NE matches the original input
5018            input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
5019            input = ggml_permute(ctx, input, 2, 0, 1, 3);
5020            kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
5021            kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
5022        }
5023
5024        ggml_tensor * out = ggml_conv_2d_dw_direct(
5025            ctx, kernel, input,
5026            stride, stride, padding, padding, dilation, dilation);
5027        ggml_set_name(out, "out");
5028        return out;
5029    }
5030};
5031
5032// GGML_OP_CONV_3D
5033struct test_conv_3d : public test_case {
5034    // Logical 5D dimensions
5035    const int64_t N, IC, ID, IH, IW;
5036    const int64_t OC, KD, KH, KW;
5037    // Conv params
5038    const int s0, s1, s2;
5039    const int p0, p1, p2;
5040    const int d0, d1, d2;
5041    // Types
5042    const ggml_type type_kernel;
5043
5044    std::string op_desc(ggml_tensor * t) override {
5045        GGML_UNUSED(t);
5046        return "CONV_3D";
5047    }
5048
5049    std::string vars() override {
5050        return VARS_TO_STR11(N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1) + "," +
5051               VARS_TO_STR8(s2, p0, p1, p2, d0, d1, d2, type_kernel);
5052    }
5053
5054    double max_nmse_err() override {
5055        return 5e-4;
5056    }
5057
5058    uint64_t op_flops(ggml_tensor * t) override {
5059        GGML_UNUSED(t);
5060        auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
5061            return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
5062        };
5063        const int64_t OD = calc_conv_output_size(ID, KD, s2, p2, d2);
5064        const int64_t OH = calc_conv_output_size(IH, KH, s1, p1, d1);
5065        const int64_t OW = calc_conv_output_size(IW, KW, s0, p0, d0);
5066
5067        return (uint64_t)N * OC * OD * OH * OW * (2 * IC * KD * KH * KW - 1);
5068    }
5069
5070    test_conv_3d(
5071        int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW,
5072        int64_t OC, int64_t KD, int64_t KH, int64_t KW,
5073        int s0, int s1, int s2,
5074        int p0, int p1, int p2,
5075        int d0, int d1, int d2,
5076        ggml_type type_kernel
5077    ) : N(N), IC(IC), ID(ID), IH(IH), IW(IW),
5078        OC(OC), KD(KD), KH(KH), KW(KW),
5079        s0(s0), s1(s1), s2(s2),
5080        p0(p0), p1(p1), p2(p2),
5081        d0(d0), d1(d1), d2(d2),
5082        type_kernel(type_kernel) {}
5083
5084    ggml_tensor * build_graph(ggml_context * ctx) override {
5085        // GGML input tensor is packed as [W, H, D, C*N]
5086        const int64_t ne_input[] = {IW, IH, ID, IC * N};
5087        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input);
5088        ggml_set_name(input, "input");
5089
5090        // GGML kernel tensor is packed as [KW, KH, KD, IC*OC]
5091        const int64_t ne_kernel[] = {KW, KH, KD, IC * OC};
5092        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel);
5093        ggml_set_name(kernel, "kernel");
5094
5095        ggml_tensor * out = ggml_conv_3d_direct(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC);
5096        ggml_set_name(out, "out");
5097        return out;
5098    }
5099};
5100
5101// GGML_OP_CONCAT
5102struct test_concat : public test_case {
5103    const ggml_type type;
5104    const std::array<int64_t, 4> ne_a;
5105    const int64_t ne_b_d;
5106    const int dim;
5107    const int v; // view (1 << 0: non-cont a, 1 << 1: non-cont b)
5108
5109    std::string vars() override {
5110        return VARS_TO_STR5(type, ne_a, ne_b_d, dim, v);
5111    }
5112
5113    test_concat(ggml_type type = GGML_TYPE_F32,
5114            std::array<int64_t, 4> ne_a = {10, 5, 5, 5},
5115            int64_t ne_b_d = 5,
5116            int dim = 2, int v = 0)
5117        : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {}
5118
5119    ggml_tensor * build_graph(ggml_context * ctx) override {
5120        auto ne_b = ne_a;
5121        ne_b[dim] = ne_b_d;
5122        ggml_tensor * a;
5123        if (v & 1) {
5124            auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
5125            a = ggml_new_tensor(ctx, type, 4, ne.data());
5126            ggml_set_name(a, "a");
5127
5128            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
5129            ggml_set_name(a, "view_of_a");
5130        } else {
5131            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
5132            ggml_set_name(a, "a");
5133        }
5134        ggml_tensor * b;
5135        if (v & 2) {
5136            auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4;
5137            b = ggml_new_tensor(ctx, type, 4, ne.data());
5138            ggml_set_name(b, "b");
5139
5140            b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0);
5141            ggml_set_name(b, "view_of_b");
5142        } else {
5143            b = ggml_new_tensor(ctx, type, 4, ne_b.data());
5144            ggml_set_name(b, "b");
5145        }
5146
5147        ggml_tensor * out = ggml_concat(ctx, a, b, dim);
5148        ggml_set_name(out, "out");
5149
5150        return out;
5151    }
5152};
5153
5154// GGML_OP_ARGSORT
5155struct test_argsort : public test_case {
5156    const ggml_type type;
5157    const std::array<int64_t, 4> ne;
5158    ggml_sort_order order;
5159
5160    std::string vars() override {
5161        return VARS_TO_STR3(type, ne, order);
5162    }
5163
5164    test_argsort(ggml_type type = GGML_TYPE_F32,
5165            std::array<int64_t, 4> ne = {16, 10, 10, 10},
5166            ggml_sort_order order = GGML_SORT_ORDER_ASC)
5167        : type(type), ne(ne), order(order) {}
5168
5169    ggml_tensor * build_graph(ggml_context * ctx) override {
5170        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
5171        ggml_set_name(a, "a");
5172
5173        ggml_tensor * out = ggml_argsort(ctx, a, order);
5174        ggml_set_name(out, "out");
5175
5176        return out;
5177    }
5178
5179    void initialize_tensors(ggml_context * ctx) override {
5180        std::random_device rd;
5181        std::default_random_engine rng(rd());
5182        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
5183            if (t->type == GGML_TYPE_I32) {
5184                // indices
5185                std::vector<int> data(ggml_nelements(t));
5186                for (int i = 0; i < ggml_nelements(t); i++) {
5187                    data[i] = rand();
5188                }
5189                std::shuffle(data.begin(), data.end(), rng);
5190                ggml_backend_tensor_set(t, data.data(), 0, ne[0]*ne[1]*ne[2]*ne[3] * sizeof(int));
5191            } else if (t->type == GGML_TYPE_F32) {
5192                // initialize with unique values to avoid ties
5193                for (int64_t r = 0; r < ggml_nrows(t); r++) {
5194                    std::vector<float> data(t->ne[0]);
5195                    for (int i = 0; i < t->ne[0]; i++) {
5196                        data[i] = i;
5197                    }
5198                    std::shuffle(data.begin(), data.end(), rng);
5199                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
5200                }
5201            } else {
5202                GGML_ABORT("fatal error");
5203            }
5204        }
5205    }
5206};
5207
5208// GGML_OP_TOP_K
5209struct test_top_k : public test_case {
5210    const ggml_type type;
5211    const std::array<int64_t, 4> ne;
5212    const int k;
5213    const bool ties;
5214    ggml_tensor * input {};
5215
5216    std::string vars() override {
5217        return VARS_TO_STR4(type, ne, k, ties);
5218    }
5219
5220    test_top_k(ggml_type type = GGML_TYPE_F32,
5221            std::array<int64_t, 4> ne = {16, 10, 10, 10},
5222            int k = 4, bool ties = false)
5223        : type(type), ne(ne), k(k), ties(ties) {}
5224
5225    double max_err() override {
5226        return 0.0;
5227    }
5228
5229    // When there are ties, only validate the final result.
5230    // The logic in err can't handle the sentinel tensors.
5231    bool run_whole_graph() override { return ties; }
5232
5233    double err(const float * a, const float * b, size_t n) override {
5234        // When there are no ties, we expect the exact same set of indices,
5235        // but possibly in a different order. When there are ties, the indices
5236        // can be different but the input values they correspond to should be
5237        // the same. The logic for ties could work for non-ties, but only for
5238        // the output tensor, not for the sentinel tensors.
5239        if (ties) {
5240            std::vector<float> src(ggml_nelements(input));
5241
5242            ggml_backend_tensor_get(input, src.data(), 0, ggml_nelements(input) * ggml_type_size(type));
5243
5244            double diff = 0.0f;
5245
5246            GGML_ASSERT(n == (size_t)(ggml_nrows(input) * k));
5247            int64_t cols = input->ne[0];
5248            std::vector<int32_t> ia(k);
5249            std::vector<int32_t> ib(k);
5250            std::vector<float> asrc(k);
5251            std::vector<float> bsrc(k);
5252            for (int64_t r = 0; r < ggml_nrows(input); r++) {
5253                // Convert indices for the row back to integer
5254                for (int64_t c = 0; c < k; c++) {
5255                    ia[c] = (int32_t)a[r * k + c];
5256                    ib[c] = (int32_t)b[r * k + c];
5257                }
5258                // The src values for each row should match.
5259                for (int64_t c = 0; c < k; c++) {
5260                    asrc[c] = src[r * cols + ia[c]];
5261                    bsrc[c] = src[r * cols + ib[c]];
5262                }
5263                diff += jdst(asrc.data(), bsrc.data(), k);
5264                // There should be no duplicate indices
5265                std::sort(ia.begin(), ia.end());
5266                std::sort(ib.begin(), ib.end());
5267                if (std::adjacent_find(ia.begin(), ia.end()) != ia.end()) {
5268                    diff += 1;
5269                }
5270                if (std::adjacent_find(ib.begin(), ib.end()) != ib.end()) {
5271                    diff += 1;
5272                }
5273            }
5274            return diff;
5275        } else {
5276            std::vector<int32_t> ia(n);
5277            std::vector<int32_t> ib(n);
5278
5279            double diff = 0.0f;
5280
5281            for (size_t i = 0; i < n; i++) {
5282                ia[i] = (int32_t) a[i];
5283                ib[i] = (int32_t) b[i];
5284
5285                // penalize the result if the data is not integer valued
5286                diff += std::fabs(a[i] - ia[i]);
5287                diff += std::fabs(b[i] - ib[i]);
5288            }
5289
5290            return diff + jdst(ia.data(), ib.data(), n);
5291        }
5292    }
5293
5294    ggml_tensor * build_graph(ggml_context * ctx) override {
5295        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
5296        ggml_set_name(a, "a");
5297
5298        // Save 'a' for err()
5299        input = a;
5300
5301        ggml_tensor * out = ggml_top_k(ctx, a, k);
5302        ggml_set_name(out, "out");
5303
5304        return out;
5305    }
5306
5307    void initialize_tensors(ggml_context * ctx) override {
5308        std::random_device rd;
5309        std::default_random_engine rng(rd());
5310        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
5311            int tie_denom = std::max(1, std::min(10, k / 2));
5312            for (int64_t r = 0; r < ggml_nrows(t); r++) {
5313                std::vector<float> data(t->ne[0]);
5314                for (int i = 0; i < t->ne[0]; i++) {
5315                    if (ties) {
5316                        // integer division to introduce duplicates
5317                        data[i] = i / tie_denom;
5318                    } else {
5319                        data[i] = i;
5320                    }
5321                }
5322                std::shuffle(data.begin(), data.end(), rng);
5323                ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
5324            }
5325        }
5326    }
5327};
5328
5329enum MoeGatingFunc {
5330    GATING_FUNC_SOFTMAX,
5331    GATING_FUNC_SIGMOID,
5332    GATING_FUNC_SOFTMAX_WEIGHT,
5333};
5334
5335struct test_topk_moe : public test_case {
5336    const std::array<int64_t, 4> ne;
5337    const int n_expert_used;
5338    const bool with_norm;
5339    const bool bias_probs;
5340    const MoeGatingFunc gating_func;
5341    const float scale_w;
5342    ggml_tensor * weights {};
5343    ggml_tensor * selected_experts {};
5344
5345    test_topk_moe(std::array<int64_t, 4> ne              = { 10, 5, 1, 1 },
5346                  int                    n_expert_used   = 1,
5347                  bool                   with_norm       = false,
5348                  bool                   bias_probs      = false,
5349                  MoeGatingFunc          gating_func     = GATING_FUNC_SOFTMAX,
5350                  float                  scale_w         = 0.0f) :
5351        ne(ne),
5352        n_expert_used(n_expert_used),
5353        with_norm(with_norm),
5354        bias_probs(bias_probs),
5355        gating_func(gating_func),
5356        scale_w(scale_w) {
5357        GGML_ASSERT(n_expert_used <= ne[0]);
5358    }
5359
5360    std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); }
5361
5362    std::string op_desc(ggml_tensor * t) override {
5363        GGML_UNUSED(t);
5364        return "TOPK_MOE";
5365    }
5366
5367    bool run_whole_graph() override { return true; }
5368
5369    ggml_tensor * build_graph(ggml_context * ctx) override {
5370        const int n_expert = ne[0];
5371        const int n_tokens = ne[1];
5372
5373        ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
5374        ggml_tensor * probs            =
5375            (gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) :
5376            (gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits;
5377        ggml_set_name(probs, "probs");
5378
5379        ggml_tensor * selection_probs = probs;
5380        if (bias_probs) {
5381            ggml_tensor * exp_probs_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
5382            ggml_set_name(exp_probs_b, "exp_probs_b");
5383            selection_probs = ggml_add(ctx, probs, exp_probs_b);
5384            ggml_set_name(selection_probs, "selection_probs");
5385        }
5386
5387        selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
5388        ggml_set_name(selected_experts, "selected_experts");
5389
5390        weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
5391        ggml_set_name(weights, "weights");
5392
5393        if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
5394            weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
5395            weights = ggml_soft_max(ctx, weights);  // [n_expert_used, n_tokens]
5396            weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
5397        }
5398
5399        if (with_norm) {
5400            weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
5401            ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
5402            ggml_set_name(weights_sum, "weights_sum");
5403
5404            weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
5405            weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
5406            weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
5407        }
5408
5409        if (scale_w) {
5410            weights = ggml_scale(ctx, weights, scale_w);
5411        }
5412
5413        ggml_set_name(weights, "weights");
5414        return weights;
5415    }
5416    // Verify two outputs
5417    std::vector<ggml_tensor *> fusion_test_nodes() override { return { selected_experts, weights }; }
5418
5419    // allow output in arbitrary order
5420    double err(const float * a, const float * b, size_t n) override {
5421        std::vector<float> a2(n);
5422        std::vector<float> b2(n);
5423        for (size_t i = 0; i < n; ++i) {
5424            a2[i] = a[i];
5425            b2[i] = b[i];
5426        }
5427        std::sort(a2.begin(), a2.end());
5428        std::sort(b2.begin(), b2.end());
5429        return nmse(a2.data(), b2.data(), n);
5430    }
5431};
5432
5433struct test_mul_mat_vec_fusion : public test_case {
5434    const ggml_type type;
5435    const ggml_glu_op glu_op;
5436    const int64_t m;
5437    const int64_t n;
5438    const int64_t k;
5439    const bool use_id;
5440    const int n_mats;
5441    const int n_used;
5442    const bool b;        // broadcast b matrix (only for use_id)
5443    const bool with_bias;
5444    const bool with_gate;
5445    std::array<int64_t, 2> batch_dims;
5446
5447    test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k,
5448                        bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true,
5449                        std::array<int64_t, 2> batch_dims = {4, 2})
5450    : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate), batch_dims(batch_dims) {
5451        if (use_id) {
5452            GGML_ASSERT(n_used <= n_mats);
5453        }
5454    }
5455
5456    std::string vars() override {
5457        return VARS_TO_STR12(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate, batch_dims);
5458    }
5459
5460    std::string op_desc(ggml_tensor * t) override {
5461        GGML_UNUSED(t);
5462        return "MUL_MAT_VEC_FUSION";
5463    }
5464
5465    bool run_whole_graph() override { return true; }
5466
5467    ggml_tensor * build_gate(ggml_context * ctx, ggml_tensor * ffn_gate, ggml_tensor * ffn_up) {
5468        ggml_tensor * out = nullptr;
5469        if (with_gate) {
5470            if (glu_op == GGML_GLU_OP_SWIGLU_OAI) {
5471                constexpr float alpha = 1.702f;
5472                constexpr float limit = 7.0f;
5473                out = ggml_swiglu_oai(ctx, ffn_gate, ffn_up, alpha, limit);
5474            } else {
5475                out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op);
5476            }
5477        }
5478        return out;
5479    }
5480
5481    ggml_tensor * build_graph(ggml_context * ctx) override {
5482        if (!use_id) {
5483            const int              channels = batch_dims[0];
5484            const int              samples  = batch_dims[1];
5485            std::array<int64_t, 4> ne       = { k, m, channels, samples };
5486            std::array<int64_t, 4> ne0      = { k, n, channels, samples };
5487
5488            ggml_tensor * cur  = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
5489            ggml_tensor * gate = with_gate ? ggml_new_tensor(ctx, type, 4, ne0.data()) : nullptr;
5490            ggml_tensor * up   = ggml_new_tensor(ctx, type, 4, ne0.data());
5491
5492            ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur);
5493            if (with_bias) {
5494                std::array<int64_t, 4> bias_ne = { ffn_up->ne[0], 1, channels, samples };
5495                ggml_tensor * up_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
5496                ffn_up = ggml_add(ctx, ffn_up, up_bias);
5497            }
5498
5499            ggml_tensor * ffn_gate = with_gate ? ggml_mul_mat(ctx, gate, cur) : nullptr;
5500            if (with_bias && with_gate) {
5501                std::array<int64_t, 4> bias_ne   = { ffn_gate->ne[0], 1, channels, samples };
5502                ggml_tensor * gate_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
5503                ffn_gate = ggml_add(ctx, ffn_gate, gate_bias);
5504            }
5505
5506            ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
5507
5508            std::array<int64_t, 4> bias2_ne   = { out->ne[0], 1, channels, samples };
5509            ggml_tensor * bias2 = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias2_ne.data());
5510            out = ggml_add(ctx, out, bias2);
5511
5512            ggml_set_name(out, "out");
5513            return out;
5514        } else {
5515            ggml_tensor * gates = ggml_new_tensor_3d(ctx, type, k, n, n_mats);
5516            ggml_tensor * ups   = ggml_new_tensor_3d(ctx, type, k, n, n_mats);
5517            ggml_tensor * ids   = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, m);
5518
5519            if (n_used != n_mats) {
5520                ids = ggml_view_2d(ctx, ids, n_used, m, ids->nb[1], 0);
5521            }
5522
5523            ggml_tensor * cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k, this->b ? 1 : n_used, m);
5524            ggml_set_name(cur, "cur");
5525
5526            ggml_tensor * ffn_up = ggml_mul_mat_id(ctx, ups, cur, ids);
5527            if (with_bias) {
5528                ggml_tensor * up_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_up->ne[0], n_mats);
5529                ffn_up = ggml_add_id(ctx, ffn_up, up_bias_param, ids);
5530            }
5531
5532            ggml_tensor * ffn_gate = with_gate? ggml_mul_mat_id(ctx, gates, cur, ids) : nullptr;
5533            if (with_bias && with_gate) {
5534                ggml_tensor * gate_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_gate->ne[0], n_mats);
5535                ffn_gate = ggml_add_id(ctx, ffn_gate, gate_bias_param, ids);
5536            }
5537
5538            ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
5539
5540            std::array<int64_t, 4> scale_ne { 1, out->ne[1], out->ne[2], out->ne[3] };
5541            ggml_tensor * scale = ggml_new_tensor(ctx, out->type, 4, scale_ne.data());
5542            out = ggml_mul(ctx, out, scale);
5543
5544            ggml_set_name(out, "out");
5545            return out;
5546        }
5547    }
5548
5549    void initialize_tensors(ggml_context * ctx) override {
5550        if (!use_id) {
5551            for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
5552                init_tensor_uniform(t);
5553            }
5554        } else {
5555            init_mul_mat_id_tensors(ctx, n_mats);
5556        }
5557    }
5558
5559    double max_nmse_err() override {
5560        return 5e-3;
5561    }
5562};
5563
5564// GGML_OP_SUM
5565struct test_sum : public test_case {
5566    const ggml_type type;
5567    const std::array<int64_t, 4> ne;
5568    const std::array<int64_t, 4> permute;
5569    bool _use_permute;
5570
5571    std::string vars() override {
5572        std::string v = VARS_TO_STR2(type, ne);
5573        if (_use_permute) v += "," + VAR_TO_STR(permute);
5574        return v;
5575    }
5576
5577    test_sum(ggml_type type = GGML_TYPE_F32,
5578            std::array<int64_t, 4> ne = {10, 5, 4, 3},
5579            std::array<int64_t, 4> permute = {0, 0, 0, 0})
5580        : type(type), ne(ne), permute(permute),
5581            _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
5582
5583    ggml_tensor * build_graph(ggml_context * ctx) override {
5584        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
5585        ggml_set_param(a);
5586        ggml_set_name(a, "a");
5587
5588        if (_use_permute) {
5589            a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]);
5590            ggml_set_name(a, "a_permuted");
5591        }
5592
5593        ggml_tensor * out = ggml_sum(ctx, a);
5594        ggml_set_name(out, "out");
5595
5596        return out;
5597    }
5598
5599    float grad_eps() override {
5600        return 0.1f * sqrtf(ne[0]*ne[1]*ne[2]*ne[3]);
5601    }
5602
5603    // Don't center the distribution around zero. Helps to avoid catastrophic cancellation.
5604    void initialize_tensors(ggml_context * ctx) override {
5605        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
5606            init_tensor_uniform(t, -0.9f, 1.1f);
5607        }
5608    }
5609};
5610
5611// GGML_OP_SUM_ROWS
5612struct test_sum_rows : public test_case {
5613    const ggml_type type;
5614    const std::array<int64_t, 4> ne;
5615    const bool permute;
5616    const bool slice;
5617
5618    std::string vars() override {
5619        return VARS_TO_STR4(type, ne, permute, slice);
5620    }
5621
5622    test_sum_rows(ggml_type type = GGML_TYPE_F32,
5623            std::array<int64_t, 4> ne = {10, 5, 4, 3},
5624            bool permute = false, bool slice = false)
5625        : type(type), ne(ne), permute(permute), slice(slice) {}
5626
5627    ggml_tensor * build_graph(ggml_context * ctx) override {
5628        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
5629        ggml_set_param(a);
5630        ggml_set_name(a, "a");
5631
5632        if (slice) {
5633            a = ggml_view_4d(ctx, a,
5634                             ne[0], ne[1], ne[2] / 2, ne[3] - 1,
5635                             a->nb[1], a->nb[2] * 2, a->nb[3], /*offset=*/a->nb[3]);
5636        }
5637        if (permute) {
5638            a = ggml_permute(ctx, a, 0, 2, 3, 1);
5639        }
5640
5641        ggml_tensor * out = ggml_sum_rows(ctx, a);
5642        ggml_set_name(out, "out");
5643
5644        return out;
5645    }
5646};
5647
5648// GGML_OP_MEAN
5649struct test_mean : public test_case {
5650    const ggml_type type;
5651    const std::array<int64_t, 4> ne;
5652
5653    std::string vars() override {
5654        return VARS_TO_STR2(type, ne);
5655    }
5656
5657    test_mean(ggml_type type = GGML_TYPE_F32,
5658            std::array<int64_t, 4> ne = {10, 5, 4, 3})
5659        : type(type), ne(ne) {}
5660
5661    ggml_tensor * build_graph(ggml_context * ctx) override {
5662        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
5663        ggml_set_param(a);
5664        ggml_set_name(a, "a");
5665
5666        ggml_tensor * out = ggml_mean(ctx, a);
5667        ggml_set_name(out, "out");
5668
5669        return out;
5670    }
5671
5672    float grad_eps() override {
5673        return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];
5674    }
5675
5676    // Don't center the distribution around zero. Helps to avoid catastrophic cancellation.
5677    void initialize_tensors(ggml_context * ctx) override {
5678        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
5679            init_tensor_uniform(t, -0.9f, 1.1f);
5680        }
5681    }
5682};
5683
5684// GGML_OP_UPSCALE
5685struct test_upscale : public test_case {
5686    const ggml_type type;
5687    const std::array<int64_t, 4> ne;
5688    const int32_t scale_factor;
5689    const bool transpose;
5690    const ggml_scale_mode mode;
5691
5692    std::string vars() override {
5693        return VARS_TO_STR5(type, ne, scale_factor, mode, transpose);
5694    }
5695
5696    test_upscale(ggml_type type = GGML_TYPE_F32,
5697            std::array<int64_t, 4> ne = {512, 512, 3, 1},
5698            int32_t scale_factor = 2, ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST, bool transpose = false)
5699        : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose), mode(mode) {}
5700
5701    ggml_tensor * build_graph(ggml_context * ctx) override {
5702        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
5703        ggml_set_name(a, "a");
5704
5705        if (transpose) {
5706            a = ggml_transpose(ctx, a);
5707            ggml_set_name(a, "a_transposed");
5708        }
5709
5710        ggml_tensor * out = ggml_upscale(ctx, a, scale_factor, mode);
5711        ggml_set_name(out, "out");
5712
5713        return out;
5714    }
5715};
5716
5717// GGML_OP_UPSCALE (via ggml_interpolate)
5718struct test_interpolate : public test_case {
5719    const ggml_type type;
5720    const std::array<int64_t, 4> ne;
5721    const std::array<int64_t, 4> ne_tgt;
5722    const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
5723
5724    std::string vars() override {
5725        return VARS_TO_STR4(type, ne, ne_tgt, mode);
5726    }
5727
5728    test_interpolate(ggml_type type = GGML_TYPE_F32,
5729            std::array<int64_t, 4> ne     = {2, 5,  7, 11},
5730            std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13},
5731            ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
5732        : type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
5733
5734    ggml_tensor * build_graph(ggml_context * ctx) override {
5735        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
5736        ggml_set_name(a, "a");
5737
5738        ggml_tensor * out = ggml_interpolate(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode);
5739        ggml_set_name(out, "out");
5740
5741        return out;
5742    }
5743};
5744
5745// GGML_OP_GROUP_NORM
5746struct test_group_norm : public test_case {
5747    const ggml_type type;
5748    const std::array<int64_t, 4> ne;
5749    const int32_t num_groups;
5750    const float eps;
5751
5752    std::string vars() override {
5753        return VARS_TO_STR4(type, ne, num_groups, eps);
5754    }
5755
5756    test_group_norm(ggml_type type = GGML_TYPE_F32,
5757            std::array<int64_t, 4> ne = {64, 64, 320, 1},
5758            int32_t num_groups = 32,
5759            float eps = 1e-6f)
5760        : type(type), ne(ne), num_groups(num_groups), eps(eps) {}
5761
5762    ggml_tensor * build_graph(ggml_context * ctx) override {
5763        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
5764        ggml_set_name(a, "a");
5765
5766        ggml_tensor * out = ggml_group_norm(ctx, a, num_groups, eps);
5767        ggml_set_name(out, "out");
5768
5769        return out;
5770    }
5771};
5772
5773// GGML_OP_GROUP_NORM + GGML_OP_MUL + GGML_OP_ADD
5774struct test_group_norm_mul_add : public test_case {
5775    const ggml_type type;
5776    const std::array<int64_t, 4> ne;
5777    int num_groups;
5778    float eps;
5779
5780    std::string op_desc(ggml_tensor * t) override {
5781        GGML_UNUSED(t);
5782        return "GROUP_NORM_MUL_ADD";
5783    }
5784
5785    bool run_whole_graph() override { return true; }
5786
5787    std::string vars() override {
5788        return VARS_TO_STR4(type, ne, num_groups, eps);
5789    }
5790
5791    test_group_norm_mul_add(ggml_type type = GGML_TYPE_F32,
5792            std::array<int64_t, 4> ne = {128, 1, 1, 1},
5793            int num_groups = 4,
5794            float eps = 1e-5f)
5795        : type(type), ne(ne), num_groups(num_groups), eps(eps) {}
5796
5797    ggml_tensor * build_graph(ggml_context * ctx) override {
5798        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
5799        ggml_tensor * w = ggml_new_tensor(ctx, type, 4, ne.data());
5800        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
5801        ggml_set_param(a); ggml_set_param(w); ggml_set_param(b);
5802        ggml_set_name(a, "a"); ggml_set_name(w, "w"); ggml_set_name(b, "b");
5803        ggml_tensor * n = ggml_group_norm(ctx, a, num_groups, eps);
5804        ggml_tensor * m = ggml_mul(ctx, n, w);
5805        ggml_tensor * out = ggml_add(ctx, m, b);
5806        ggml_set_name(out, "out");
5807        return out;
5808    }
5809};
5810
5811// GGML_OP_L2_NORM
5812struct test_l2_norm : public test_case {
5813    const ggml_type type;
5814    const std::array<int64_t, 4> ne;
5815    const float eps;
5816
5817    std::string vars() override {
5818        return VARS_TO_STR2(type, ne);
5819    }
5820
5821    test_l2_norm(ggml_type type = GGML_TYPE_F32,
5822            std::array<int64_t, 4> ne = {64, 64, 320, 1},
5823            float eps = 1e-12f)
5824        : type(type), ne(ne), eps(eps) {}
5825
5826    ggml_tensor * build_graph(ggml_context * ctx) override {
5827        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
5828        ggml_set_name(a, "a");
5829
5830        ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
5831        ggml_set_name(out, "out");
5832
5833        return out;
5834    }
5835};
5836
5837// GGML_OP_ACC
5838struct test_acc : public test_case {
5839    const ggml_type type;
5840    const std::array<int64_t, 4> ne_a;
5841    const std::array<int64_t, 4> ne_b;
5842
5843    std::string vars() override {
5844        return VARS_TO_STR3(type, ne_a, ne_b);
5845    }
5846
5847    test_acc(ggml_type type = GGML_TYPE_F32,
5848            std::array<int64_t, 4> ne_a = {256, 17, 1, 1},
5849            std::array<int64_t, 4> ne_b = {256, 16, 1, 1})
5850        : type(type), ne_a(ne_a), ne_b(ne_b) {}
5851
5852    ggml_tensor * build_graph(ggml_context * ctx) override {
5853        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
5854        ggml_set_param(a);
5855        ggml_set_name(a, "a");
5856
5857        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
5858        ggml_set_param(b);
5859        ggml_set_name(b, "b");
5860
5861        ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
5862        ggml_set_name(out, "out");
5863
5864        return out;
5865    }
5866};
5867
5868// GGML_OP_PAD
5869struct test_pad : public test_case {
5870    const ggml_type type;
5871    const std::array<int64_t, 4> ne_a;
5872    const int pad_0;
5873    const int pad_1;
5874    const bool circular;
5875
5876    std::string vars() override {
5877        return VARS_TO_STR5(type, ne_a, pad_0, pad_1, circular);
5878    }
5879
5880    test_pad(ggml_type type = GGML_TYPE_F32,
5881            std::array<int64_t, 4> ne_a = {512, 512, 1, 1},
5882            int pad_0 = 1, int pad_1 = 1, bool circular = false)
5883        : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1), circular(circular) {}
5884
5885    ggml_tensor * build_graph(ggml_context * ctx) override {
5886        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
5887        ggml_set_name(a, "a");
5888
5889        ggml_tensor * out = circular
5890            ? ggml_pad_circular(ctx, a, pad_0, pad_1, 0, 0)
5891            : ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
5892        ggml_set_name(out, "out");
5893
5894        return out;
5895    }
5896};
5897
5898// GGML_OP_PAD (with extension)
5899struct test_pad_ext : public test_case {
5900    const ggml_type type;
5901    const std::array<int64_t, 4> ne_a;
5902    const int lp0;
5903    const int rp0;
5904    const int lp1;
5905    const int rp1;
5906    const int lp2;
5907    const int rp2;
5908    const int lp3;
5909    const int rp3;
5910    const int tfrm; // 0 - none, 1 - non-cont, 2 - perm
5911    const bool circular;
5912
5913    std::string vars() override {
5914        return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, tfrm, circular);
5915    }
5916
5917    test_pad_ext(ggml_type type = GGML_TYPE_F32,
5918            std::array<int64_t, 4> ne_a = {512, 512, 3, 1},
5919            int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,
5920            int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1,
5921            int tfrm = 0, bool circular = false)
5922        : type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3),
5923          tfrm(tfrm), circular(circular) {}
5924
5925    ggml_tensor * build_graph(ggml_context * ctx) override {
5926        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
5927        ggml_set_name(a, "a");
5928
5929        if (tfrm == 1) {
5930            a = ggml_view_4d(ctx, a, (a->ne[0] + 1) / 2, (a->ne[1] + 1) / 2, (a->ne[2] + 1) / 2, (a->ne[3] + 1) / 2, a->nb[1], a->nb[2], a->nb[3], 0);
5931            ggml_set_name(a, "view of a");
5932        } else if (tfrm == 2) {
5933            a = ggml_permute(ctx, a, 2, 1, 0, 3);
5934            ggml_set_name(a, "permuted a");
5935        }
5936
5937        ggml_tensor * out = circular
5938            ? ggml_pad_ext_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)
5939            : ggml_pad_ext         (ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
5940        ggml_set_name(out, "out");
5941
5942        return out;
5943    }
5944};
5945
5946// GGML_OP_PAD_REFLECT_1D
5947struct test_pad_reflect_1d : public test_case {
5948    const ggml_type type;
5949    const std::array<int64_t, 4> ne_a;
5950    const int pad_0;
5951    const int pad_1;
5952
5953    std::string vars() override {
5954        return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
5955    }
5956
5957    test_pad_reflect_1d(ggml_type type = GGML_TYPE_F32,
5958            std::array<int64_t, 4> ne_a = {512, 34, 2, 1},
5959            int pad_0 = 10, int pad_1 = 9)
5960        : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1)  {}
5961
5962    ggml_tensor * build_graph(ggml_context * ctx) override {
5963        ggml_tensor * a = ggml_new_tensor(ctx, type, 2, ne_a.data());
5964        ggml_set_name(a, "a");
5965
5966        ggml_tensor * out = ggml_pad_reflect_1d(ctx, a, pad_0, pad_1);
5967        ggml_set_name(out, "out");
5968
5969        return out;
5970    }
5971};
5972
5973// GGML_OP_ROLL
5974struct test_roll : public test_case {
5975    const int shift0;
5976    const int shift1;
5977    const int shift3;
5978    const int shift4;
5979
5980    std::string vars() override {
5981        return VARS_TO_STR4(shift0, shift1, shift3, shift4);
5982    }
5983
5984    test_roll(int shift0 = 3, int shift1 = -2, int shift3 = 1, int shift4 = -1)
5985        : shift0(shift0), shift1(shift1), shift3(shift3), shift4(shift4) {}
5986
5987    ggml_tensor * build_graph(ggml_context * ctx) override {
5988        int64_t ne[4] = {10, 5, 4, 3};
5989        ggml_tensor * a = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
5990        ggml_set_name(a, "a");
5991
5992        ggml_tensor * out = ggml_roll(ctx, a, shift0, shift1, shift3, shift4);
5993        ggml_set_name(out, "out");
5994
5995        return out;
5996    }
5997};
5998
5999// GGML_OP_ARANGE
6000struct test_arange : public test_case {
6001    const ggml_type type;
6002    const float start;
6003    const float stop;
6004    const float step;
6005
6006    std::string vars() override {
6007        return VARS_TO_STR4(type, start, stop, step);
6008    }
6009
6010    test_arange(ggml_type type = GGML_TYPE_F32,
6011            float start = 0.f, float stop = 10.f, float step = 1.f)
6012        : type(type), start(start), stop(stop), step(step)  {}
6013
6014    ggml_tensor * build_graph(ggml_context * ctx) override {
6015        ggml_tensor * out = ggml_arange(ctx, start, stop, step);
6016        ggml_set_name(out, "out");
6017
6018        return out;
6019    }
6020};
6021
6022// GGML_OP_TIMESTEP_EMBEDDING
6023struct test_timestep_embedding : public test_case {
6024    const ggml_type type;
6025    const std::array<int64_t, 4> ne_a;
6026    const int dim;
6027    const int max_period;
6028
6029    std::string vars() override {
6030        return VARS_TO_STR4(type, ne_a, dim, max_period);
6031    }
6032
6033    test_timestep_embedding(ggml_type type = GGML_TYPE_F32,
6034            std::array<int64_t, 4> ne_a = {2, 1, 1, 1},
6035            int dim = 320, int max_period=10000)
6036        : type(type), ne_a(ne_a), dim(dim), max_period(max_period)  {}
6037
6038    ggml_tensor * build_graph(ggml_context * ctx) override {
6039        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
6040        ggml_set_name(a, "a");
6041
6042        ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period);
6043        ggml_set_name(out, "out");
6044
6045        return out;
6046    }
6047};
6048
6049// GGML_OP_LEAKY_RELU
6050struct test_leaky_relu : public test_case {
6051    const ggml_type type;
6052    const std::array<int64_t, 4> ne_a;
6053    const float negative_slope;
6054
6055    std::string vars() override {
6056        return VARS_TO_STR3(type, ne_a, negative_slope);
6057    }
6058
6059    test_leaky_relu(ggml_type type = GGML_TYPE_F32,
6060            std::array<int64_t, 4> ne_a = {10, 5, 4, 3},
6061            float negative_slope = 0.1f)
6062        : type(type), ne_a(ne_a), negative_slope(negative_slope)  {}
6063
6064    ggml_tensor * build_graph(ggml_context * ctx) override {
6065        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
6066        ggml_set_name(a, "a");
6067
6068        ggml_tensor * out = ggml_leaky_relu(ctx, a, negative_slope, true);
6069        ggml_set_name(out, "out");
6070
6071        return out;
6072    }
6073};
6074
6075// GGML_OP_FLASH_ATTN_EXT
6076struct test_flash_attn_ext : public test_case {
6077    const int64_t hsk; // K head size
6078    const int64_t hsv; // V head size
6079    const int64_t nh; // num heads
6080    const std::array<int64_t, 2> nr23; // repeat in dim 2 and 3, tests for grouped-query attention
6081    const int64_t kv; // kv size
6082    const int64_t nb; // batch size
6083
6084    const bool mask; // use mask
6085    const bool sinks; // use sinks
6086
6087    const float max_bias; // ALiBi
6088    const float logit_softcap; // Gemma 2
6089
6090    const ggml_prec prec;
6091    const ggml_type type_KV;
6092    std::array<int32_t, 4> permute;
6093
6094    std::string vars() override {
6095        return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute);
6096    }
6097
6098    double max_nmse_err() override {
6099        return 5e-4;
6100    }
6101
6102    uint64_t op_flops(ggml_tensor * t) override {
6103        GGML_UNUSED(t);
6104        // Just counting matmul costs:
6105        // Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head
6106        return (2 * nh*nr23[0] * nb * (hsk + hsv) * kv)*nr23[1];
6107    }
6108
6109    test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array<int64_t, 2> nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,
6110                        bool mask = true, bool sinks = false, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
6111                        ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
6112        : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
6113
6114    ggml_tensor * build_graph(ggml_context * ctx) override {
6115        const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
6116        const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
6117
6118        auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * {
6119            int64_t ne[4] = {ne0, ne1, ne2, ne3};
6120            int64_t ne_perm[4];
6121            for (int i = 0; i < 4; ++i) {
6122                ne_perm[permute[i]] = ne[i];
6123            }
6124            ggml_tensor * t;
6125            if (is_view) {
6126                ggml_tensor * t0 = ggml_new_tensor_4d(ctx, type, ne_perm[0], 2*ne_perm[1], ne_perm[2], ne_perm[3]);
6127                t = ggml_view_4d(ctx, t0, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3], t0->nb[1], t0->nb[2], t0->nb[3], 0);
6128            } else {
6129                t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
6130            }
6131            if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) {
6132                t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
6133            }
6134            return t;
6135        };
6136
6137        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1], false);
6138        ggml_set_name(q, "q");
6139
6140        ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,         nr23[1], true); // the K tensor is usually a view of the K cache
6141        ggml_set_name(k, "k");
6142
6143        ggml_tensor * v = nullptr;
6144        if (hsk_padded == 576 && hsv_padded == 512) {
6145            // TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes
6146
6147            // in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models
6148            // for more info:
6149            //   - https://github.com/ggml-org/llama.cpp/pull/13435
6150            //   - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392
6151            //   - https://github.com/ggml-org/llama.cpp/pull/18986
6152            v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
6153        } else {
6154            v = create_permuted(type_KV,       hsv_padded, kv, nh,         nr23[1], true); // the V tensor is usually a view of the V cache
6155        }
6156        ggml_set_name(v, "v");
6157
6158        ggml_tensor * m = nullptr;
6159        if (mask) {
6160            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, nr23[1]);
6161            ggml_set_name(m, "m");
6162        }
6163
6164        ggml_tensor * s = nullptr;
6165        if (sinks) {
6166            s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, q->ne[2]);
6167            ggml_set_name(s, "s");
6168        }
6169
6170        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
6171        ggml_flash_attn_ext_add_sinks(out, s);
6172        ggml_flash_attn_ext_set_prec (out, prec);
6173        ggml_set_name(out, "out");
6174
6175        return out;
6176    }
6177
6178    void initialize_tensors(ggml_context * ctx) override {
6179        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
6180            if (strcmp(t->name, "s") == 0) {
6181                // make the sink values more noticable in order to trigger a test failure when the implementation is wrong
6182                init_tensor_uniform(t, -10.0f, 10.0f);
6183            } else if (strcmp(t->name, "m") == 0) {
6184                init_tensor_kq_mask(t);
6185            } else {
6186                init_tensor_uniform(t);
6187            }
6188        }
6189    }
6190
6191    bool grad_precise() override {
6192        return true;
6193    }
6194};
6195
6196// GGML_OP_CROSS_ENTROPY_LOSS
6197struct test_cross_entropy_loss : public test_case {
6198    const ggml_type type;
6199    const std::array<int64_t, 4> ne;
6200
6201    std::string vars() override {
6202        return VARS_TO_STR2(type, ne);
6203    }
6204
6205    test_cross_entropy_loss(ggml_type type = GGML_TYPE_F32,
6206            std::array<int64_t, 4> ne = {10, 5, 4, 3})
6207        : type(type), ne(ne) {}
6208
6209    ggml_tensor * build_graph(ggml_context * ctx) override {
6210        ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
6211        ggml_set_param(logits);
6212        ggml_set_name(logits, "logits");
6213
6214        ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
6215        // The labels are assumed to be constant -> no gradients.
6216        ggml_set_name(labels, "labels");
6217
6218        // Ensure labels add up to 1:
6219        labels = ggml_soft_max(ctx, labels);
6220        ggml_set_name(labels, "labels_normalized");
6221
6222        ggml_tensor * out = ggml_cross_entropy_loss(ctx, logits, labels);
6223        ggml_set_name(out, "out");
6224
6225        return out;
6226    }
6227
6228    void initialize_tensors(ggml_context * ctx) override {
6229        // For larger abs. diffs between logits softmax is more linear, therefore more precise num. gradients.
6230        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
6231            init_tensor_uniform(t, -100.0f, 100.0f);
6232        }
6233    }
6234
6235    float grad_eps() override {
6236        return 1.0f;
6237    }
6238
6239    bool grad_precise() override {
6240        return true;
6241    }
6242};
6243
6244// GGML_OP_CROSS_ENTROPY_LOSS_BACK
6245struct test_cross_entropy_loss_back : public test_case {
6246    const ggml_type type;
6247    const std::array<int64_t, 4> ne;
6248
6249    std::string vars() override {
6250        return VARS_TO_STR2(type, ne);
6251    }
6252
6253    test_cross_entropy_loss_back(ggml_type type = GGML_TYPE_F32,
6254            std::array<int64_t, 4> ne = {10, 5, 4, 3})
6255        : type(type), ne(ne) {}
6256
6257    ggml_tensor * build_graph(ggml_context * ctx) override {
6258        ggml_tensor * grad = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
6259        ggml_set_name(grad, "grad");
6260
6261        ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
6262        ggml_set_name(logits, "logits");
6263
6264        ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
6265        ggml_set_name(labels, "labels");
6266
6267        // Ensure labels add up to 1:
6268        labels = ggml_soft_max(ctx, labels);
6269        ggml_set_name(labels, "labels_normalized");
6270
6271        ggml_tensor * out = ggml_cross_entropy_loss_back(ctx, grad, logits, labels);
6272        ggml_set_name(out, "out");
6273
6274        return out;
6275    }
6276};
6277
6278// GGML_OP_OPT_STEP_ADAMW
6279struct test_opt_step_adamw : public test_case {
6280    const ggml_type type;
6281    const std::array<int64_t, 4> ne;
6282
6283    std::string vars() override {
6284        return VARS_TO_STR2(type, ne);
6285    }
6286
6287    test_opt_step_adamw(ggml_type type = GGML_TYPE_F32,
6288            std::array<int64_t, 4> ne = {10, 5, 4, 3})
6289        : type(type), ne(ne) {}
6290
6291    ggml_tensor * build_graph(ggml_context * ctx) override {
6292        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6293        ggml_set_param(a); // Despite tensor a having gradients the output tensor will not.
6294        ggml_set_name(a, "a");
6295
6296        ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6297        ggml_set_name(grad, "grad");
6298
6299        ggml_tensor * grad_m = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6300        ggml_set_name(grad_m, "grad_m");
6301
6302        ggml_tensor * grad_v = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6303        ggml_set_name(grad_v, "grad_v");
6304
6305        ggml_tensor * adamw_params = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 7);
6306        ggml_set_name(adamw_params, "adamw_params");
6307
6308        ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, grad_m, grad_v, adamw_params);
6309        ggml_set_name(out, "out");
6310
6311        return out;
6312    }
6313
6314    void initialize_tensors(ggml_context * ctx) override {
6315        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
6316            init_tensor_uniform(t, 0.0f, 1.0f); // grad_v and adamw_params need non-negative values.
6317        }
6318    }
6319
6320    bool grad_precise() override {
6321        return true;
6322    }
6323};
6324
6325// GGML_OP_OPT_STEP_SGD
6326struct test_opt_step_sgd : public test_case {
6327    const ggml_type              type;
6328    const std::array<int64_t, 4> ne;
6329
6330    std::string vars() override { return VARS_TO_STR2(type, ne); }
6331
6332    test_opt_step_sgd(ggml_type type = GGML_TYPE_F32,
6333            std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
6334        : type(type), ne(ne) {}
6335
6336    ggml_tensor * build_graph(ggml_context * ctx) override {
6337        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6338        ggml_set_param(a);  // Despite tensor a having gradients the output tensor will not.
6339        ggml_set_name(a, "a");
6340
6341        ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6342        ggml_set_name(grad, "grad");
6343
6344        ggml_tensor * sgd_params = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2);
6345        ggml_set_name(sgd_params, "sgd_params");
6346
6347        ggml_tensor * out = ggml_opt_step_sgd(ctx, a, grad, sgd_params);
6348
6349        ggml_set_name(out, "out");
6350
6351        return out;
6352    }
6353
6354    void initialize_tensors(ggml_context * ctx) override {
6355        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
6356            init_tensor_uniform(t, 0.0f, 1.0f);  // sgd_params need non-negative values.
6357        }
6358    }
6359
6360    bool grad_precise() override {
6361        return true;
6362    }
6363};
6364
6365// GGML_OP_CUMSUM
6366struct test_cumsum : public test_case {
6367    const ggml_type              type;
6368    const std::array<int64_t, 4> ne;
6369
6370    std::string vars() override { return VARS_TO_STR2(type, ne); }
6371
6372    test_cumsum(ggml_type type = GGML_TYPE_F32,
6373            std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
6374        : type(type), ne(ne) {}
6375
6376    ggml_tensor * build_graph(ggml_context * ctx) override {
6377        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6378        ggml_set_param(a);
6379        ggml_set_name(a, "a");
6380
6381        ggml_tensor * out = ggml_cumsum(ctx, a);
6382
6383        ggml_set_name(out, "out");
6384
6385        return out;
6386    }
6387
6388    void initialize_tensors(ggml_context * ctx) override {
6389        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
6390            init_tensor_uniform(t, -1.0f, 1.0f);
6391        }
6392    }
6393};
6394
6395// GGML_OP_XIELU
6396struct test_xielu : public test_case {
6397    const ggml_type              type;
6398    const std::array<int64_t, 4> ne;
6399
6400    std::string vars() override { return VARS_TO_STR2(type, ne); }
6401
6402    test_xielu(ggml_type type = GGML_TYPE_F32,
6403            std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
6404        : type(type), ne(ne) {}
6405
6406    ggml_tensor * build_graph(ggml_context * ctx) override {
6407        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6408        ggml_set_param(a);
6409        ggml_set_name(a, "a");
6410
6411        float alpha_n = 4.0f;
6412        float alpha_p = 20.0f;
6413        float beta = 0.5f;
6414        float eps = 0.0000001f;
6415
6416        ggml_tensor * out = ggml_xielu(ctx, a, alpha_n, alpha_p, beta, eps);
6417
6418        ggml_set_name(out, "out");
6419
6420        return out;
6421    }
6422
6423    void initialize_tensors(ggml_context * ctx) override {
6424        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
6425            init_tensor_uniform(t, -1.0f, 1.0f);
6426        }
6427    }
6428};
6429
6430// GGML_OP_TRI
6431struct test_tri : public test_case {
6432    const ggml_type              type;
6433    const std::array<int64_t, 4> ne;
6434    const ggml_tri_type          tri_type;
6435
6436    std::string vars() override { return VARS_TO_STR3(type, ne, tri_type); }
6437
6438    test_tri(ggml_tri_type tri_type, ggml_type type = GGML_TYPE_F32,
6439            std::array<int64_t, 4> ne = { 10, 10, 4, 3 })
6440        : type(type), ne(ne), tri_type(tri_type) {
6441            GGML_ASSERT(ne[0] == ne[1]);
6442        }
6443
6444    ggml_tensor * build_graph(ggml_context * ctx) override {
6445        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6446        ggml_set_param(a);
6447        ggml_set_name(a, "a");
6448
6449        ggml_tensor * out = ggml_tri(ctx, a, tri_type);
6450
6451        ggml_set_name(out, "out");
6452
6453        return out;
6454    }
6455
6456    void initialize_tensors(ggml_context * ctx) override {
6457        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
6458            init_tensor_uniform(t, -1.0f, 1.0f);
6459        }
6460    }
6461};
6462
6463// GGML_OP_FILL
6464struct test_fill : public test_case {
6465    const ggml_type              type;
6466    const std::array<int64_t, 4> ne;
6467    float                        c;
6468
6469    std::string vars() override { return VARS_TO_STR3(type, ne, c); }
6470
6471    test_fill(float c, ggml_type type = GGML_TYPE_F32,
6472            std::array<int64_t, 4> ne = { 10, 10, 4, 3 })
6473        : type(type), ne(ne), c(c) {}
6474
6475    ggml_tensor * build_graph(ggml_context * ctx) override {
6476        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6477        ggml_set_param(a);
6478        ggml_set_name(a, "a");
6479
6480        ggml_tensor * out = ggml_fill(ctx, a, c);
6481
6482        ggml_set_name(out, "out");
6483
6484        return out;
6485    }
6486};
6487
6488// GGML_OP_SOLVE_TRI
6489struct test_solve_tri : public test_case {
6490    const ggml_type              type;
6491    const std::array<int64_t, 4> ne_lhs;
6492    const std::array<int64_t, 4> ne_rhs;
6493
6494    std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); }
6495
6496    uint64_t op_flops(ggml_tensor * t) override {
6497        GGML_UNUSED(t);
6498        int64_t n = ne_lhs[0];
6499        int64_t k = ne_rhs[0];
6500        int64_t batch = ne_lhs[2] * ne_lhs[3];
6501        // n * (n + 1) / 2 non-zero elements of lhs, 2 flops each, for each col of rhs
6502        return n * (n + 1) * k * batch;
6503    }
6504
6505    test_solve_tri(ggml_type type = GGML_TYPE_F32,
6506            std::array<int64_t, 4> ne_lhs = { 10, 10, 4, 3 },
6507            std::array<int64_t, 4> ne_rhs = { 3, 10, 4, 3 }
6508        )
6509        : type(type), ne_lhs(ne_lhs), ne_rhs(ne_rhs) {}
6510
6511    ggml_tensor * build_graph(ggml_context * ctx) override {
6512        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne_lhs[0], ne_lhs[1], ne_lhs[2], ne_lhs[3]);
6513        ggml_set_param(a);
6514        ggml_set_name(a, "a");
6515
6516        ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne_rhs[0], ne_rhs[1], ne_rhs[2], ne_rhs[3]);
6517        ggml_set_param(b);
6518        ggml_set_name(b, "b");
6519
6520        ggml_tensor * out = ggml_solve_tri(ctx, a, b, true, true, false);
6521        ggml_set_name(out, "out");
6522
6523        return out;
6524    }
6525
6526    void initialize_tensors(ggml_context * ctx) override {
6527        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
6528            if (strcmp(t->name, "a") == 0) {
6529                // note: avoid zeros in the diagonal
6530                init_tensor_tril(t, 0.1, 1.0f);
6531            } else {
6532                init_tensor_uniform(t, -1.0f, 1.0f);
6533            }
6534        }
6535    }
6536};
6537
6538// GGML_OP_DIAG
6539struct test_diag : public test_case {
6540    const ggml_type              type;
6541    const std::array<int64_t, 4> ne;
6542
6543    std::string vars() override { return VARS_TO_STR2(type, ne); }
6544
6545    test_diag(ggml_type type = GGML_TYPE_F32,
6546            std::array<int64_t, 4> ne = { 10, 1, 4, 3 })
6547        : type(type), ne(ne) {}
6548
6549    ggml_tensor * build_graph(ggml_context * ctx) override {
6550        GGML_ASSERT(ne[1] == 1);
6551        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6552        ggml_set_param(a);
6553        ggml_set_name(a, "a");
6554
6555        ggml_tensor * out = ggml_diag(ctx, a);
6556        ggml_set_name(out, "out");
6557
6558        return out;
6559    }
6560};
6561
6562
6563enum llm_norm_type {
6564    LLM_NORM,
6565    LLM_NORM_RMS,
6566};
6567
6568struct llama_hparams {
6569    uint32_t n_vocab;
6570    uint32_t n_embd;
6571    uint32_t n_head;
6572    uint32_t n_head_kv;
6573    static constexpr uint32_t n_layer = 1;
6574    uint32_t n_rot;
6575    uint32_t n_embd_head; // dimension of values (d_v)
6576    uint32_t n_ff;
6577
6578    float f_norm_eps;
6579    float f_norm_rms_eps;
6580
6581    // cparams
6582    static constexpr uint32_t n_ctx = 512; // user-specified context size
6583    static constexpr uint32_t n_ctx_orig = n_ctx;
6584
6585    // batch
6586    int32_t n_tokens;
6587
6588    // llm_build_context
6589    static constexpr int32_t n_kv    = 32; // size of KV cache to consider (n_kv <= n_ctx
6590    static constexpr int32_t kv_head = 1;  // index of where we store new KV data in the cache
6591
6592    uint32_t n_embd_gqa() const { // dimension of key embeddings across all k-v heads
6593        return n_embd_head * n_head_kv;
6594    }
6595};
6596
6597// LLM base class
6598struct test_llm : public test_case {
6599    llama_hparams hp;
6600
6601protected:
6602    test_llm(llama_hparams hp)
6603        : hp(std::move(hp)) {
6604    }
6605
6606public:
6607    struct ggml_tensor * llm_build_norm(
6608            struct ggml_context * ctx,
6609             struct ggml_tensor * cur,
6610             struct ggml_tensor * mw,
6611             struct ggml_tensor * mb,
6612                  llm_norm_type   type) {
6613        switch (type) {
6614            case LLM_NORM:     cur = ggml_norm    (ctx, cur, hp.f_norm_eps); break;
6615            case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hp.f_norm_rms_eps); break;
6616        }
6617        cur = ggml_mul(ctx, cur, mw);
6618        if (mb) {
6619            cur = ggml_add(ctx, cur, mb);
6620        }
6621        return cur;
6622    }
6623
6624    void llm_build_kv_store(
6625            struct ggml_context * ctx,
6626             struct ggml_tensor * k_l,
6627             struct ggml_tensor * v_l,
6628             struct ggml_tensor * k_cur,
6629             struct ggml_tensor * v_cur) {
6630        // compute the transposed [n_tokens, n_embd] V matrix
6631        struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, hp.n_embd_gqa(), hp.n_tokens));
6632
6633        struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, hp.n_tokens*hp.n_embd_gqa(),
6634                (ggml_row_size(k_l->type, hp.n_embd_gqa()))*hp.kv_head);
6635
6636        struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, hp.n_tokens, hp.n_embd_gqa(),
6637                (  hp.n_ctx)*ggml_element_size(v_l),
6638                (hp.kv_head)*ggml_element_size(v_l));
6639
6640        // important: storing RoPE-ed version of K in the KV cache!
6641        ggml_cpy(ctx, k_cur,   k_cache_view);
6642        ggml_cpy(ctx, v_cur_t, v_cache_view);
6643    }
6644
6645    struct ggml_tensor * llm_build_kqv(
6646            struct ggml_context * ctx,
6647             struct ggml_tensor * k_l,
6648             struct ggml_tensor * v_l,
6649             struct ggml_tensor * q_cur,
6650             struct ggml_tensor * kq_mask,
6651                        float     kq_scale) {
6652        struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
6653
6654        struct ggml_tensor * k =
6655            ggml_view_3d(ctx, k_l,
6656                    hp.n_embd_head, hp.n_kv, hp.n_head_kv,
6657                    ggml_row_size(k_l->type, hp.n_embd_gqa()),
6658                    ggml_row_size(k_l->type, hp.n_embd_head),
6659                    0);
6660
6661        struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
6662
6663        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f);
6664
6665        // split cached v into n_head heads
6666        struct ggml_tensor * v =
6667            ggml_view_3d(ctx, v_l,
6668                    hp.n_kv, hp.n_embd_head, hp.n_head_kv,
6669                    ggml_element_size(v_l)*hp.n_ctx,
6670                    ggml_element_size(v_l)*hp.n_ctx*hp.n_embd_head,
6671                    0);
6672
6673        struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
6674
6675        struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
6676
6677        struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, hp.n_embd_head*hp.n_head, hp.n_tokens);
6678
6679        struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
6680        cur = ggml_mul_mat(ctx, wo, cur);
6681
6682        return cur;
6683    }
6684
6685    void initialize_tensors(ggml_context * ctx) override {
6686        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
6687            if (t->type == GGML_TYPE_I32) {
6688                // pos
6689                std::vector<int> data(hp.n_tokens);
6690                for (int i = 0; i < hp.n_tokens; i++) {
6691                    data[i] = rand() % hp.n_ctx;
6692                }
6693                ggml_backend_tensor_set(t, data.data(), 0, hp.n_tokens * sizeof(int));
6694            } else {
6695                init_tensor_uniform(t);
6696            }
6697        }
6698    }
6699};
6700
6701// Llama
6702struct test_llama : public test_llm {
6703    static constexpr float freq_base = 10000.0f;
6704    static constexpr float freq_scale = 1.0f;
6705    static constexpr float ext_factor = 0.0f;
6706    static constexpr float attn_factor = 1.0f;
6707    static constexpr float beta_fast = 32.0f;
6708    static constexpr float beta_slow = 1.0f;
6709    bool fused;
6710
6711    std::string op_desc(ggml_tensor * t) override {
6712        GGML_UNUSED(t);
6713        return "LLAMA";
6714    }
6715
6716    std::string vars() override {
6717        auto n_tokens = hp.n_tokens;
6718        return VARS_TO_STR1(n_tokens);
6719    }
6720
6721    double max_nmse_err() override {
6722        return 2e-3;
6723    }
6724
6725    bool run_whole_graph() override { return fused; }
6726
6727    test_llama(int n_tokens = 1, bool fused = false)
6728        : test_llm({
6729            /*n_vocab        =*/ 32000,
6730            /*n_embd         =*/ 3200,
6731            /*n_head         =*/ 32,
6732            /*n_head_kv      =*/ 32,
6733            /*n_rot          =*/ 100,
6734            /*n_embd_head    =*/ 100,
6735            /*n_ff           =*/ 8640,
6736            /*f_norm_eps     =*/ 0.f,
6737            /*f_norm_rms_eps =*/ 1e-5f,
6738            /*n_tokens       =*/ n_tokens,
6739        })
6740        , fused(fused)
6741    {
6742    }
6743
6744    ggml_tensor * build_graph(ggml_context * ctx) override {
6745        struct ggml_tensor * cur;
6746        struct ggml_tensor * inpL;
6747
6748        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
6749
6750        // inp_pos - contains the positions
6751        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
6752
6753        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6754        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
6755
6756        ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
6757        ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
6758
6759        for (uint32_t il = 0; il < hp.n_layer; ++il) {
6760            struct ggml_tensor * inpSA = inpL;
6761
6762            // norm
6763            ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
6764            cur = llm_build_norm(ctx, inpL, attn_norm, nullptr, LLM_NORM_RMS);
6765
6766            // self-attention
6767            {
6768                ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
6769                ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
6770                ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
6771
6772                // compute Q and K and RoPE them
6773                struct ggml_tensor * Qcur = ggml_mul_mat(ctx, wq, cur);
6774                struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);
6775                struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
6776
6777                Qcur = ggml_rope_ext(
6778                    ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens), inp_pos, nullptr,
6779                    hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
6780                    ext_factor, attn_factor, beta_fast, beta_slow
6781                );
6782
6783                Kcur = ggml_rope_ext(
6784                    ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
6785                    hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
6786                    ext_factor, attn_factor, beta_fast, beta_slow
6787                );
6788
6789                llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
6790
6791                cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
6792            }
6793
6794            struct ggml_tensor * ffn_inp = ggml_add(ctx, cur, inpSA);
6795
6796            // feed-forward network
6797            ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
6798            cur = llm_build_norm(ctx, ffn_inp, ffn_norm, nullptr, LLM_NORM_RMS);
6799
6800            ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
6801            ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff,   hp.n_embd);
6802            ggml_tensor * ffn_up   = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
6803            struct ggml_tensor * tmp = ggml_mul_mat(ctx, ffn_up, cur);
6804            cur = ggml_mul_mat(ctx, ffn_gate, cur);
6805            cur = ggml_silu(ctx, cur);
6806            cur = ggml_mul(ctx, cur, tmp);
6807            cur = ggml_mul_mat(ctx, ffn_down, cur);
6808
6809            cur = ggml_add(ctx, cur, ffn_inp);
6810
6811            // input for next layer
6812            inpL = cur;
6813        }
6814
6815        cur = inpL;
6816
6817        ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
6818        cur = llm_build_norm(ctx, cur, output_norm, nullptr, LLM_NORM_RMS);
6819
6820        // lm_head
6821        ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_vocab);
6822        cur = ggml_mul_mat(ctx, output, cur);
6823
6824        return cur;
6825    }
6826};
6827
6828// Falcon
6829struct test_falcon : public test_llm {
6830    static constexpr float freq_base = 10000.0f;
6831    static constexpr float freq_scale = 1.0f;
6832    static constexpr float ext_factor = 0.0f;
6833    static constexpr float attn_factor = 1.0f;
6834    static constexpr float beta_fast = 32.0f;
6835    static constexpr float beta_slow = 1.0f;
6836
6837    std::string op_desc(ggml_tensor * t) override {
6838        GGML_UNUSED(t);
6839        return "FALCON";
6840    }
6841
6842    std::string vars() override {
6843        auto n_tokens = hp.n_tokens;
6844        return VARS_TO_STR1(n_tokens);
6845    }
6846
6847    double max_nmse_err() override {
6848        return 2e-3;
6849    }
6850
6851    test_falcon(int n_tokens = 1)
6852        : test_llm({
6853            /*n_vocab        =*/ 32000,
6854            /*n_embd         =*/ 3200,
6855            /*n_head         =*/ 50,
6856            /*n_head_kv      =*/ 1,
6857            /*n_rot          =*/ 64,
6858            /*n_embd_head    =*/ 64,
6859            /*n_ff           =*/ 8640,
6860            /*f_norm_eps     =*/ 1e-5f,
6861            /*f_norm_rms_eps =*/ 0.f,
6862            /*n_tokens       =*/ n_tokens,
6863        }) {
6864    }
6865
6866    ggml_tensor * build_graph(ggml_context * ctx) override {
6867        struct ggml_tensor * cur;
6868        struct ggml_tensor * inpL;
6869
6870        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
6871
6872        // inp_pos - contains the positions
6873        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
6874
6875        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
6876        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
6877
6878        ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
6879        ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
6880
6881        for (uint32_t il = 0; il < hp.n_layer; ++il) {
6882            // norm
6883            ggml_tensor * attn_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
6884            ggml_tensor * attn_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
6885            ggml_tensor * attn_norm = llm_build_norm(ctx, inpL, attn_norm_w, attn_norm_b, LLM_NORM);
6886
6887            // self-attention
6888            {
6889                cur = attn_norm;
6890
6891                ggml_tensor * wqkv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd + 2*hp.n_embd_gqa());
6892
6893                cur = ggml_mul_mat(ctx, wqkv, cur);
6894
6895                struct ggml_tensor * Qcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd,     hp.n_tokens, cur->nb[1], 0*sizeof(float)*(hp.n_embd)));
6896                struct ggml_tensor * Kcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd)));
6897                struct ggml_tensor * Vcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd + hp.n_embd_gqa())));
6898
6899                Qcur = ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens);
6900                Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
6901
6902                // using mode = 2 for neox mode
6903                Qcur = ggml_rope_ext(
6904                    ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
6905                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
6906                );
6907
6908                Kcur = ggml_rope_ext(
6909                    ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
6910                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
6911                );
6912
6913                llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
6914
6915                cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
6916            }
6917
6918            struct ggml_tensor * ffn_inp = cur;
6919
6920            // feed forward
6921            {
6922                ggml_tensor * ffn_up   = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
6923                ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
6924                cur = attn_norm;
6925                cur = ggml_mul_mat(ctx, ffn_up, cur);
6926                cur = ggml_gelu(ctx, cur);
6927                cur = ggml_mul_mat(ctx, ffn_down, cur);
6928            }
6929
6930            cur = ggml_add(ctx, cur, ffn_inp);
6931
6932            cur = ggml_add(ctx, cur, inpL);
6933
6934            // input for next layer
6935            inpL = cur;
6936        }
6937
6938        cur = inpL;
6939
6940        ggml_tensor * output_norm   = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
6941        ggml_tensor * output_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
6942        cur = llm_build_norm(ctx, cur, output_norm, output_norm_b, LLM_NORM);
6943
6944        // lm_head
6945        ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q8_0, hp.n_embd, hp.n_vocab);
6946        cur = ggml_mul_mat(ctx, output, cur);
6947
6948        return cur;
6949    }
6950};
6951
6952
6953// ###########################################
6954// ## Section 3: GGML Op Test Instantiation ##
6955// ###########################################
6956static const ggml_type all_types[] = {
6957    GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
6958    GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
6959    GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
6960    GGML_TYPE_Q8_0,
6961    GGML_TYPE_MXFP4,
6962    GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
6963    GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
6964    GGML_TYPE_Q6_K,
6965    // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
6966    GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
6967    GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
6968    GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
6969};
6970
6971static const ggml_type base_types[] = {
6972    GGML_TYPE_F32, GGML_TYPE_F16,
6973    GGML_TYPE_Q8_0, // for I8MM tests
6974    GGML_TYPE_Q4_0,
6975    GGML_TYPE_Q4_1, // for I8MM tests
6976    GGML_TYPE_Q4_K,
6977    GGML_TYPE_MXFP4, // TODO: or "other"
6978    GGML_TYPE_IQ2_XXS
6979};
6980
6981static const ggml_type other_types[] = {
6982    GGML_TYPE_Q4_1,
6983    GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
6984    GGML_TYPE_Q8_0,
6985    GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
6986    GGML_TYPE_Q5_K,
6987    GGML_TYPE_Q6_K,
6988    // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
6989    GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
6990    GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
6991    GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
6992    GGML_TYPE_BF16,
6993};
6994
6995#ifdef _MSC_VER
6996// Workaround long compile time with msvc
6997#pragma optimize("", off)
6998#endif
6999
7000// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
7001static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
7002    std::vector<std::unique_ptr<test_case>> test_cases;
7003    std::default_random_engine rng(0);
7004
7005    // unary ops
7006    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
7007        for (int v : {0, 1}) {
7008            for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
7009                if (op == GGML_UNARY_OP_XIELU) {
7010                    continue; // need extra params, separate test
7011                }
7012                test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 128, 2, 2, 2 }, v));
7013                test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 5, 7, 11, 13 }, v));
7014            }
7015        }
7016    }
7017
7018    // glu ops
7019    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
7020        for (int v : {0, 1}) {
7021            for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
7022                if (op == GGML_GLU_OP_SWIGLU_OAI) {
7023                    // SWIGLU_OAI is handled separately
7024                    continue;
7025                }
7026
7027                for (bool swapped : {false, true}) {
7028                    test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
7029                    test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
7030                }
7031
7032                test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v));
7033                test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v));
7034            }
7035        }
7036    }
7037
7038    for (int v : {0, 1}) {
7039        for (float alpha : {.5f, 1.702f}) {
7040            for (float limit : {2.0f, 7.0f}) {
7041                test_cases.emplace_back(new test_swiglu_oai(GGML_TYPE_F32, { 128, 2, 2, 2 }, v, alpha, limit));
7042            }
7043        }
7044    }
7045
7046    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_Q4_0}) {
7047        test_cases.emplace_back(new test_get_rows(type, 300*256,   5,         4,   1,   2, false));
7048        test_cases.emplace_back(new test_get_rows(type,     256,   80000, 70000,   2,   1, false));
7049        test_cases.emplace_back(new test_get_rows(type,     256,   5,         4, 700, 100, false));
7050    }
7051
7052    test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, 1, false));
7053    for (ggml_type type : all_types) {
7054        for (int b : {1, 7}) {
7055            for (bool v : {false, true}) {
7056                test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, 1, v));
7057            }
7058        }
7059    }
7060    for (int b : {1, 7}) {
7061        for (bool v : {false, true}) {
7062            test_cases.emplace_back(new test_get_rows(GGML_TYPE_I32, 256, 5, 4, b, 1, v));
7063        }
7064    }
7065
7066    test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 8, 2, 1, false));
7067    for (ggml_type type : all_types) {
7068        for (bool v : {false, true}) {
7069            test_cases.emplace_back(new test_get_rows_back(type, 256, 5, 4, 1, v));
7070        }
7071    }
7072    for (bool v : {false, true}) {
7073        test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
7074    }
7075
7076    test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, GGML_TYPE_I64, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));
7077    test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, GGML_TYPE_I32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));
7078    test_cases.emplace_back(new test_set_rows(GGML_TYPE_Q8_0, GGML_TYPE_I32, { 256, 5, 1, 3 }, { 1, 1, }, 1, false));
7079    for (ggml_type type : all_types) {
7080        for (int b : {1, 7}) {
7081            for (bool v : {false, true}) {
7082                test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5,  b, 3 }, { 1, 1, }, 1, v));
7083                test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 11, 1, b }, { 2, 3, }, 7, v));
7084
7085                test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 3*ggml_blck_size(type), 3, b, 1 }, { 2, 3, }, 2, v));
7086
7087                if (ggml_blck_size(type) == 1) {
7088                    test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 31, 3, b, 1 }, { 2, 3, }, 2, v));
7089                    test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 33, 5, 1, b }, { 2, 3, }, 1, v));
7090                }
7091            }
7092        }
7093    }
7094
7095    for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION }) {
7096        for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
7097            for (int ne2 : {1, 8, 512}) {
7098                test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, ne2, 1 }, mode));
7099                test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, ne2, 3 }, mode));
7100            }
7101        }
7102    }
7103
7104    for (ggml_type type_input : {GGML_TYPE_F32}) {
7105        for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
7106            for (int k0 : {1, 3}) {
7107                for (int k1 : {1, 3}) {
7108                    for (int s0 : {1, 2}) {
7109                        for (int s1 : {1, 2}) {
7110                            for (int p0 : {0, 1}) {
7111                                for (int p1 : {0, 1}) {
7112                                    test_cases.emplace_back(new test_pool2d(pool_type, type_input, {10, 10, 3, 1}, k0, k1, s0, s1, p0, p1));
7113                                }
7114                            }
7115                        }
7116                    }
7117                }
7118            }
7119        }
7120    }
7121
7122    for (ggml_type type_input : {GGML_TYPE_F32}) {
7123        for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
7124            for (int k0 : {1, 3}) {
7125                for (int s0 : {1, 2}) {
7126                    for (int p0 : {0, 1}) {
7127                        test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 10,  3, 2, 1 }, k0, s0, p0));
7128                        test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 11,  1, 3, 2 }, k0, s0, p0));
7129                        test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 128, 2, 1, 3 }, k0, s0, p0));
7130                    }
7131                }
7132            }
7133        }
7134    }
7135
7136#if 0
7137    // >4GB im2col destination. Too slow to run by default.
7138    // Test cases taken from Wan2.1 T2V 1.3B.
7139    test_cases.emplace_back(new test_im2col   (GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {832, 480, 192, 4}, {3, 3, 192, 96}, 1, 1, 1, 1, 1, 1, true));
7140    test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {834, 482, 6, 96},  {3, 3,3, 9216}, 96, 1, 1, 1, 0, 0, 0, 1, 1, 1, false));
7141#endif
7142
7143    // im2col 1D
7144    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
7145    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
7146    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
7147    for (int s0 : {1, 3}) {
7148        for (int p0 : {0, 3}) {
7149            for (int d0 : {1, 3}) {
7150                test_cases.emplace_back(new test_im2col(
7151                    GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
7152                    s0, 0, p0, 0, d0, 0, false));
7153            }
7154        }
7155    }
7156
7157    // im2col 2D
7158    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
7159    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
7160    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
7161    for (int s0 : {1, 3}) {
7162        for (int s1 : {1, 3}) {
7163            for (int p0 : {0, 3}) {
7164                for (int p1 : {0, 3}) {
7165                    for (int d0 : {1, 3}) {
7166                        for (int d1 : {1, 3}) {
7167                            test_cases.emplace_back(new test_im2col(
7168                                GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
7169                                s0, s1, p0, p1, d0, d1, true));
7170                        }
7171                    }
7172                }
7173            }
7174        }
7175    }
7176
7177    // extra tests for im2col 2D
7178    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
7179    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
7180    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
7181    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));
7182    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));
7183    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
7184    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
7185    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
7186    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
7187    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 1536, 729}, {2, 2, 1536, 4096}, 1, 1, 0, 0, 1, 1, true));
7188
7189    // im2col 3D
7190    test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
7191    test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
7192    test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
7193    for (int s0 : {1, 3}) {
7194        for (int s1 : {1, 3}) {
7195            for (int s2 : {1, 3}) {
7196                for (int p0 : {0, 3}) {
7197                    for (int p1 : {0, 3}) {
7198                        for (int p2 : {0, 3}) {
7199                            for (int d0 : {1, 3}) {
7200                                for (int d1 : {1, 3}) {
7201                                    for (int d2 : {1, 3}) {
7202                                        for (int IC : {1, 3}) {
7203                                            for (bool v : {false, true}) {
7204                                                test_cases.emplace_back(new test_im2col_3d(
7205                                                    GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 10, 3}, {3, 3, 3, 3},
7206                                                    IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, v));
7207                                            }
7208                                        }
7209                                    }
7210                                }
7211                            }
7212                        }
7213                    }
7214                }
7215            }
7216        }
7217    }
7218
7219// Conv_2D test cases
7220#ifdef DETAILED_TESTS
7221    // Probably we do not have enough time to execute these in the pipeline.
7222    uint32_t iwh_idx  = 0;
7223    uint32_t kwh_idx  = 1;
7224    uint32_t Cout_idx = 2;
7225    uint32_t Cin_idx  = 3;
7226    uint32_t B_idx    = 4;
7227
7228    std::vector<std::array<int, 5>> cases = {
7229  //{IWH, KWH, Cout, Cin, B}
7230  // K=CRS=NPQ=4096 conv_2d matmul performance
7231        {19,   4, 4096, 256, 16},
7232 // K=128, CRS=128, NPQ=4096
7233        { 19,  4, 128,  8,   16},
7234 // K=130, CRS=128, NPQ=4096
7235        { 19,  4, 130,  8,   16},
7236 // Edge case: K x CRS is small
7237        { 19,  2, 4,    4,   16},
7238 // A ConvNet's first layer
7239        { 224, 3, 8,    3,   1 },
7240 // A ConvNet's first layer with 2x2 convolution, and 1 channel
7241        { 224, 2, 8,    1,   1 },
7242 // A ConvNet's first layer with 2x2 convolution, and 1 channel, several images in the batch
7243        { 224, 2, 8,    1,   8 },
7244 // A middle layer of a ConvNet
7245        { 58,  3, 64,   32,  1 },
7246 // A middle layer of a ConvNet, several images in the batch
7247        { 58,  3, 64,   32,  8 },
7248 // A deep layer of a ConvNet, several images in the batch
7249        { 16,  3, 256,  128, 8 }
7250    };
7251
7252    for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
7253        for (auto act_case : cases) {
7254            test_cases.emplace_back(new test_conv_2d(
7255                { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
7256                { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
7257                kernel_type, 1, 1, 0, 0, 1, 1, false));
7258        }
7259    }
7260#endif
7261
7262    // CONV_2D:
7263    auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
7264        return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
7265    };
7266
7267    //uint32_t s0 = 3;
7268    uint32_t s1 = 5;
7269    uint32_t p0 = 5;
7270    //uint32_t p1 = 2;
7271    uint32_t d0 = 2;
7272    uint32_t d1 = 4;
7273
7274    for (uint32_t s0 : { 1, 3 }) {
7275        for (uint32_t p1 : { 2, 5 }) {
7276            for (uint32_t Cin : { 1, 25 }) {
7277                for (uint32_t Cout : { 1, 12 }) {
7278                    for (uint32_t KH : { 1, 2, 3, 11 }) {
7279                        for (uint32_t KW : { 1, 2, 3, 11 }) {
7280                            for (uint32_t H : { 1, 133 }) {
7281                                for (uint32_t W : { 1, 141 }) {
7282                                    if (calc_conv_output_size(W, KW, s0, p0, d0) > 0 &&
7283                                        calc_conv_output_size(H, KH, s1, p1, d1) > 0) {
7284                                        for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
7285                                            test_cases.emplace_back(new test_conv_2d(
7286                                                { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, kernel_type, s0, s1, p0, p1, d0, d1, false));
7287                                        }
7288                                    }
7289                                }
7290                            }
7291                        }
7292                    }
7293                }
7294            }
7295        }
7296    }
7297
7298    // sycl backend will limit task global_range < MAX_INT
7299    // test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
7300    // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)
7301    // these cases are verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
7302    // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
7303    // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
7304
7305    test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, false));
7306    test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, true));
7307    test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
7308    test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
7309
7310    // CONV_3D
7311    auto calc_conv_output_size_3d = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
7312        return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
7313    };
7314
7315    for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
7316        for (int N : {1, 2}) {
7317            for (int IC : {1, 3}) {
7318                for (int OC : {1, 4}) {
7319                    for (int s0 : {1, 2}) {
7320                        for (int p1 : {0, 1}) {
7321                            for (int d2 : {1, 2}) {
7322                                int64_t IW = 20, IH = 22, ID = 18;
7323                                int64_t KW = 3,  KH = 3,  KD = 3;
7324                                int s1 = s0, s2 = s0;
7325                                int p0 = p1, p2 = p1;
7326                                int d0 = d2, d1 = d2;
7327
7328                                if (calc_conv_output_size_3d(IW, KW, s0, p0, d0) <= 0 ||
7329                                    calc_conv_output_size_3d(IH, KH, s1, p1, d1) <= 0 ||
7330                                    calc_conv_output_size_3d(ID, KD, s2, p2, d2) <= 0) {
7331                                    continue;
7332                                }
7333                                test_cases.emplace_back(new test_conv_3d(
7334                                    N, IC, ID, IH, IW,
7335                                    OC, KD, KH, KW,
7336                                    s0, s1, s2, p0, p1, p2, d0, d1, d2,
7337                                    kernel_type));
7338
7339                                // Asymmetric kernel and params
7340                                int64_t asym_KW = 5, asym_KH = 1, asym_KD = 3;
7341                                int asym_s0 = 2, asym_s1 = 1, asym_s2 = 1;
7342                                int asym_p0 = 2, asym_p1 = 0, asym_p2 = 1;
7343                                int asym_d0 = 1, asym_d1 = 1, asym_d2 = 2;
7344
7345                                if (calc_conv_output_size_3d(IW, asym_KW, asym_s0, asym_p0, asym_d0) <= 0 ||
7346                                    calc_conv_output_size_3d(IH, asym_KH, asym_s1, asym_p1, asym_d1) <= 0 ||
7347                                    calc_conv_output_size_3d(ID, asym_KD, asym_s2, asym_p2, asym_d2) <= 0) {
7348                                    continue;
7349                                }
7350                                test_cases.emplace_back(new test_conv_3d(
7351                                    N, IC, ID, IH, IW,
7352                                    OC, asym_KD, asym_KH, asym_KW,
7353                                    asym_s0, asym_s1, asym_s2, asym_p0, asym_p1, asym_p2, asym_d0, asym_d1, asym_d2,
7354                                    kernel_type));
7355                            }
7356                        }
7357                    }
7358                }
7359            }
7360        }
7361        // Case with kernel size 1
7362        test_cases.emplace_back(new test_conv_3d(1, 4, 8, 8, 8, 8, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, kernel_type));
7363    }
7364
7365    for(uint32_t Cout : {1, 9}){
7366        for(uint32_t Cin : {1, 7}){
7367            for(uint32_t K : {1, 3, 1337}){
7368                for(uint32_t L : {1, 2, 13}){
7369                    for(uint32_t s0: {1, 2, 3}){
7370                        test_cases.emplace_back(new test_conv_transpose_1d({L,Cin,1,1}, {K,Cout,Cin,1}, s0, 0, 1));
7371                    }
7372                }
7373            }
7374        }
7375    }
7376
7377    test_cases.emplace_back(new test_conv_transpose_1d());
7378    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
7379    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
7380    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 1, 0, 1));
7381    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 2, 0, 1));
7382    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1));
7383    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
7384    test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
7385
7386    test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));
7387    test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
7388    test_cases.emplace_back(new test_conv_transpose_2d({129, 63, 35, 1}, {3, 3, 48, 35}, 1));
7389
7390    test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4,  500, 1, 1}));
7391    test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
7392
7393    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32,    1, 1, 1}));
7394    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32,  513, 1, 1}));
7395    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100,  10, 1, 1}));
7396    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
7397    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
7398    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
7399    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438,  3, 1, 1}));
7400
7401    for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
7402        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
7403        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
7404        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 2, 1, 1}));
7405        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 2, 1}));
7406        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 2}));
7407        test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
7408        test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
7409    }
7410
7411    for (bool view : {false, true}) {
7412        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view));
7413        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
7414        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
7415        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
7416        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
7417    }
7418
7419    test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
7420    test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
7421    test_cases.emplace_back(new test_dup(GGML_TYPE_I32));
7422    test_cases.emplace_back(new test_dup(GGML_TYPE_I16));
7423    test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {0, 2, 1, 3}));
7424    test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {0, 2, 1, 3})); // dup by rows
7425    test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {1, 0, 2, 3}));
7426    test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {1, 0, 2, 3})); // dup dst not-contiguous
7427    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10,  8, 3, 1}, {0, 2, 1, 3}));
7428    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10,  8, 3, 1}, {1, 2, 0, 3}));
7429
7430    for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
7431        test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
7432    }
7433
7434    for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
7435        test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
7436    }
7437
7438    // same-type copy
7439    for (ggml_type type : all_types) {
7440        const auto nk = ggml_blck_size(type);
7441
7442        for (int k = 1; k < 4; ++k) {
7443            test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}));
7444            test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 2, 1, 3}));
7445            test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 3, 1, 2}, {0, 2, 1, 3}));
7446        }
7447    }
7448
7449    for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
7450        for (ggml_type type_dst : all_types) {
7451            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
7452            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
7453        }
7454    }
7455    for (ggml_type type_src : all_types) {
7456        for (ggml_type type_dst : {GGML_TYPE_F32}) {
7457            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
7458            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
7459        }
7460    }
7461    for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
7462        for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) {
7463            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
7464        }
7465    }
7466    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}));
7467    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3}));
7468    test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}));
7469    test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3}));
7470    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7471    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7472    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 3}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7473    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7474    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7475    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7476    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7477    test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
7478    test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
7479    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
7480
7481    for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
7482        for (bool use_view_slice : { true, false }) {
7483            for (std::array<int64_t, 4> ne : std::initializer_list<std::array<int64_t, 4>>{ {2, 1, 1, 1}, {2, 1, 3, 5},
7484                {2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) {
7485                if (use_view_slice && (type_dst == GGML_TYPE_F16 || type_dst == GGML_TYPE_BF16)) {
7486                    continue; // TODO: add after WebGPU is fixed
7487                }
7488                test_cases.emplace_back(new test_cont(type_dst, ne, use_view_slice));
7489            }
7490        }
7491    }
7492
7493    auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr, bool perm1 = false) {
7494        for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
7495            test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1));
7496        }
7497    };
7498    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
7499        for (bool perm1 : {false, true}) {
7500            add_test_bin_bcast(type, {1,  1,   8,   1}, {1,  1, 1, 1}, perm1);
7501            add_test_bin_bcast(type, {1,  1,   1,   1}, {32, 1, 1, 1}, perm1);
7502            add_test_bin_bcast(type, {1,  1, 320, 320}, {1,  1, 1, 1}, perm1);
7503            add_test_bin_bcast(type, {10, 5,   1,   1}, {1,  1, 1, 1}, perm1);
7504            add_test_bin_bcast(type, {10, 5,   4,   1}, {1,  1, 1, 1}, perm1);
7505            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 1, 1}, perm1);
7506            add_test_bin_bcast(type, {10, 5,   4,   3}, {2,  1, 1, 1}, perm1);
7507            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  2, 1, 1}, perm1);
7508            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 2, 1}, perm1);
7509            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 1, 2}, perm1);
7510            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  1, 2, 2}, perm1);
7511            add_test_bin_bcast(type, {10, 5,   4,   3}, {1,  2, 2, 2}, perm1);
7512            add_test_bin_bcast(type, {10, 5,   4,   3}, {2,  2, 2, 2}, perm1);
7513        }
7514
7515        // test case for k_bin_bcast_unravel in CUDA backend
7516        add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});
7517
7518        // stable diffusion
7519        add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
7520        add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});
7521        add_test_bin_bcast(type, {1280, 16, 16, 1}, {1, 1, 1, 1});
7522        add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 256, 1, 1});
7523        add_test_bin_bcast(type, {1, 1, 1280, 1}, {16, 16, 1, 1});
7524        add_test_bin_bcast(type, {16, 16, 1280, 1}, {1, 1, 1, 1});
7525        add_test_bin_bcast(type, {1, 1, 1920, 1}, {16, 16, 1, 1});
7526        add_test_bin_bcast(type, {1, 1, 2560, 1}, {16, 16, 1, 1});
7527        add_test_bin_bcast(type, {1, 1, 1280, 1}, {32, 32, 1, 1});
7528        add_test_bin_bcast(type, {1, 1, 1920, 1}, {32, 32, 1, 1});
7529        add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1});
7530        add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1});
7531        add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1});
7532        add_test_bin_bcast(type, {64, 262144, 1, 1}, {1, 1, 1, 1});
7533        //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1});
7534        //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
7535    }
7536
7537    // single inplace tests, especially important for WebGPU backend since kernels for inplace vs. not are different
7538    test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
7539    test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
7540    test_cases.emplace_back(new test_bin_bcast(ggml_sub_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
7541    test_cases.emplace_back(new test_bin_bcast(ggml_div_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
7542
7543    // fusion
7544    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2));
7545    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 2, 1, 1}, 3));
7546    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1}, 4));
7547    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 2}, 5));
7548    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 2}, 6));
7549    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 2, 2}, 7));
7550    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {2, 2, 2, 2}, 8));
7551    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
7552
7553    test_cases.emplace_back(new test_add1());
7554    test_cases.emplace_back(new test_add1(GGML_TYPE_F32, {1024, 1024, 1, 1}));
7555    test_cases.emplace_back(new test_scale());
7556    test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
7557    test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test
7558    test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {100, 10, 10, 10}, 2.0f, 1.0f));
7559    test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
7560    test_cases.emplace_back(new test_silu_back());
7561
7562    for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f }) {
7563        for (uint32_t n : { 64, 1025 }) {
7564            for (bool v : { false, true }) {
7565                test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
7566                test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
7567            }
7568            test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
7569            test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
7570        }
7571    }
7572
7573    // in-place tests
7574    test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true));
7575
7576    for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f }) {
7577        for (uint32_t n : { 64, 1025 }) {
7578            test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
7579            test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
7580            test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
7581            test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
7582            test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
7583            test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
7584        }
7585    }
7586    for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
7587        for (bool multi_add : {false, true}) {
7588            test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add));
7589        }
7590        test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false));
7591    }
7592
7593    for (auto multi_add : {false, true}) {
7594        for (auto set_rows : {false, true}) {
7595            for (auto rope : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX}) {
7596                test_cases.emplace_back(new test_rms_norm_mul_rope({768, 1, 1, 1}, 1e-6f, multi_add, set_rows, rope));
7597                test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 1, 1}, 1e-6f, multi_add, set_rows, rope));
7598                test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 5, 1}, 1e-6f, multi_add, set_rows, rope));
7599                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 2, 1}, 1e-6f, multi_add, set_rows, rope));
7600                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 2, 1}, 1e-6f, multi_add, set_rows, rope));
7601                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 50, 1}, 1e-6f, multi_add, set_rows, rope));
7602                test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 50, 1}, 1e-6f, multi_add, set_rows, rope));
7603                test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
7604                test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
7605            }
7606        }
7607    }
7608    for (int64_t d_conv : {3, 4, 9}) {
7609        for (int64_t d_inner: {1024, 1536, 2048}) {
7610            test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
7611            test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
7612            test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
7613        }
7614    }
7615
7616    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
7617    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
7618    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64,  8, 2, 32, 4)); // Falcon-H1
7619
7620    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
7621    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
7622    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
7623    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
7624
7625    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 1, 1));
7626    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 1));
7627    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 32, 4));
7628    test_cases.emplace_back(new test_rwkv_wkv7(GGML_TYPE_F32, 32, 64, 128, 4));
7629
7630    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
7631    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
7632    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
7633    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
7634
7635#if 0
7636    // > 4GB A matrix. Too slow to be enabled by default.
7637    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16,  900000,  3, 2592, {1, 1}, {1, 1}));
7638    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 96, 2592, {1, 1}, {1, 1}));
7639    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000,  3, 2592, {1, 1}, {1, 1}));
7640    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000,  1, 2592, {1, 1}, {1, 1}));
7641
7642    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 2, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0
7643    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 1, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0
7644    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 1, 5120, {128, 1}, {1, 1}));
7645    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 512, 5120, {128, 1}, {1, 1}));
7646#endif
7647
7648    for (ggml_type type_a : all_types) {
7649        for (int i = 1; i < 10; ++i) {
7650            test_cases.emplace_back(new test_mul_mat(type_a,    GGML_TYPE_F32, 16,  i, 256, { 1,  1}, {1, 1}));
7651        }
7652    }
7653
7654#if 0
7655    {
7656        // Test paths in OpenCL
7657        std::vector<int> ns = {32, 64, 128, 256, 512, 1024, 4096};
7658        std::vector<int> ks = {896, 1536, 4096};
7659        for (auto n : ns) {
7660            for (auto k : ks) {
7661                test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 1024, n, k, {1, 1}, {1, 1}));
7662            }
7663        }
7664    }
7665#endif
7666
7667#if 1
7668    for (ggml_type type_a : base_types) {
7669        for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
7670            std::vector<int> ks = { 256 };
7671            if (ggml_blck_size(type_a) == 1) {
7672                ks.push_back(4);
7673            }
7674            for (auto k : ks) {
7675                // test cases without permutation
7676                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {1, 1}, {1, 1}));
7677                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {1, 1}, {2, 1}));
7678                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {1, 1}, {1, 2}));
7679                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 1}, {1, 1}));
7680                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 1}, {2, 1}));
7681                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {1, 1}));
7682                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {2, 1}));
7683                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {1, 2}));
7684                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {2, 2}));
7685
7686                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
7687                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
7688                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 2}));
7689                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {1, 1}));
7690                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {2, 1}));
7691                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 1}));
7692                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 1}));
7693                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 2}));
7694                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 2}));
7695
7696                // test cases with permutation
7697                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
7698                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
7699                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
7700
7701                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
7702                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
7703                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
7704
7705                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
7706                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
7707                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
7708            }
7709
7710            // test cases with large ne00/ne10 to cover stream-k fixup
7711            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 1024, {3, 2}, {1, 1}));
7712            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 1024, {3, 2}, {1, 1}));
7713            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
7714
7715            // test cases with large batch size
7716            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {1536, 1}, {1, 1}));
7717        }
7718    }
7719    for (ggml_type type_a : other_types) {
7720        for (ggml_type type_b : {GGML_TYPE_F32}) {
7721            if (ggml_blck_size(type_a) != 256) {
7722                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1,  1}, {1, 1}));
7723            }
7724            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1,  1}, {1, 1}));
7725        }
7726    }
7727#else
7728    // m = a rows
7729    // n = b rows
7730    // k = cols
7731    std::uniform_int_distribution<> dist_m(1, 128);
7732    std::uniform_int_distribution<> dist_n(16, 128);
7733    std::uniform_int_distribution<> dist_k(1, 16);
7734    for (int i = 0; i < 1000; i++) {
7735        for (ggml_type type_a : all_types) {
7736            for (ggml_type type_b : {GGML_TYPE_F32}) {
7737                int m = dist_m(rng);
7738                int n = dist_n(rng);
7739                int k = dist_k(rng) * ggml_blck_size(type_a);
7740                test_cases.emplace_back(new test_mul_mat(type_a, type_b, m, n, k, { 1,  1}, {1, 1}));
7741            }
7742        }
7743    }
7744#endif
7745
7746    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,  128, { 8,  1}, {1, 1}));
7747    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,  128, { 8,  1}, {4, 1}));
7748    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,   64, { 8,  1}, {4, 1}));
7749    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,   64, { 8,  1}, {4, 1}));
7750    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 45, 128, { 8,  1}, {4, 1}));
7751    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45,  64, { 8,  1}, {4, 1}));
7752    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1,  1}, {4, 1}, {0, 2, 1, 3}));
7753    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67,  {1,  1}, {4, 1}, {0, 2, 1, 3}));
7754    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1,  1}, {1, 1}, {0, 1, 2, 3}, 64, 3));
7755    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
7756
7757    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 576, 512, 576, {1,1}, {1,1}));
7758    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 1, 2048, 8192, {1,  1}, {1, 1}));
7759    for (ggml_type type_a : all_types) {
7760        test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 1, 64, 256, {1,  1}, {1, 1}));
7761    }
7762
7763#if 0
7764    // test the mat-mat path for Metal
7765    for (int k = 1; k < 512; ++k) {
7766        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 127, k, {12,1}, {1,1}));
7767        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 127, k, {12,1}, {1,1}));
7768        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 77, k, {12,1}, {1,1}));
7769        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, k, {12,1}, {1,1}));
7770        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 128, k, {12,1}, {1,1}));
7771        test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 128, k, {12,1}, {1,1}));
7772        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 50, 200, k));
7773        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, true, 50, 200, k));
7774        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F32, GGML_TYPE_F32, 16, 16, false, 50, 200, k));
7775        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F32, GGML_TYPE_F32, 16, 16, true, 50, 200, k));
7776    }
7777#endif
7778
7779    for (auto bs2 : {1,3}) {
7780        for (auto bs : {1,2,4,8}) {
7781            for (auto nr : {1,4}) {
7782                for (uint32_t m = 0; m < 2; ++m) {
7783                    for (uint32_t k = 0; k < 2; ++k) {
7784                        for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
7785                            test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k,  {bs,  bs2}, {nr, 1}, {0, 2, 1, 3}));
7786                            test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m,  1, 1056 + k, {bs,  bs2}, {nr, 1}, {0, 1, 2, 3}, 2*1056 + k));
7787                        }
7788                    }
7789                }
7790            }
7791        }
7792    }
7793
7794    // sycl backend will limit task global_range < MAX_INT
7795    // test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)
7796    // however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.)
7797    // this case is verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
7798    // test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 512, 262144, 9216, {1, 1}, {1, 1}));
7799
7800    // test large experts*tokens
7801    for (bool b : {false, true}) {
7802        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
7803        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));
7804        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 50, 200, 64));
7805    }
7806
7807    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
7808    test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
7809
7810    // gpt-oss issue with Vulkan mmq_id
7811    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
7812
7813    for (ggml_type type_a : base_types) {
7814        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
7815            for (int n_mats : {4, 8}) {
7816                for (int n_used : {1, 2, 4}) {
7817                    for (bool b : {false, true}) {
7818                        for (int n : {1, 4, 5, 17, 32, 129}) {
7819                            int m = 512;
7820                            int k = 256;
7821                            test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
7822                        }
7823                    }
7824                }
7825            }
7826        }
7827    }
7828
7829    for (ggml_type type_a : other_types) {
7830        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
7831            for (int n_mats : {4}) {
7832                for (int n_used : {2}) {
7833                    for (bool b : {false}) {
7834                        for (int n : {1, 32}) {
7835                            int m = 512;
7836                            int k = 256;
7837                            test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
7838                        }
7839                    }
7840                }
7841            }
7842        }
7843    }
7844
7845    for (int bs : {1, 4, 512}) {
7846        for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q4_K}) {
7847            for (ggml_type type_b : {GGML_TYPE_F32}) {
7848                // test with mul after (ffn_moe_weighted)
7849                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1, true));
7850            }
7851        }
7852    }
7853
7854    for (ggml_type type_a : base_types) {
7855        for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
7856            for (int n : {1, 16}) {
7857                for (int k : {1, 16}) {
7858                    for (int bs2 : {1, 3}) {
7859                        for (int bs3 : {1, 3}) {
7860                            for (int nr2 : {1, 2}) {
7861                                for (int nr3 : {1, 2}) {
7862                                    test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, n, k, {bs2, bs3}, {nr2, nr3}));
7863                                }
7864                            }
7865                        }
7866                    }
7867                }
7868            }
7869        }
7870    }
7871
7872    // add_id
7873    for (ggml_type type_a : {GGML_TYPE_F32}) {
7874        for (ggml_type type_b : {GGML_TYPE_F32}) {
7875            for (int n_mats : {4, 8}) {
7876                for (int n_used : {1, 2, 4}) {
7877                    for (int n_embd : {32, 129}) {
7878                        for (int n_token : {1, 32, 129}) {
7879                            test_cases.emplace_back(new test_add_id(type_a, type_b, n_embd, n_mats, n_used, n_token));
7880                        }
7881                    }
7882                }
7883            }
7884        }
7885    }
7886
7887    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
7888        test_cases.emplace_back(new test_sqr       (type));
7889        test_cases.emplace_back(new test_sqrt      (type));
7890        test_cases.emplace_back(new test_log       (type));
7891        test_cases.emplace_back(new test_sin       (type));
7892        test_cases.emplace_back(new test_cos       (type));
7893        test_cases.emplace_back(new test_clamp     (type));
7894        test_cases.emplace_back(new test_leaky_relu(type));
7895        test_cases.emplace_back(new test_floor     (type));
7896        test_cases.emplace_back(new test_ceil      (type));
7897        test_cases.emplace_back(new test_round     (type));
7898        test_cases.emplace_back(new test_trunc     (type));
7899        test_cases.emplace_back(new test_sqr       (type, {7, 1, 5, 3}));
7900        test_cases.emplace_back(new test_sqr       (type, {1024, 1024, 1, 1}));
7901        test_cases.emplace_back(new test_sqrt      (type, {7, 1, 5, 3}));
7902        test_cases.emplace_back(new test_sqrt      (type, {1024, 1024, 1, 1}));
7903        test_cases.emplace_back(new test_log       (type, {7, 1, 5, 3}));
7904        test_cases.emplace_back(new test_log       (type, {1024, 1024, 1, 1}));
7905        test_cases.emplace_back(new test_sin       (type, {7, 1, 5, 3}));
7906        test_cases.emplace_back(new test_sin       (type, {1024, 1024, 1, 1}));
7907        test_cases.emplace_back(new test_cos       (type, {7, 1, 5, 3}));
7908        test_cases.emplace_back(new test_cos       (type, {1024, 1024, 1, 1}));
7909        test_cases.emplace_back(new test_clamp     (type, {7, 1, 5, 3}));
7910        test_cases.emplace_back(new test_clamp     (type, {1024, 1024, 1, 1}));
7911        test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
7912        test_cases.emplace_back(new test_leaky_relu(type, {1024, 1024, 1, 1}));
7913        test_cases.emplace_back(new test_floor     (type, {7, 1, 5, 3}));
7914        test_cases.emplace_back(new test_floor     (type, {1024, 1024, 1, 1}));
7915        test_cases.emplace_back(new test_ceil      (type, {7, 1, 5, 3}));
7916        test_cases.emplace_back(new test_ceil      (type, {1024, 1024, 1, 1}));
7917        test_cases.emplace_back(new test_round     (type, {7, 1, 5, 3}));
7918        test_cases.emplace_back(new test_round     (type, {1024, 1024, 1, 1}));
7919        test_cases.emplace_back(new test_trunc     (type, {7, 1, 5, 3}));
7920        test_cases.emplace_back(new test_trunc     (type, {1024, 1024, 1, 1}));
7921    }
7922
7923    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
7924    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 1}, 5));
7925    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 2}, 5));
7926
7927#if 0
7928    std::uniform_int_distribution<> dist_ne1(1, 50);
7929    int exponent = 1;
7930    while (exponent < (1 << 17)) {
7931        std::uniform_int_distribution<> dist_ne0(exponent, 2*exponent);
7932
7933        for (int n = 0; n < 10; ++n) {
7934            int64_t ne0 = dist_ne0(rng);
7935            int64_t ne1 = dist_ne1(rng);
7936            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));
7937        }
7938
7939        exponent <<= 1;
7940    }
7941#endif
7942    for (bool mask : {false, true}) {
7943        for (bool sinks : {false, true}) {
7944            for (float max_bias : {0.0f, 8.0f}) {
7945                if (!mask && max_bias > 0.0f) continue;
7946                for (float scale : {1.0f, 0.1f}) {
7947                    for (int64_t ne0 : {16, 1024}) {
7948                        for (int64_t ne1 : {16, 1024}) {
7949                            if (mask) {
7950                                for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
7951                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
7952                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
7953
7954                                    if (ne0 <= 32 && ne1 <= 32) {
7955                                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 3}, mask, sinks, m_prec, {3, 1}, scale, max_bias));
7956                                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {2, 3}, scale, max_bias));
7957                                    }
7958                                }
7959                            } else {
7960                                /* The precision of mask here doesn't matter as boolean mask is false */
7961                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
7962                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
7963                            }
7964                        }
7965                    }
7966                }
7967            }
7968            // inplace tests
7969            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f, true));
7970            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f, true));
7971        }
7972    }
7973    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
7974    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
7975    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
7976    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
7977    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
7978    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
7979    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
7980
7981    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true,   true,  GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
7982    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true,   true,  GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
7983    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 1, 1, 1}, false,  false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
7984    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 4, 1, 1}, false,  false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
7985    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {643251, 3, 1, 1}, false,  false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
7986
7987    for (float max_bias : {0.0f, 8.0f}) {
7988        for (float scale : {1.0f, 0.1f}) {
7989            for (int64_t ne0 : {16, 1024}) {
7990                for (int64_t ne1 : {16, 1024}) {
7991                    test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, scale, max_bias));
7992                    test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, scale, max_bias));
7993                    test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0,   ne1,   2, 3}, scale, max_bias));
7994                }
7995            }
7996        }
7997    }
7998
7999    for (bool fw : {true, false}) { // fw == forward
8000        bool all = true;
8001
8002        for (float fs : { 1.0f, 1.4245f }) {
8003            for (float ef : { 0.0f, 0.7465f }) {
8004                for (float af : { 1.0f, 1.4245f }) {
8005                    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
8006                        for (bool ff : {false, true}) { // freq_factors
8007                            for (float v : { 0, 1 }) {
8008                                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 7B
8009
8010                                if (all) {
8011                                    test_cases.emplace_back(new test_rope(type, {128,  40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B
8012                                    test_cases.emplace_back(new test_rope(type, {128,  52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B
8013                                    test_cases.emplace_back(new test_rope(type, {128,  64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B
8014                                    test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
8015                                }
8016
8017                                if (all) {
8018                                    test_cases.emplace_back(new test_rope(type, { 64,   1, 2, 1},  64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
8019                                    test_cases.emplace_back(new test_rope(type, { 64,  71, 2, 1},  64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
8020                                    test_cases.emplace_back(new test_rope(type, { 64,   8, 2, 1},  64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
8021
8022                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
8023                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
8024                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 4, 1},  32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
8025
8026                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
8027                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
8028                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 4, 1},  32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
8029                                    test_cases.emplace_back(new test_rope(type, { 16, 16, 8192, 1},  16, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw));
8030                                }
8031
8032                                if (all) {
8033                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
8034                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
8035                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1},  20, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw));
8036                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1},  32, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw));
8037                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B)
8038                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 7B)
8039                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1},  20, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw));
8040                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1},  32, GGML_ROPE_TYPE_IMROPE,  512, fs, ef, af, ff, v, fw));
8041                                    test_cases.emplace_back(new test_rope(type, { 80,  16, 2, 1},  80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
8042                                    test_cases.emplace_back(new test_rope(type, {128,  16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)
8043                                    test_cases.emplace_back(new test_rope(type, {16, 16, 8192, 1}, 16, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
8044                                }
8045
8046                                test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1},  64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
8047                            }
8048                        }
8049
8050                        all = false;
8051                    }
8052                }
8053            }
8054        }
8055    }
8056
8057    // single inplace test per type/mode/ff
8058    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
8059        for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_IMROPE, GGML_ROPE_TYPE_VISION}) {
8060            for (bool ff : {false, true}) {
8061                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));
8062                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 1, true, true));
8063                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 3}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 1, true, true));
8064            }
8065        }
8066    }
8067
8068    for (int v : { 0, 1, 2, 3 }) {
8069        for (int dim : { 0, 1, 2, 3, }) {
8070            test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
8071            test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim, v));
8072        }
8073    }
8074
8075    for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
8076        for (uint32_t i = 4; i <= 1024*1024; i *= 2) {
8077            test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1}));
8078            test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i, 1, 1, 1}));
8079        }
8080        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
8081        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
8082        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order));
8083        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order));
8084        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order));
8085        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order));
8086        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order));
8087        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order));
8088        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
8089    }
8090
8091    for (int n = 1; n < 5; ++n) {
8092        for (int k = 1; k <= n; ++k) {
8093            test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {n, 2, 1, 3}, k, true));
8094        }
8095    }
8096    for (int i = 0; i < 20; ++i) {
8097        for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) {
8098            if (k <= 1<<i) {
8099                test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
8100                test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
8101                test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k, true));
8102            }
8103        }
8104    }
8105    for (int k : {1, 2, 3, 7, 15}) {
8106        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
8107        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
8108        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1023, 2, 1, 3}, k));
8109        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1024, 2, 1, 3}, k));
8110        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {1025, 2, 1, 3}, k));
8111        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16384, 1, 1, 1}, k));
8112        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2047, 2, 1, 3}, k));
8113        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2048, 2, 1, 3}, k));
8114        test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2049, 2, 1, 3}, k));
8115    }
8116
8117    // exhaustive top_k tests
8118    //for (int i = 1; i < 9999; ++i) {
8119    //    test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));
8120    //}
8121
8122    for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC, ggml_scale_mode(GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)}) {
8123        test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
8124        test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
8125        test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5,  7, 11}, {5, 7, 11, 13}, mode));
8126        test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5,  7, 11}, mode));
8127    }
8128    for (ggml_scale_mode mode : {GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
8129        test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS)));
8130        test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS)));
8131        test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, (ggml_scale_mode)(mode | GGML_SCALE_FLAG_ALIGN_CORNERS)));
8132    }
8133
8134    test_cases.emplace_back(new test_sum());
8135    test_cases.emplace_back(new test_sum_rows());
8136    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3}));  // row-contiguous but non-contiguous
8137    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
8138    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
8139    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
8140    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
8141    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
8142    test_cases.emplace_back(new test_mean());
8143    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
8144    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
8145    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
8146    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
8147    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
8148    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
8149    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
8150    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
8151    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
8152    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
8153    test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
8154    test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
8155    test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));
8156    test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
8157    test_cases.emplace_back(new test_acc());
8158    test_cases.emplace_back(new test_pad());
8159    test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular
8160    test_cases.emplace_back(new test_pad_ext());
8161    test_cases.emplace_back(new test_pad_reflect_1d());
8162    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
8163    test_cases.emplace_back(new test_roll());
8164    test_cases.emplace_back(new test_arange());
8165    test_cases.emplace_back(new test_arange(GGML_TYPE_F32, 0.0f, 1048576.0f, 1.0f));
8166    test_cases.emplace_back(new test_timestep_embedding());
8167    test_cases.emplace_back(new test_leaky_relu());
8168
8169    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 }));
8170    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 }));
8171    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 }));
8172    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));
8173    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 }));
8174    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 }));
8175    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 }));
8176    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 512, 5, 4, 3 }));
8177    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1023, 5, 4, 3 }));
8178    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 5, 4, 3 }));
8179    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2047, 5, 4, 3 }));
8180    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 }));
8181    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 }));
8182    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 }));
8183    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20481, 4, 1, 1 }));
8184
8185    test_cases.emplace_back(new test_xielu());
8186
8187    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER));
8188    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG));
8189    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER));
8190    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG));
8191
8192    test_cases.emplace_back(new test_fill(0.0f));
8193    test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));
8194    test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
8195    test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 }));
8196
8197    test_cases.emplace_back(new test_diag());
8198    test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 79, 1, 19, 13 }));
8199    test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 256, 1, 8, 16 }));
8200
8201    test_cases.emplace_back(new test_solve_tri());
8202    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));
8203    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 17, 17, 2, 4 }, { 9, 17, 2, 4 }));
8204    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 30, 30, 7, 1 }, { 8, 30, 7, 1 }));
8205    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
8206    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
8207    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 64, 64, 2, 2 }));
8208    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 79, 79, 5, 3 }, { 417, 79, 5, 3 }));
8209    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));
8210    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 80, 80, 2, 8 }));
8211    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 79, 80, 2, 8 }));
8212    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 81, 80, 2, 8 }));
8213    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 80, 80, 8, 8 }));
8214    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 79, 80, 8, 8 }));
8215    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 81, 80, 8, 8 }));
8216    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 84, 84, 4, 4 }, { 32, 84, 4, 4 }));
8217    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 95, 95, 8, 8 }, { 40, 95, 8, 8 }));
8218    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
8219    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 }));
8220    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 32, 128, 4, 4 }));
8221    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 3, 4 }, { 32, 128, 3, 4 }));
8222    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 32, 128, 4, 1 }));
8223    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 200, 64, 4, 4 }));
8224    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 384, 64, 4, 4 }));
8225
8226    for (int tfrm : {0, 1, 2}) {
8227        for (bool circular : {false, true}) {
8228            test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, tfrm, circular));
8229            test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, tfrm, circular));
8230        }
8231    }
8232
8233    for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
8234        for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
8235            if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
8236            if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
8237            if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
8238
8239            for (bool mask : { true, false } ) {
8240                for (bool sinks : { true, false } ) {
8241                    for (float max_bias : { 0.0f, 8.0f }) {
8242                        if (!mask && max_bias > 0.0f) continue;
8243                        for (float logit_softcap : {0.0f, 10.0f}) {
8244                            if (hsk != 128 && logit_softcap != 0.0f) continue;
8245                            for (int nh : { 1, 4 }) {
8246                                if (nh == 1 && hsk != 576) continue; // GLM 4.7 Flash
8247                                for (int nr3 : { 1, 3, }) {
8248                                    if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
8249                                    for (int nr2 : { 1, 4, 12, 20 }) {
8250                                        if (nr2 == 12 && hsk != 128) continue;
8251                                        if (nr2 == 20 && (nh != 1 || hsk != 576)) continue;
8252                                        //for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
8253                                        for (int kv : { 113, 512, 1024, }) {
8254                                            if (nr2 != 1 && kv != 512) continue;
8255                                            for (int nb : { 1, 3, 32, 35, }) {
8256                                                for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
8257                                                    if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
8258                                                    for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
8259                                                        if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;
8260                                                        test_cases.emplace_back(new test_flash_attn_ext(
8261                                                                    hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
8262                                                        // run fewer test cases permuted
8263                                                        if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
8264                                                            test_cases.emplace_back(new test_flash_attn_ext(
8265                                                                        hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
8266                                                        }
8267                                                    }
8268                                                }
8269                                            }
8270                                        }
8271                                    }
8272                                }
8273                            }
8274                        }
8275                    }
8276                }
8277            }
8278        }
8279    }
8280
8281    test_cases.emplace_back(new test_cross_entropy_loss     (GGML_TYPE_F32, {   10, 5, 4, 3}));
8282    test_cases.emplace_back(new test_cross_entropy_loss     (GGML_TYPE_F32, {30000, 1, 1, 1}));
8283    test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {   10, 5, 4, 3}));
8284    test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
8285
8286    test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
8287    test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));
8288
8289    for (ggml_type type : base_types) {
8290        for (bool with_gate : {false, true}) {
8291            for (bool use_id : {false, true}) {
8292                for (bool b : {false, true}) {
8293                    if (!use_id && b) {
8294                        continue;
8295                    }
8296                    for (bool with_bias : {false, true}) {
8297                        if (!with_gate && !with_bias) {
8298                            continue;
8299                        }
8300                        for (ggml_glu_op glu_op : {GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU}) {
8301                            if (!with_bias && glu_op == GGML_GLU_OP_SWIGLU_OAI) {
8302                                continue;
8303                            }
8304                            if (!with_gate && glu_op != GGML_GLU_OP_SWIGLU) {
8305                                continue;
8306                            }
8307                            test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
8308                                use_id, 16, 8, b, with_bias, with_gate));
8309                            test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
8310                                use_id, 16, 8, b, with_bias, with_gate, {1, 1}));
8311                        }
8312                    }
8313                }
8314            }
8315        }
8316    }
8317
8318    for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) {
8319        for (bool with_norm : {false, true}) {
8320            for (bool bias_probs : {false, true}) {
8321                for (float scale_w : {0.0f, 2.0f}) {
8322                    test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w));
8323                    test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
8324                    test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
8325                    test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
8326                    test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
8327                    test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
8328                    test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
8329                    test_cases.emplace_back(new test_topk_moe({160, 4, 1, 1}, 160, with_norm, bias_probs, gate, scale_w));
8330                }
8331            }
8332        }
8333    }
8334
8335#if 0
8336    // these tests are disabled to save execution time, sbut they can be handy for debugging
8337    test_cases.emplace_back(new test_llama(2, true));
8338    test_cases.emplace_back(new test_llama(1));
8339    test_cases.emplace_back(new test_llama(2));
8340    test_cases.emplace_back(new test_falcon(1));
8341    test_cases.emplace_back(new test_falcon(2));
8342#endif
8343
8344    return test_cases;
8345}
8346#ifdef _MSC_VER
8347#pragma optimize("", on)
8348#endif
8349
8350// Test cases for performance evaluation: should be representative of real-world use cases
8351static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
8352    std::vector<std::unique_ptr<test_case>> test_cases;
8353
8354    // Conv2d: K=CRS=NPQ=4096 matmul performance
8355    uint32_t                        iwh_idx  = 0;
8356    uint32_t                        kwh_idx  = 1;
8357    uint32_t                        Cout_idx = 2;
8358    uint32_t                        Cin_idx  = 3;
8359    uint32_t                        B_idx    = 4;
8360    std::vector<std::array<int, 5>> cases    = {
8361  //{IWH, KWH, Cout, Cin, B}
8362  // K=CRS=NPQ=4096 conv2d matmul performance
8363        {19,   4, 4096, 256, 16},
8364 // K=128, CRS=128, NPQ=4096
8365        { 19,  4, 128,  8,   16},
8366 // K=130, CRS=128, NPQ=4096
8367        { 19,  4, 130,  8,   16},
8368 // Edge case: K x CRS is small
8369        { 19,  2, 4,    4,   16},
8370 // A ConvNet's first layer
8371        { 224, 3, 8,    3,   1 },
8372 // A ConvNet's first layer with 2x2 convolution, and 1 channel
8373        { 224, 2, 8,    1,   1 },
8374 // A ConvNet's first layer with 2x2 convolution, and 1 channel, several images in the batch
8375        { 224, 2, 8,    1,   8 },
8376 // A middle layer of a ConvNet
8377        { 58,  3, 64,   32,  1 },
8378 // A middle layer of a ConvNet, several images in the batch
8379        { 58,  3, 64,   32,  8 },
8380 // A deep layer of a ConvNet, several images in the batch
8381        { 16,  3, 512,  128, 8 },
8382 // High resolution output (large NPQ)
8383        {1536, 3, 64,   32,  1 },
8384    };
8385
8386    for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
8387        for (auto act_case : cases) {
8388            // Direct CONV_2D
8389            test_cases.emplace_back(new test_conv_2d(
8390                { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
8391                { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
8392                kernel_type, 1, 1, 0, 0, 1, 1, false));
8393        }
8394    }
8395
8396    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1,   1, 1, 1}));
8397    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
8398
8399    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_F16,  {512, 3072, 1, 1}));
8400    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_F32,  {8192, 512, 2, 1}, {0, 2, 1, 3}));
8401    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_F32,  {3072, 512, 2, 1}, {0, 2, 1, 3}));
8402    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_Q4_0, {8192, 512, 2, 1}));
8403    test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32,  {8192, 512, 2, 1}));
8404
8405    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
8406    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
8407    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
8408    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
8409
8410    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8411    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8412    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8413    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8414    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8415
8416
8417    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
8418    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
8419    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
8420    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
8421    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
8422    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
8423    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
8424    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
8425
8426    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
8427    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
8428    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
8429
8430    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {512, 34, 2, 1}));
8431    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 1, 1}));
8432    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 4, 1}));
8433    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));
8434    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
8435
8436    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8,  1}, {4, 1}, {0, 2, 1, 3}));
8437    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8,  1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
8438
8439    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 32, 64, 4, 4 }));
8440    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));
8441    // qwen3next with CHUNK_SIZE 64
8442    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 }));
8443    // qwen3next with CHUNK_SIZE 128
8444    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 }));
8445    test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 256, 256, 4, 2 }, { 128, 256, 4, 2 }));
8446
8447    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
8448    test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
8449
8450    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));
8451    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 16, 5, 4 }));
8452    test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20000, 10, 4, 1 }));
8453
8454    for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
8455        for (ggml_type type_a : all_types) {
8456            for (ggml_type type_b : {GGML_TYPE_F32}) {
8457                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1,  1}, {1, 1}));
8458            }
8459        }
8460    }
8461
8462    // qwen3-30b-a3b
8463    for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
8464        for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
8465            for (ggml_type type_b : {GGML_TYPE_F32}) {
8466                test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048));
8467                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
8468            }
8469        }
8470    }
8471
8472    for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
8473        for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
8474            for (ggml_type type_b : {GGML_TYPE_F32}) {
8475                test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048));
8476                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
8477            }
8478        }
8479    }
8480
8481
8482    // gpt-oss-20b
8483    for (int bs : {1, 4, 8, 512}) {
8484        for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
8485            for (ggml_type type_b : {GGML_TYPE_F32}) {
8486                test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880));
8487                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
8488            }
8489        }
8490    }
8491
8492    for (int K : {3, 5}) {
8493        for (int IC : {256, 2560}) {
8494            for (int IW_IH : {32, 64, 256}) {
8495                if (IC == 2560 && IW_IH == 256) {
8496                    // too big
8497                    continue;
8498                }
8499                test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {IW_IH, IW_IH, IC, 1}, {K, K, IC, 1}, 1, 1, 1, 1, 1, 1, true));
8500            }
8501        }
8502    }
8503
8504    // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012
8505    test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
8506
8507    test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
8508    test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 4, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
8509
8510    for (int kv : { 4096, 8192, 16384, }) {
8511        for (int hs : { 64, 128, }) {
8512            for (int nr : { 1, 4, }) {
8513                test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
8514            }
8515        }
8516    }
8517
8518    for (int col : {8192, 16384, 32768, 65536, 131072, 262144, 524288}) {
8519        for (int rows : {1, 4, 16}){
8520            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {col, rows, 1, 1}, false,  false,  GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
8521        }
8522    }
8523
8524    test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
8525    test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
8526
8527    test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
8528    test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));
8529    test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
8530
8531    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
8532
8533
8534    for (int n_token : {1, 512}) {
8535        test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 128, 4, n_token));
8536        test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));
8537    }
8538
8539    for (bool fw : {true, false}) { // fw == forward
8540        for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
8541            for (bool ff : {false, true}) { // freq_factors
8542                for (float v : { 0, 1 }) {
8543                    test_cases.emplace_back(new test_rope(type, {128,  32, 512, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // llama 7B
8544                    test_cases.emplace_back(new test_rope(type, {128,  64, 512, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // llama 65B
8545                    test_cases.emplace_back(new test_rope(type, { 80,  32, 512, 1},  20, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (stablelm)
8546                    test_cases.emplace_back(new test_rope(type, { 64,   8, 512, 1},  64, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (falcon 40B)
8547                    test_cases.emplace_back(new test_rope(type, {128,  12, 512, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
8548                    test_cases.emplace_back(new test_rope(type, {128,  12, 512, 1}, 128, GGML_ROPE_TYPE_IMROPE,  512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B)
8549                    test_cases.emplace_back(new test_rope(type, { 80,  16, 2, 1},  80, GGML_ROPE_TYPE_VISION, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
8550                }
8551            }
8552        }
8553    }
8554
8555    std::vector<std::array<int64_t, 4>> reduce_rows_cases = {
8556        { 8192, 1,    1, 1 },
8557        { 8192, 8192, 1, 1 },
8558        { 128,  8192, 1, 1 },
8559    };
8560
8561    for (auto it: reduce_rows_cases){
8562        test_cases.emplace_back(new test_mean(GGML_TYPE_F32, it));
8563        test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, it));
8564        test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it));
8565    }
8566
8567    test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000,  16, 1, 1}));
8568    test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1,  1, 1}));
8569    test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1}));
8570
8571    test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1));
8572    for (auto k : {1, 10, 40, 400}) {
8573        for (auto nrows : {1, 16}) {
8574            for (auto cols : {k, 1000, 65000, 200000}) {
8575                test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));
8576            }
8577        }
8578    }
8579
8580    for (auto nrows : {1, 4, 8, 16}) {
8581        for (auto cols : {128, 1024, 4096, 8192, 16384, 32768, 65536, 131072, 200000, 2000000}) {
8582            test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, {cols, nrows, 1, 1}));
8583        }
8584    }
8585
8586    // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
8587    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill
8588    test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4,   3328, 1, 1}, {4, 3328, 1, 1})); // generate
8589    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
8590    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1,   1)); // generate
8591
8592    return test_cases;
8593}
8594
8595static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter,
8596                         printer * output_printer) {
8597    auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
8598        if (params_filter == nullptr) {
8599            return;
8600        }
8601
8602        std::regex params_filter_regex(params_filter);
8603
8604        for (auto it = test_cases.begin(); it != test_cases.end();) {
8605            if (!std::regex_search((*it)->vars(), params_filter_regex)) {
8606                it = test_cases.erase(it);
8607                continue;
8608            }
8609
8610            it++;
8611        }
8612    };
8613
8614    if (mode == MODE_TEST) {
8615        auto test_cases = make_test_cases_eval();
8616        filter_test_cases(test_cases, params_filter);
8617        ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
8618        if (backend_cpu == NULL) {
8619            test_operation_info info("", "", "CPU");
8620            info.set_error("backend", "Failed to initialize CPU backend");
8621            output_printer->print_operation(info);
8622            return false;
8623        }
8624        // Use reference implementation on the CPU backend for comparison
8625        using ggml_backend_cpu_set_use_ref_t = void (*)(ggml_backend_t, bool);
8626        auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
8627        auto * set_use_ref = (ggml_backend_cpu_set_use_ref_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_use_ref");
8628        if (set_use_ref) {
8629            set_use_ref(backend_cpu, true);
8630        }
8631
8632        size_t n_ok = 0;
8633        size_t                   tests_run = 0;
8634        std::vector<std::string> failed_tests;
8635        for (auto & test : test_cases) {
8636            test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer);
8637            if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) {
8638                continue;
8639            }
8640            tests_run++;
8641            if (status == test_status_t::OK) {
8642                n_ok++;
8643            } else if (status == test_status_t::FAIL) {
8644                failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")");
8645            }
8646        }
8647        output_printer->print_summary(test_summary_info(n_ok, tests_run, false));
8648        output_printer->print_failed_tests(failed_tests);
8649
8650        ggml_backend_free(backend_cpu);
8651
8652        return n_ok == tests_run;
8653    }
8654
8655    if (mode == MODE_GRAD) {
8656        auto test_cases = make_test_cases_eval();
8657        filter_test_cases(test_cases, params_filter);
8658        size_t n_ok = 0;
8659        for (auto & test : test_cases) {
8660            if (test->eval_grad(backend, op_names_filter, output_printer)) {
8661                n_ok++;
8662            }
8663        }
8664        output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false));
8665
8666        return n_ok == test_cases.size();
8667    }
8668
8669    if (mode == MODE_PERF) {
8670        auto test_cases = make_test_cases_perf();
8671        filter_test_cases(test_cases, params_filter);
8672        for (auto & test : test_cases) {
8673            test->eval_perf(backend, op_names_filter, output_printer);
8674        }
8675        return true;
8676    }
8677
8678    if (mode == MODE_SUPPORT) {
8679        auto test_cases = make_test_cases_eval();
8680        filter_test_cases(test_cases, params_filter);
8681
8682        // Filter out fusion cases
8683        test_cases.erase(
8684            std::remove_if(test_cases.begin(), test_cases.end(), [](const std::unique_ptr<test_case> & tc) {
8685                return tc->run_whole_graph();
8686            }),
8687            test_cases.end()
8688        );
8689
8690        for (auto & test : test_cases) {
8691            test->eval_support(backend, op_names_filter, output_printer);
8692        }
8693        return true;
8694    }
8695
8696    GGML_ABORT("fatal error");
8697}
8698
8699static void list_all_ops() {
8700    printf("GGML operations:\n");
8701    std::set<std::string> all_ops;
8702
8703    for (int i = 1; i < GGML_OP_COUNT; i++) {
8704        all_ops.insert(ggml_op_name((enum ggml_op)i));
8705    }
8706    for (int i = 0; i < GGML_UNARY_OP_COUNT; i++) {
8707        all_ops.insert(ggml_unary_op_name((enum ggml_unary_op)i));
8708    }
8709    for (int i = 0; i < GGML_GLU_OP_COUNT; i++) {
8710        all_ops.insert(ggml_glu_op_name((enum ggml_glu_op)i));
8711    }
8712    for (const auto & op : all_ops) {
8713        printf("  %s\n", op.c_str());
8714    }
8715    printf("\nTotal: %zu operations\n", all_ops.size());
8716}
8717
8718static void show_test_coverage() {
8719    std::set<std::string> all_ops;
8720    for (int i = 1; i < GGML_OP_COUNT; i++) {
8721        auto op = (enum ggml_op)i;
8722        if (op == GGML_OP_VIEW      ||
8723            op == GGML_OP_RESHAPE   ||
8724            op == GGML_OP_PERMUTE   ||
8725            op == GGML_OP_TRANSPOSE ||
8726            op == GGML_OP_CONT      ||
8727            op == GGML_OP_GLU       ||
8728            op == GGML_OP_UNARY) {
8729            continue;
8730        }
8731        all_ops.insert(ggml_op_name(op));
8732    }
8733    for (int i = 0; i < GGML_UNARY_OP_COUNT; i++) {
8734        all_ops.insert(ggml_unary_op_name((enum ggml_unary_op)i));
8735    }
8736    for (int i = 0; i < GGML_GLU_OP_COUNT; i++) {
8737        all_ops.insert(ggml_glu_op_name((enum ggml_glu_op)i));
8738    }
8739    auto test_cases = make_test_cases_eval();
8740    // Filter out fusion cases
8741    test_cases.erase(
8742        std::remove_if(test_cases.begin(), test_cases.end(), [](const std::unique_ptr<test_case> & tc) {
8743            return tc->run_whole_graph();
8744        }),
8745        test_cases.end()
8746    );
8747
8748    std::set<std::string> tested_ops;
8749
8750    ggml_init_params params = {
8751        /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
8752        /* .mem_base = */ NULL,
8753        /* .no_alloc = */ true,
8754    };
8755
8756    for (auto & test_case : test_cases) {
8757        ggml_context * ctx = ggml_init(params);
8758        if (ctx) {
8759            test_case->mode = MODE_TEST;
8760            ggml_tensor * out = test_case->build_graph(ctx);
8761            if (out && out->op != GGML_OP_NONE) {
8762                if (out->op == GGML_OP_UNARY) {
8763                    tested_ops.insert(ggml_unary_op_name(ggml_get_unary_op(out)));
8764                } else if (out->op == GGML_OP_GLU) {
8765                    tested_ops.insert(ggml_glu_op_name(ggml_get_glu_op(out)));
8766                } else {
8767                    tested_ops.insert(ggml_op_name(out->op));
8768                }
8769            }
8770            ggml_free(ctx);
8771        }
8772    }
8773    std::set<std::string> covered_ops;
8774    std::set<std::string> uncovered_ops;
8775    for (const auto & op : all_ops) {
8776        if (tested_ops.count(op) > 0) {
8777            covered_ops.insert(op);
8778        } else {
8779            uncovered_ops.insert(op);
8780        }
8781    }
8782
8783    printf("Operations covered by tests (%zu):\n", covered_ops.size());
8784    for (const auto & op : covered_ops) {
8785        printf("  ✓ %s\n", op.c_str());
8786    }
8787    printf("\nOperations without tests (%zu):\n", uncovered_ops.size());
8788    for (const auto & op : uncovered_ops) {
8789        printf("  ✗ %s\n", op.c_str());
8790    }
8791
8792    printf("\nCoverage Summary:\n");
8793    printf("  Total operations: %zu\n", all_ops.size());
8794    printf("  Tested operations: %zu\n", covered_ops.size());
8795    printf("  Untested operations: %zu\n", uncovered_ops.size());
8796    printf("  Coverage: %.1f%%\n", (double)covered_ops.size() / all_ops.size() * 100.0);
8797}
8798
8799static void usage(char ** argv) {
8800    printf("Usage: %s [mode] [-o <op,..>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>] [--list-ops] [--show-coverage]\n", argv[0]);
8801    printf("    valid modes:\n");
8802    printf("      - test (default, compare with CPU backend for correctness)\n");
8803    printf("      - grad (compare gradients from backpropagation with method of finite differences)\n");
8804    printf("      - perf (performance evaluation)\n");
8805    printf("      - support (probe backend operation support)\n");
8806    printf("    op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc),\n");
8807    printf("        optionally including the full test case string (e.g. \"ADD(type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1)\")\n");
8808    printf("    --output specifies output format (default: console, options: console, sql, csv)\n");
8809    printf("    --list-ops lists all available GGML operations\n");
8810    printf("    --show-coverage shows test coverage\n");
8811}
8812
8813int main(int argc, char ** argv) {
8814    test_mode mode = MODE_TEST;
8815    output_formats output_format = CONSOLE;
8816    const char * op_names_filter = nullptr;
8817    const char * backend_filter = nullptr;
8818    const char * params_filter = nullptr;
8819
8820    for (int i = 1; i < argc; i++) {
8821        if (strcmp(argv[i], "test") == 0) {
8822            mode = MODE_TEST;
8823        } else if (strcmp(argv[i], "perf") == 0) {
8824            mode = MODE_PERF;
8825        } else if (strcmp(argv[i], "grad") == 0) {
8826            mode = MODE_GRAD;
8827        } else if (strcmp(argv[i], "support") == 0) {
8828            mode = MODE_SUPPORT;
8829        } else if (strcmp(argv[i], "-o") == 0) {
8830            if (i + 1 < argc) {
8831                op_names_filter = argv[++i];
8832            } else {
8833                usage(argv);
8834                return 1;
8835            }
8836        } else if (strcmp(argv[i], "-b") == 0) {
8837            if (i + 1 < argc) {
8838                backend_filter = argv[++i];
8839            } else {
8840                usage(argv);
8841                return 1;
8842            }
8843        } else if (strcmp(argv[i], "-p") == 0) {
8844            if (i + 1 < argc) {
8845                params_filter = argv[++i];
8846            } else {
8847                usage(argv);
8848                return 1;
8849            }
8850        } else if (strcmp(argv[i], "--output") == 0) {
8851            if (i + 1 < argc) {
8852                if (!output_format_from_str(argv[++i], output_format)) {
8853                    usage(argv);
8854                    return 1;
8855                }
8856            } else {
8857                usage(argv);
8858                return 1;
8859            }
8860        } else if (strcmp(argv[i], "--list-ops") == 0) {
8861            list_all_ops();
8862            return 0;
8863        } else if (strcmp(argv[i], "--show-coverage") == 0) {
8864            show_test_coverage();
8865            return 0;
8866        } else {
8867            usage(argv);
8868            return 1;
8869        }
8870    }
8871
8872    // load and enumerate backends
8873    ggml_backend_load_all();
8874
8875    // Create printer for output format
8876    std::unique_ptr<printer> output_printer = create_printer(output_format);
8877    if (output_printer) {
8878        output_printer->print_header();
8879    }
8880
8881    output_printer->print_testing_start(testing_start_info(ggml_backend_dev_count()));
8882
8883    size_t n_ok = 0;
8884
8885    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
8886        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
8887
8888        if (backend_filter != NULL && strcmp(backend_filter, ggml_backend_dev_name(dev)) != 0) {
8889            output_printer->print_backend_init(
8890                backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping"));
8891            n_ok++;
8892            continue;
8893        }
8894
8895        if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) {
8896            output_printer->print_backend_init(backend_init_info(
8897                i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping CPU backend"));
8898            n_ok++;
8899            continue;
8900        }
8901
8902        ggml_backend_t backend = ggml_backend_dev_init(dev, NULL);
8903        GGML_ASSERT(backend != NULL);
8904
8905        ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
8906        auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
8907        if (ggml_backend_set_n_threads_fn) {
8908            // TODO: better value for n_threads
8909            ggml_backend_set_n_threads_fn(backend, N_THREADS);
8910        }
8911
8912        size_t free, total;  // NOLINT
8913        ggml_backend_dev_memory(dev, &free, &total);
8914        output_printer->print_backend_init(backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev),
8915                                                             false, "", ggml_backend_dev_description(dev),
8916                                                             total / 1024 / 1024, free / 1024 / 1024, true));
8917
8918        bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get());
8919
8920        if (ok) {
8921            n_ok++;
8922        }
8923        output_printer->print_backend_status(
8924            backend_status_info(ggml_backend_name(backend), ok ? test_status_t::OK : test_status_t::FAIL));
8925
8926        ggml_backend_free(backend);
8927    }
8928
8929    ggml_quantize_free();
8930
8931    if (output_printer) {
8932        output_printer->print_footer();
8933    }
8934
8935    output_printer->print_overall_summary(
8936        overall_summary_info(n_ok, ggml_backend_dev_count(), n_ok == ggml_backend_dev_count()));
8937
8938    if (n_ok != ggml_backend_dev_count()) {
8939        return 1;
8940    }
8941
8942    return 0;
8943}