1/*
   2    WebGPU backend implementation.
   3    Note: Use ClangFormat to format this file.
   4*/
   5
   6#include "ggml-webgpu.h"
   7
   8#include "ggml-backend-impl.h"
   9#include "ggml-impl.h"
  10#include "ggml-webgpu-shader-lib.hpp"
  11#include "ggml-wgsl-shaders.hpp"
  12#include "pre_wgsl.hpp"
  13
  14#ifdef __EMSCRIPTEN__
  15#    include <emscripten/emscripten.h>
  16#endif
  17
  18#include <webgpu/webgpu_cpp.h>
  19
  20#include <atomic>
  21#include <condition_variable>
  22#include <cstdint>
  23#include <cstring>
  24#include <iostream>
  25#include <map>
  26#include <mutex>
  27#include <optional>
  28#include <string>
  29#include <vector>
  30
  31#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
  32#define CEIL_DIV(M, N)        (((M) + (N) - 1) / (N))
  33
  34#ifdef GGML_WEBGPU_DEBUG
  35#    define WEBGPU_LOG_DEBUG(msg)  std::cout << msg << std::endl
  36#    define WEBGPU_DEBUG_BUF_ELEMS 512
  37#else
  38#    define WEBGPU_LOG_DEBUG(msg) ((void) 0)
  39#endif  // GGML_WEBGPU_DEBUG
  40
  41#ifdef GGML_WEBGPU_CPU_PROFILE
  42// total timing (aggregated)
  43#    define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now();
  44
  45#    define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)                                                         \
  46        auto   cpu_total_end_##id = std::chrono::high_resolution_clock::now();                            \
  47        double cpu_total_time_##id =                                                                      \
  48            std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
  49        (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
  50// fine-grained timing (not included in totals)
  51#    define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
  52
  53#    define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)                                                          \
  54        auto   cpu_detail_end_##id = std::chrono::high_resolution_clock::now();                             \
  55        double cpu_detail_time_##id =                                                                       \
  56            std::chrono::duration<double, std::milli>(cpu_detail_end_##id - cpu_detail_start_##id).count(); \
  57        (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id;
  58#else
  59#    define WEBGPU_CPU_PROFILE_TOTAL_START(id)
  60#    define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)
  61#    define WEBGPU_CPU_PROFILE_DETAIL_START(id)
  62#    define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)
  63#endif  // GGML_WEBGPU_CPU_PROFILE
  64
  65#ifdef GGML_WEBGPU_GPU_PROFILE
  66#    define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS       24
  67#    define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16  // e.g. enough for two timestamps
  68#endif
  69
  70/* Constants */
  71
  72// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
  73#define WEBGPU_MAX_WG_SIZE 288
  74
  75#define WEBGPU_MUL_MAT_WG_SIZE               256
  76#define WEBGPU_NUM_PARAM_BUFS                16u
  77#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     8u
  78#define WEBGPU_WAIT_ANY_TIMEOUT_MS           0
  79// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
  80#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD  WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
  81#define WEBGPU_PARAMS_BUF_SIZE_BYTES         128  // enough for 32 parameters
  82#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS       16
  83#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
  84#define WEBGPU_STORAGE_BUF_BINDING_MULT      4  // a storage buffer binding size must be a multiple of 4
  85
  86// For operations which process a row in parallel, this seems like a reasonable default
  87#define WEBGPU_ROW_SPLIT_WG_SIZE 64
  88
  89// Matrix multiplication parameters
  90
  91// Register tiling parameters
  92#define WEBGPU_MUL_MAT_TILE_M    8
  93#define WEBGPU_MUL_MAT_TILE_N    8
  94#define WEBGPU_MUL_MAT_WG_SIZE_M 8
  95#define WEBGPU_MUL_MAT_WG_SIZE_N 8
  96#define WEBGPU_MUL_MAT_TILE_K    32
  97
  98// Subgroup matrix parameters
  99// The number of subgroups in the M dimension
 100#define WEBGPU_MUL_MAT_SUBGROUP_M        2
 101// The number of subgroups in the N dimension
 102#define WEBGPU_MUL_MAT_SUBGROUP_N        2
 103// The number of subgroup matrices each subgroup accumulates over
 104#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
 105#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
 106
 107// Matrix-vector multiplication parameters
 108#define WEBGPU_MUL_MAT_VEC_WG_SIZE        256
 109// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
 110#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
 111#define WEBGPU_MUL_MAT_VEC_TILE_K         256
 112
 113/* End Constants */
 114
 115// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
 116static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000;  // NOLINT
 117
 118// Always returns the base offset of a tensor, regardless of views.
 119static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
 120    if (tensor->view_src) {
 121        return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
 122    }
 123    return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
 124}
 125
 126/* Struct definitions */
 127
 128// Forward reference
 129static void ggml_webgpu_create_buffer(wgpu::Device &    device,
 130                                      wgpu::Buffer &    buffer,
 131                                      size_t            size,
 132                                      wgpu::BufferUsage usage,
 133                                      const char *      label);
 134
 135struct webgpu_pool_bufs {
 136    wgpu::Buffer host_buf;
 137    wgpu::Buffer dev_buf;
 138};
 139
 140// The futures to wait on for a single queue submission
 141struct webgpu_submission_futures {
 142    std::vector<wgpu::FutureWaitInfo> futures;
 143};
 144
 145// Holds a pool of parameter buffers for WebGPU operations
 146struct webgpu_buf_pool {
 147    std::vector<webgpu_pool_bufs> free;
 148
 149    // The pool must be synchronized because
 150    // 1. The memset pool is shared globally by every ggml buffer,
 151    // since allocating a pool per ggml buffer would consume too much memory.
 152    // 2. For the per-thread buffer pools in webgpu_context,
 153    // buffers are allocated and freed in Dawn callbacks,
 154    // which can run on a different thread than the calling thread.
 155    std::mutex              mutex;
 156    std::condition_variable cv;
 157
 158    void init(wgpu::Device      device,
 159              int               num_bufs,
 160              size_t            buf_size,
 161              wgpu::BufferUsage dev_buf_usage,
 162              wgpu::BufferUsage host_buf_usage) {
 163        for (int i = 0; i < num_bufs; i++) {
 164            wgpu::Buffer host_buf;
 165            wgpu::Buffer dev_buf;
 166            ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
 167            ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
 168            free.push_back({ host_buf, dev_buf });
 169        }
 170    }
 171
 172    webgpu_pool_bufs alloc_bufs() {
 173        std::unique_lock<std::mutex> lock(mutex);
 174        cv.wait(lock, [this] { return !free.empty(); });
 175        webgpu_pool_bufs bufs = free.back();
 176        free.pop_back();
 177        return bufs;
 178    }
 179
 180    void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
 181        std::lock_guard<std::mutex> lock(mutex);
 182        free.insert(free.end(), bufs.begin(), bufs.end());
 183        cv.notify_all();
 184    }
 185
 186    void cleanup() {
 187        std::lock_guard<std::mutex> lock(mutex);
 188        for (auto & bufs : free) {
 189            if (bufs.host_buf) {
 190                bufs.host_buf.Destroy();
 191            }
 192            if (bufs.dev_buf) {
 193                bufs.dev_buf.Destroy();
 194            }
 195        }
 196        free.clear();
 197    }
 198
 199    ~webgpu_buf_pool() { this->cleanup(); }
 200};
 201
 202#ifdef GGML_WEBGPU_GPU_PROFILE
 203struct webgpu_gpu_profile_bufs {
 204    wgpu::Buffer   host_buf;
 205    wgpu::Buffer   dev_buf;
 206    wgpu::QuerySet query_set;
 207};
 208
 209// Holds a pool of parameter buffers for WebGPU operations
 210struct webgpu_gpu_profile_buf_pool {
 211    std::vector<webgpu_gpu_profile_bufs> free;
 212
 213    std::mutex mutex;
 214
 215    std::condition_variable cv;
 216
 217    void init(wgpu::Device      device,
 218              int               num_bufs,
 219              size_t            buf_size,
 220              wgpu::BufferUsage dev_buf_usage,
 221              wgpu::BufferUsage host_buf_usage) {
 222        for (int i = 0; i < num_bufs; i++) {
 223            wgpu::Buffer host_buf;
 224            wgpu::Buffer dev_buf;
 225            ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf");
 226            ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf");
 227            // Create a query set for 2 timestamps
 228            wgpu::QuerySetDescriptor ts_query_set_desc = {};
 229
 230            ts_query_set_desc.type      = wgpu::QueryType::Timestamp;
 231            ts_query_set_desc.count     = 2;
 232            wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc);
 233
 234            free.push_back({ host_buf, dev_buf, ts_query_set });
 235        }
 236    }
 237
 238    webgpu_gpu_profile_bufs alloc_bufs() {
 239        std::unique_lock<std::mutex> lock(mutex);
 240        cv.wait(lock, [this] { return !free.empty(); });
 241        webgpu_gpu_profile_bufs bufs = free.back();
 242        free.pop_back();
 243        return bufs;
 244    }
 245
 246    void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) {
 247        std::lock_guard<std::mutex> lock(mutex);
 248        free.insert(free.end(), bufs.begin(), bufs.end());
 249        cv.notify_all();
 250    }
 251
 252    void cleanup() {
 253        std::lock_guard<std::mutex> lock(mutex);
 254        for (auto & bufs : free) {
 255            bufs.host_buf.Destroy();
 256            bufs.dev_buf.Destroy();
 257            bufs.query_set.Destroy();
 258        }
 259        free.clear();
 260    }
 261
 262    ~webgpu_gpu_profile_buf_pool() { this->cleanup(); }
 263};
 264#endif
 265
 266struct webgpu_pipeline {
 267    wgpu::ComputePipeline pipeline;
 268    std::string           name;
 269    std::shared_ptr<void> context = nullptr;
 270};
 271
 272struct webgpu_command {
 273    wgpu::CommandBuffer             commands;
 274    std::vector<webgpu_pool_bufs>   params_bufs;
 275    std::optional<webgpu_pool_bufs> set_rows_error_bufs;
 276#ifdef GGML_WEBGPU_GPU_PROFILE
 277    webgpu_gpu_profile_bufs timestamp_query_bufs;
 278    std::string             pipeline_name;
 279#endif
 280};
 281
 282struct webgpu_capabilities {
 283    wgpu::Limits limits;
 284    bool         supports_subgroup_matrix = false;
 285
 286    uint32_t sg_mat_m = 0;
 287    uint32_t sg_mat_n = 0;
 288    uint32_t sg_mat_k = 0;
 289
 290    uint32_t subgroup_size     = 0;
 291    uint32_t max_subgroup_size = 0;
 292    size_t   memset_bytes_per_thread;
 293};
 294
 295// Stores global webgpu members
 296struct webgpu_global_context_struct {
 297    wgpu::Instance instance;
 298    wgpu::Adapter  adapter;
 299    wgpu::Device   device;
 300    wgpu::Queue    queue;
 301
 302    webgpu_capabilities  capabilities;
 303    // Shared buffer to move data from device to host
 304    wgpu::Buffer         get_tensor_staging_buf;
 305    // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
 306    std::recursive_mutex mutex;
 307
 308    webgpu_buf_pool                memset_buf_pool;
 309    std::map<int, webgpu_pipeline> memset_pipelines;  // variant or type index
 310    std::atomic_uint               inflight_threads = 0;
 311
 312#ifdef GGML_WEBGPU_CPU_PROFILE
 313    // Profiling: labeled CPU time in ms (total)
 314    std::unordered_map<std::string, double> cpu_time_ms;
 315    // Profiling: detailed CPU time in ms
 316    std::unordered_map<std::string, double> cpu_detail_ms;
 317#endif
 318
 319#ifdef GGML_WEBGPU_GPU_PROFILE
 320    // Profiling: per-shader GPU time in ms
 321    std::unordered_map<std::string, double> shader_gpu_time_ms;
 322    // Profiling: pool of timestamp query buffers (one per operation)
 323    webgpu_gpu_profile_buf_pool             timestamp_query_buf_pool;
 324#endif
 325
 326#ifdef GGML_WEBGPU_DEBUG
 327    wgpu::Buffer debug_host_buf;
 328    wgpu::Buffer debug_dev_buf;
 329#endif
 330
 331    ~webgpu_global_context_struct() {
 332        if (this->get_tensor_staging_buf) {
 333            this->get_tensor_staging_buf.Destroy();
 334            this->get_tensor_staging_buf = nullptr;
 335        }
 336#ifdef GGML_WEBGPU_DEBUG
 337        if (this->debug_host_buf) {
 338            this->debug_host_buf.Destroy();
 339            this->debug_host_buf = nullptr;
 340        }
 341        if (this->debug_dev_buf) {
 342            this->debug_dev_buf.Destroy();
 343            this->debug_dev_buf = nullptr;
 344        }
 345#endif
 346    }
 347};
 348
 349typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
 350
 351// All the base objects needed to run operations on a WebGPU device
 352struct webgpu_context_struct {
 353    // Points to global instances owned by ggml_backend_webgpu_reg_context
 354    webgpu_global_context global_ctx;
 355
 356    pre_wgsl::Preprocessor p;
 357
 358    webgpu_buf_pool param_buf_pool;
 359    webgpu_buf_pool set_rows_error_buf_pool;
 360
 361    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines;  // src0_type, src1_type, vectorized
 362    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
 363        mul_mat_vec_pipelines;                                                       // src0_type, src1_type, vectorized
 364
 365    std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
 366        flash_attn_pipelines;
 367
 368    std::unordered_map<int, webgpu_pipeline> argmax_pipelines;         // key is vec4
 369    std::unordered_map<int, webgpu_pipeline> argsort_pipelines;        // key is order (asc/desc)
 370    std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines;  // key is order (asc/desc)
 371    std::unordered_map<int, webgpu_pipeline> cumsum_pipelines;         // key is fixed, no variants yet
 372    std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines;       // key is fixed, no variants yet
 373
 374    std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
 375                                                  set_rows_pipelines;
 376    std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines;  // src_type, vectorized
 377
 378    std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;       // src_type, dst_type
 379
 380    std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
 381        binary_pipelines;
 382
 383    std::map<int, webgpu_pipeline>                               rms_norm_pipelines;  // inplace
 384    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines;      // type, ff, inplace
 385    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines;       // glu_op, type, split
 386    std::map<int, webgpu_pipeline>                               scale_pipelines;     // inplace
 387    std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines;  // mask_type, has_sink, inplace
 388    std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
 389        unary_pipelines;
 390    std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;
 391
 392    size_t memset_bytes_per_thread;
 393};
 394
 395typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
 396
 397// Metadata required for the ggml backend registration/discovery interface
 398struct ggml_backend_webgpu_reg_context {
 399    // Since the Instance is a global entrypoint into the WebGPU API, it lives here
 400    webgpu_global_context webgpu_global_ctx;
 401    size_t                device_count;
 402    const char *          name;
 403};
 404
 405// Per-device struct for the global logical device interface
 406struct ggml_backend_webgpu_device_context {
 407    webgpu_global_context webgpu_global_ctx;
 408    std::string           device_name;
 409    std::string           device_desc;
 410};
 411
 412// Per-thread data required to actually run WebGPU operations in a backend instance
 413struct ggml_backend_webgpu_context {
 414    webgpu_context webgpu_ctx;
 415    std::string    name;
 416};
 417
 418// Per-thread data related to buffers
 419struct ggml_backend_webgpu_buffer_context {
 420    wgpu::Buffer          buffer;
 421    std::string           label;
 422    webgpu_global_context global_ctx;
 423
 424    ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) :
 425        buffer(std::move(buf)),
 426        label(std::move(lbl)),
 427        global_ctx(std::move(global_ctx_)) {}
 428};
 429
 430/* WebGPU object initializations */
 431
 432// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
 433// the corresponding values provided in `repls`.
 434static std::string ggml_webgpu_process_shader_repls(const char *                               src,
 435                                                    const std::map<std::string, std::string> & repls) {
 436    if (!src) {
 437        return std::string();
 438    }
 439    std::string s = src;
 440    for (const auto & kv : repls) {
 441        std::string token = "{{" + kv.first + "}}";
 442        size_t      pos   = 0;
 443        while ((pos = s.find(token, pos)) != std::string::npos) {
 444            s.replace(pos, token.length(), kv.second);
 445            pos += kv.second.length();
 446        }
 447    }
 448    return s;
 449}
 450
 451static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device &                           device,
 452                                                   const char *                             shader_code,
 453                                                   const char *                             label,
 454                                                   const std::vector<wgpu::ConstantEntry> & constants = {}) {
 455    wgpu::ShaderSourceWGSL shader_source;
 456    shader_source.code = shader_code;
 457
 458    wgpu::ShaderModuleDescriptor shader_desc;
 459    shader_desc.nextInChain = &shader_source;
 460
 461    wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
 462
 463    wgpu::ComputePipelineDescriptor pipeline_desc;
 464    pipeline_desc.label              = label;
 465    pipeline_desc.compute.module     = shader_module;
 466    pipeline_desc.compute.entryPoint = "main";   // Entry point in the WGSL code
 467    pipeline_desc.layout             = nullptr;  // nullptr means auto layout
 468    if (constants.size() > 0) {
 469        pipeline_desc.compute.constants     = constants.data();
 470        pipeline_desc.compute.constantCount = constants.size();
 471    }
 472    return { device.CreateComputePipeline(&pipeline_desc), label };
 473}
 474
 475static void ggml_webgpu_create_buffer(wgpu::Device &    device,
 476                                      wgpu::Buffer &    buffer,
 477                                      size_t            size,
 478                                      wgpu::BufferUsage usage,
 479                                      const char *      label) {
 480    wgpu::BufferDescriptor buffer_desc;
 481    buffer_desc.size             = size;
 482    buffer_desc.usage            = usage;
 483    buffer_desc.label            = label;
 484    buffer_desc.mappedAtCreation = false;
 485
 486    // TODO: error handling
 487    buffer = device.CreateBuffer(&buffer_desc);
 488}
 489
 490/** End WebGPU object initializations */
 491
 492/** WebGPU Actions */
 493
 494// Wait for the queue to finish processing all submitted work
 495static void ggml_backend_webgpu_wait(webgpu_global_context &                  ctx,
 496                                     std::vector<webgpu_submission_futures> & futures,
 497                                     bool                                     block = true) {
 498    // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
 499    // inflight_max may be 0, meaning that we must wait on all futures.
 500    uint64_t timeout_ms       = block ? UINT64_MAX : 0;
 501    uint32_t inflight_threads = ctx->inflight_threads;
 502    uint32_t inflight_max     = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
 503    while (futures.size() >= inflight_max && futures.size() > 0) {
 504        ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
 505        futures.erase(futures.begin());
 506    }
 507    size_t i = 0;
 508    while (i < futures.size()) {
 509        auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
 510        switch (waitStatus) {
 511            case wgpu::WaitStatus::Success:
 512                futures.erase(futures.begin() + i);
 513                break;
 514            case wgpu::WaitStatus::TimedOut:
 515                i++;
 516                break;
 517            case wgpu::WaitStatus::Error:
 518                GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
 519                break;
 520            default:
 521                GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
 522                break;
 523        }
 524    }
 525}
 526
 527static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
 528                                           wgpu::Buffer &          buffer,
 529                                           wgpu::MapMode           mode,
 530                                           size_t                  offset,
 531                                           size_t                  size) {
 532    ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
 533                                          [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
 534                                              if (status != wgpu::MapAsyncStatus::Success) {
 535                                                  GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
 536                                                                 message.data);
 537                                              }
 538                                          }),
 539                          UINT64_MAX);
 540}
 541
 542#ifdef GGML_WEBGPU_DEBUG
 543// This function adds debugging information to shaders, as WebGPU does not support printing directly.
 544// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
 545// debug statements in the shader, and then call this function after encoding the commands and submitting them.
 546static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
 547    wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
 548    encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
 549    wgpu::CommandBuffer commands = encoder.Finish();
 550    ctx->queue.Submit(1, &commands);
 551    ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
 552    const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
 553    std::cout << "debug[0]: " << debug_data[0] << "\n";
 554    ctx->debug_host_buf.Unmap();
 555}
 556#endif
 557
 558static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context       ctx,
 559                                                            std::vector<webgpu_command> commands,
 560                                                            webgpu_buf_pool &           param_buf_pool,
 561                                                            webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
 562    std::vector<wgpu::CommandBuffer> command_buffers;
 563    std::vector<webgpu_pool_bufs>    params_bufs;
 564    std::vector<webgpu_pool_bufs>    set_rows_error_bufs;
 565#ifdef GGML_WEBGPU_GPU_PROFILE
 566    std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
 567#endif
 568
 569    for (const auto & command : commands) {
 570        command_buffers.push_back(command.commands);
 571        params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
 572        if (command.set_rows_error_bufs) {
 573            set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
 574        }
 575    }
 576    ctx->queue.Submit(command_buffers.size(), command_buffers.data());
 577
 578    std::vector<wgpu::FutureWaitInfo> futures;
 579
 580    wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
 581        wgpu::CallbackMode::AllowSpontaneous,
 582        [&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
 583            if (status != wgpu::QueueWorkDoneStatus::Success) {
 584                GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
 585            }
 586            // Free the staged buffers
 587            param_buf_pool.free_bufs(params_bufs);
 588        });
 589    futures.push_back({ p_f });
 590
 591    for (const auto & bufs : set_rows_error_bufs) {
 592        wgpu::Future f = bufs.host_buf.MapAsync(
 593            wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
 594            [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
 595                if (status != wgpu::MapAsyncStatus::Success) {
 596                    GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
 597                } else {
 598                    const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
 599                    if (*error_data) {
 600                        GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
 601                    }
 602                    // We can't unmap in here due to WebGPU reentrancy limitations.
 603                    if (set_rows_error_buf_pool) {
 604                        set_rows_error_buf_pool->free_bufs({ bufs });
 605                    }
 606                }
 607            });
 608        futures.push_back({ f });
 609    }
 610
 611#ifdef GGML_WEBGPU_GPU_PROFILE
 612    for (const auto & command : commands) {
 613        auto label   = command.pipeline_name;
 614        auto ts_bufs = command.timestamp_query_bufs;
 615
 616        wgpu::Future f = ts_bufs.host_buf.MapAsync(
 617            wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
 618            [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
 619                if (status != wgpu::MapAsyncStatus::Success) {
 620                    GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
 621                } else {
 622                    const uint64_t * ts_data    = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
 623                    // WebGPU timestamps are in ns; convert to ms
 624                    double           elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
 625                    ctx->shader_gpu_time_ms[label] += elapsed_ms;
 626                    // We can't unmap in here due to WebGPU reentrancy limitations.
 627                    ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
 628                }
 629            });
 630        futures.push_back({ f });
 631    }
 632#endif
 633    return { futures };
 634}
 635
 636static webgpu_command ggml_backend_webgpu_build_multi(
 637    webgpu_global_context &                                ctx,
 638    webgpu_buf_pool &                                      param_buf_pool,
 639    const std::vector<webgpu_pipeline> &                   pipelines,
 640    const std::vector<std::vector<uint32_t>> &             params_list,
 641    const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
 642    const std::vector<std::pair<uint32_t, uint32_t>> &     workgroups_list,
 643    const std::optional<webgpu_pool_bufs> &                set_rows_error_bufs = std::nullopt) {
 644    GGML_ASSERT(pipelines.size() == params_list.size());
 645    GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
 646    GGML_ASSERT(pipelines.size() == workgroups_list.size());
 647
 648    std::vector<webgpu_pool_bufs> params_bufs_list;
 649    std::vector<wgpu::BindGroup>  bind_groups;
 650
 651    for (size_t i = 0; i < pipelines.size(); i++) {
 652        webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
 653
 654        ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
 655                                       params_bufs.host_buf.GetSize());
 656        uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
 657        for (size_t j = 0; j < params_list[i].size(); j++) {
 658            _params[j] = params_list[i][j];
 659        }
 660        params_bufs.host_buf.Unmap();
 661
 662        std::vector<wgpu::BindGroupEntry> entries            = bind_group_entries_list[i];
 663        uint32_t                          params_binding_num = entries.size();
 664        entries.push_back({ .binding = params_binding_num,
 665                            .buffer  = params_bufs.dev_buf,
 666                            .offset  = 0,
 667                            .size    = params_bufs.dev_buf.GetSize() });
 668
 669        wgpu::BindGroupDescriptor bind_group_desc;
 670        bind_group_desc.layout     = pipelines[i].pipeline.GetBindGroupLayout(0);
 671        bind_group_desc.entryCount = entries.size();
 672        bind_group_desc.entries    = entries.data();
 673        bind_group_desc.label      = pipelines[i].name.c_str();
 674        bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc));
 675
 676        params_bufs_list.push_back(params_bufs);
 677    }
 678
 679    wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
 680    for (const auto & params_bufs : params_bufs_list) {
 681        encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
 682    }
 683
 684    // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
 685    if (set_rows_error_bufs) {
 686        encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
 687                                   set_rows_error_bufs->host_buf.GetSize());
 688    }
 689
 690#ifdef GGML_WEBGPU_GPU_PROFILE
 691    webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
 692    if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
 693        ts_bufs.host_buf.Unmap();
 694    }
 695
 696    wgpu::PassTimestampWrites   ts_writes = { .querySet                  = ts_bufs.query_set,
 697                                              .beginningOfPassWriteIndex = 0,
 698                                              .endOfPassWriteIndex       = 1 };
 699    wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes };
 700    wgpu::ComputePassEncoder    pass      = encoder.BeginComputePass(&pass_desc);
 701#else
 702    wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
 703#endif
 704    for (size_t i = 0; i < pipelines.size(); i++) {
 705        pass.SetPipeline(pipelines[i].pipeline);
 706        pass.SetBindGroup(0, bind_groups[i]);
 707        pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1);
 708    }
 709    pass.End();
 710
 711#ifdef GGML_WEBGPU_GPU_PROFILE
 712    encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
 713    encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
 714#endif
 715
 716    wgpu::CommandBuffer commands = encoder.Finish();
 717    webgpu_command      result   = {};
 718    result.commands              = commands;
 719    result.params_bufs           = params_bufs_list;
 720    result.set_rows_error_bufs   = set_rows_error_bufs;
 721#ifdef GGML_WEBGPU_GPU_PROFILE
 722    result.timestamp_query_bufs = ts_bufs;
 723    // TODO: handle multiple pipeline names
 724    result.pipeline_name        = pipelines.front().name;
 725#endif
 726    return result;
 727}
 728
 729static webgpu_command ggml_backend_webgpu_build(webgpu_global_context &           ctx,
 730                                                webgpu_buf_pool &                 param_buf_pool,
 731                                                webgpu_pipeline &                 pipeline,
 732                                                std::vector<uint32_t>             params,
 733                                                std::vector<wgpu::BindGroupEntry> bind_group_entries,
 734                                                uint32_t                          wg_x,
 735                                                uint32_t                          wg_y                = 1,
 736                                                std::optional<webgpu_pool_bufs>   set_rows_error_bufs = std::nullopt) {
 737    return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
 738                                           {
 739                                               pipeline
 740    },
 741                                           { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
 742}
 743
 744static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
 745                                              wgpu::Buffer &          buf,
 746                                              uint32_t                value,
 747                                              size_t                  offset,
 748                                              size_t                  size) {
 749    std::vector<uint32_t>             params  = { (uint32_t) offset, (uint32_t) size, value };
 750    std::vector<wgpu::BindGroupEntry> entries = {
 751        { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
 752    };
 753    size_t   bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;
 754    uint32_t wg_x         = CEIL_DIV(size + 3, bytes_per_wg);
 755
 756    webgpu_command command =
 757        ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
 758    std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command },
 759                                                                                  ctx->memset_buf_pool) };
 760    ggml_backend_webgpu_wait(ctx, futures);
 761}
 762
 763/** End WebGPU Actions */
 764
 765/** GGML Backend Interface */
 766
 767static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
 768    ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
 769    return ctx->name.c_str();
 770}
 771
 772static void ggml_backend_webgpu_free(ggml_backend_t backend) {
 773    ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
 774    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
 775
 776#ifdef GGML_WEBGPU_CPU_PROFILE
 777    std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
 778    double total_cpu = 0.0;
 779    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
 780        total_cpu += kv.second;
 781    }
 782    std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
 783    std::cout << "ggml_webgpu: cpu breakdown:\n";
 784    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
 785        double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
 786        std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
 787    }
 788    if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) {
 789        std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
 790    }
 791    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) {
 792        double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
 793        std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
 794    }
 795#endif
 796
 797#ifdef GGML_WEBGPU_GPU_PROFILE
 798    std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
 799    double total_gpu = 0.0;
 800    for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
 801        total_gpu += kv.second;
 802    }
 803    std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
 804    std::cout << "\nggml_webgpu: gpu breakdown:\n";
 805    for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
 806        double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
 807        std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
 808    }
 809#endif
 810
 811#if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE)
 812    std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
 813#endif
 814
 815    delete ctx;
 816    delete backend;
 817}
 818
 819static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
 820    return webgpu_tensor_offset(tensor) + tensor->view_offs;
 821}
 822
 823static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
 824    ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
 825    return ctx->buffer;
 826}
 827
 828static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
 829    size_t offset = ggml_webgpu_tensor_offset(t);
 830    return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
 831}
 832
 833static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
 834    size_t offset = ggml_webgpu_tensor_offset(t);
 835    return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
 836}
 837
 838static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
 839    return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
 840}
 841
 842// Used to determine if two tensors are the same for in-place operations
 843static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
 844    return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
 845           (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
 846}
 847
 848// Used to determine if two tensors share the same buffer and their byte ranges overlap,
 849static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
 850    return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
 851           ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
 852           ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
 853}
 854
 855struct binary_overlap_flags {
 856    bool inplace;  // src0 == dst
 857    bool overlap;  // src1 == dst
 858};
 859
 860static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
 861                                                              ggml_tensor * src1,
 862                                                              ggml_tensor * dst) {
 863    binary_overlap_flags flags = {};
 864    flags.inplace              = ggml_webgpu_tensor_equal(src0, dst);
 865    flags.overlap              = ggml_webgpu_tensor_overlap(src1, dst);
 866
 867    return flags;
 868}
 869
 870static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
 871    uint32_t ne = (uint32_t) ggml_nelements(dst);
 872
 873    std::vector<uint32_t> params = {
 874        ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
 875        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
 876        // Convert byte-strides to element-strides
 877        (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
 878        (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
 879        (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
 880        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
 881        // Logical shapes
 882        (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
 883        (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
 884    };
 885
 886    std::vector<wgpu::BindGroupEntry> entries = {
 887        { .binding = 0,
 888         .buffer  = ggml_webgpu_tensor_buf(src),
 889         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
 890         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
 891        { .binding = 1,
 892         .buffer  = ggml_webgpu_tensor_buf(dst),
 893         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
 894         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
 895    };
 896
 897    uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
 898    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
 899                                     params, entries, wg_x);
 900}
 901
 902static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
 903    const bool circular = ggml_get_op_params_i32(dst, 8) != 0;
 904
 905    ggml_webgpu_pad_pipeline_key       pipeline_key   = { .circular = circular };
 906    ggml_webgpu_pad_shader_lib_context shader_lib_ctx = {
 907        .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
 908    };
 909
 910    webgpu_pipeline pipeline;
 911    auto            it = ctx->pad_pipelines.find(pipeline_key);
 912    if (it != ctx->pad_pipelines.end()) {
 913        pipeline = it->second;
 914    } else {
 915        ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
 916        pipeline =
 917            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
 918        pipeline.context = processed.decisions;
 919        ctx->pad_pipelines.emplace(pipeline_key, pipeline);
 920    }
 921
 922    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
 923
 924    const uint32_t ne = (uint32_t) ggml_nelements(dst);
 925
 926    std::vector<uint32_t> params = {
 927        ne,
 928        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
 929        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
 930        // Strides (in elements)
 931        (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
 932        (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
 933        (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
 934        (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
 935        // Shapes
 936        (uint32_t) src->ne[0],
 937        (uint32_t) src->ne[1],
 938        (uint32_t) src->ne[2],
 939        (uint32_t) src->ne[3],
 940        (uint32_t) dst->ne[0],
 941        (uint32_t) dst->ne[1],
 942        (uint32_t) dst->ne[2],
 943        (uint32_t) dst->ne[3],
 944        // Pad sizes
 945        (uint32_t) ggml_get_op_params_i32(dst, 0),
 946        (uint32_t) ggml_get_op_params_i32(dst, 1),
 947        (uint32_t) ggml_get_op_params_i32(dst, 2),
 948        (uint32_t) ggml_get_op_params_i32(dst, 3),
 949        (uint32_t) ggml_get_op_params_i32(dst, 4),
 950        (uint32_t) ggml_get_op_params_i32(dst, 5),
 951        (uint32_t) ggml_get_op_params_i32(dst, 6),
 952        (uint32_t) ggml_get_op_params_i32(dst, 7),
 953    };
 954
 955    std::vector<wgpu::BindGroupEntry> entries = {
 956        { .binding = 0,
 957         .buffer  = ggml_webgpu_tensor_buf(src),
 958         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
 959         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
 960        { .binding = 1,
 961         .buffer  = ggml_webgpu_tensor_buf(dst),
 962         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
 963         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
 964    };
 965
 966    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
 967    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 968}
 969
 970static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
 971                                                          ggml_tensor *    src,
 972                                                          ggml_tensor *    idx,
 973                                                          ggml_tensor *    dst) {
 974    // For set rows specifically, we need to check if src and idx are empty tensors.
 975    if (ggml_is_empty(src) || ggml_is_empty(idx)) {
 976        return std::nullopt;
 977    }
 978
 979    ggml_webgpu_set_rows_pipeline_key key = { .dst_type = dst->type,
 980                                              .vec4     = src->ne[0] % 4 == 0,
 981                                              .i64_idx  = idx->type == GGML_TYPE_I64 };
 982
 983    ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = {
 984        .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
 985    };
 986
 987    webgpu_pipeline pipeline;
 988    auto            it = ctx->set_rows_pipelines.find(key);
 989    if (it != ctx->set_rows_pipelines.end()) {
 990        pipeline = it->second;
 991    } else {
 992        ggml_webgpu_processed_shader processed =
 993            ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
 994        pipeline =
 995            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
 996        pipeline.context = processed.decisions;
 997        ctx->set_rows_pipelines.emplace(key, pipeline);
 998    }
 999
1000    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1001
1002    std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
1003    if (key.i64_idx) {
1004        error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
1005        if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
1006            error_bufs->host_buf.Unmap();
1007        }
1008    }
1009
1010    std::vector<uint32_t> params = {
1011        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1012        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
1013        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1014        // Convert byte-strides to element-strides
1015        (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1016        (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
1017        (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
1018        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1019        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1020        // Shape of src
1021        (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
1022        // Shape of idx
1023        (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
1024    };
1025
1026    std::vector<wgpu::BindGroupEntry> entries = {
1027        { .binding = 0,
1028         .buffer  = ggml_webgpu_tensor_buf(src),
1029         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
1030         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
1031        { .binding = 1,
1032         .buffer  = ggml_webgpu_tensor_buf(idx),
1033         .offset  = ggml_webgpu_tensor_align_offset(ctx, idx),
1034         .size    = ggml_webgpu_tensor_binding_size(ctx, idx) },
1035        { .binding = 2,
1036         .buffer  = ggml_webgpu_tensor_buf(dst),
1037         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1038         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
1039    };
1040
1041    if (key.i64_idx) {
1042        entries.push_back(
1043            { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
1044    }
1045
1046    uint32_t threads;
1047    if (key.vec4) {
1048        threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
1049    } else {
1050        threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
1051    }
1052    uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
1053    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
1054                                     error_bufs);
1055}
1056
1057static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
1058                                           ggml_tensor *    src,
1059                                           ggml_tensor *    idx,
1060                                           ggml_tensor *    dst) {
1061    std::vector<uint32_t> params = {
1062        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1063        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
1064        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1065        // Convert byte-strides to element-strides
1066        (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1067        (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
1068        (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
1069        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1070        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1071        // Shape of dst
1072        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
1073        // Shape of idx
1074        (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
1075    };
1076
1077    std::vector<wgpu::BindGroupEntry> entries = {
1078        { .binding = 0,
1079         .buffer  = ggml_webgpu_tensor_buf(src),
1080         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
1081         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
1082        { .binding = 1,
1083         .buffer  = ggml_webgpu_tensor_buf(idx),
1084         .offset  = ggml_webgpu_tensor_align_offset(ctx, idx),
1085         .size    = ggml_webgpu_tensor_binding_size(ctx, idx) },
1086        { .binding = 2,
1087         .buffer  = ggml_webgpu_tensor_buf(dst),
1088         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1089         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
1090    };
1091
1092    uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
1093
1094    uint32_t        vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
1095    webgpu_pipeline pipeline   = ctx->get_rows_pipelines[src->type][vectorized];
1096    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1097}
1098
1099static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
1100                                          ggml_tensor *    src0,
1101                                          ggml_tensor *    src1,
1102                                          ggml_tensor *    dst) {
1103    std::vector<uint32_t> params = {
1104        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1105        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1106        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1107        (uint32_t) dst->ne[0],                                  // number of rows in result (M, transposed)
1108        (uint32_t) dst->ne[1],                                  // number of columns in result (N)
1109        (uint32_t) src0->ne[0],                                 // number of columns in src0/src1 (K)
1110        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),  // stride (elements/blocks) of src0 in dimension 1
1111        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),  // stride (elements/blocks) of src1 in dimension 1
1112        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),  // stride (elements/blocks) of src0 in dimension 2
1113        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),  // stride (elements/blocks) of src1 in dimension 2
1114        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),  // stride (elements/blocks) of src0 in dimension 3
1115        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),  // stride (elements/blocks) of src1 in dimension 3
1116        (uint32_t) src0->ne[2],                                 // batch size in dimension 2
1117        (uint32_t) src0->ne[3],                                 // batch size in dimension 3
1118        (uint32_t) (src1->ne[2] / src0->ne[2]),                 // broadcast in dimension 2
1119        (uint32_t) (src1->ne[3] / src0->ne[3])                  // broadcast in dimension 3
1120    };
1121
1122    std::vector<wgpu::BindGroupEntry> entries = {
1123        { .binding = 0,
1124         .buffer  = ggml_webgpu_tensor_buf(src0),
1125         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
1126         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },
1127        { .binding = 1,
1128         .buffer  = ggml_webgpu_tensor_buf(src1),
1129         .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
1130         .size    = ggml_webgpu_tensor_binding_size(ctx, src1) },
1131        { .binding = 2,
1132         .buffer  = ggml_webgpu_tensor_buf(dst),
1133         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1134         .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  },
1135    };
1136
1137    webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
1138
1139    uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
1140    uint32_t wg_y = 1;
1141
1142    bool use_fast = false;
1143    switch (src1->type) {
1144        case GGML_TYPE_F16:
1145            use_fast = (src0->type == GGML_TYPE_F16);
1146            break;
1147        case GGML_TYPE_F32:
1148            switch (src0->type) {
1149                case GGML_TYPE_F32:
1150                case GGML_TYPE_F16:
1151                case GGML_TYPE_Q4_0:
1152                    use_fast = true;
1153                    break;
1154                default:
1155                    break;
1156            }
1157            break;
1158        default:
1159            break;
1160    }
1161
1162    if (use_fast) {
1163        int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
1164        if (dst->ne[1] == 1) {
1165            // We don't support vectorized mul_mat_vec for quantized types
1166            vectorized             = vectorized && (src0->type < 2);
1167            pipeline               = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
1168            uint32_t batches       = dst->ne[2] * dst->ne[3];
1169            uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
1170            uint32_t total_wg      = output_groups * batches;
1171            wg_x                   = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1172            wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
1173        } else {
1174            pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
1175            uint32_t wg_m;
1176            uint32_t wg_n;
1177#ifndef __EMSCRIPTEN__
1178            if (ctx->global_ctx->capabilities.supports_subgroup_matrix) {
1179                // The total number of subgroups/workgroups needed per matrix.
1180                uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M *
1181                                        ctx->global_ctx->capabilities.sg_mat_m;
1182                wg_m                  = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
1183                uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N *
1184                                        ctx->global_ctx->capabilities.sg_mat_n;
1185                wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
1186            } else {
1187#endif
1188                uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
1189                uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
1190                wg_m              = CEIL_DIV(dst->ne[0], tile_m_s);
1191                wg_n              = CEIL_DIV(dst->ne[1], tile_n_s);
1192#ifndef __EMSCRIPTEN__
1193            }
1194#endif
1195
1196            wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
1197        }
1198    }
1199    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
1200}
1201
1202#ifndef __EMSCRIPTEN__
1203static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
1204                                             ggml_tensor *    Q,
1205                                             ggml_tensor *    K,
1206                                             ggml_tensor *    V,
1207                                             ggml_tensor *    mask,
1208                                             ggml_tensor *    sinks,
1209                                             ggml_tensor *    dst) {
1210    float scale = *(float *) dst->op_params;
1211    float max_bias;
1212    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1213    float logit_softcap;
1214    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
1215    if (logit_softcap != 0.0f) {
1216        scale /= logit_softcap;
1217    }
1218    float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
1219    float m0          = powf(2.0f, -(max_bias) / n_head_log2);
1220    float m1          = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1221
1222    const int has_mask  = (mask != nullptr);
1223    const int has_sinks = (sinks != nullptr);
1224
1225    std::vector<uint32_t> params = {
1226        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
1227        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
1228        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
1229        has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
1230        has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
1231        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1232        (uint32_t) Q->ne[2],                              // number of heads
1233        (uint32_t) Q->ne[1],                              // sequence length (Q)
1234        (uint32_t) K->ne[1],                              // sequence length (K/V)
1235        (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)),  // stride (elements/blocks) of Q in dimension 1
1236        (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)),  // stride (elements/blocks) of Q in dimension 2
1237        (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)),  // stride (elements/blocks) of Q in dimension 3
1238        (uint32_t) (K->nb[1] / ggml_type_size(K->type)),  // stride (elements/blocks) of K in dimension 1
1239        (uint32_t) (K->nb[2] / ggml_type_size(K->type)),  // stride (elements/blocks) of K in dimension 2
1240        (uint32_t) (K->nb[3] / ggml_type_size(K->type)),  // stride (elements/blocks) of K in dimension 3
1241        (uint32_t) (V->nb[1] / ggml_type_size(V->type)),  // stride (elements/blocks) of V in dimension 1
1242        (uint32_t) (V->nb[2] / ggml_type_size(V->type)),  // stride (elements/blocks) of V in dimension 2
1243        (uint32_t) (V->nb[3] / ggml_type_size(V->type)),  // stride (elements/blocks) of V in dimension 3
1244        has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0,  // stride of mask dim 3
1245        (uint32_t) (Q->ne[2] / K->ne[2]),  // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
1246        *(uint32_t *) &scale,              // scale (possibly adjusted for logit softcap)
1247        *(uint32_t *) &max_bias,
1248        *(uint32_t *) &logit_softcap,
1249        *(uint32_t *) &n_head_log2,
1250        *(uint32_t *) &m0,
1251        *(uint32_t *) &m1
1252
1253    };
1254    std::vector<wgpu::BindGroupEntry> entries = {
1255        { .binding = 0,
1256         .buffer  = ggml_webgpu_tensor_buf(Q),
1257         .offset  = ggml_webgpu_tensor_align_offset(ctx, Q),
1258         .size    = ggml_webgpu_tensor_binding_size(ctx, Q) },
1259        { .binding = 1,
1260         .buffer  = ggml_webgpu_tensor_buf(K),
1261         .offset  = ggml_webgpu_tensor_align_offset(ctx, K),
1262         .size    = ggml_webgpu_tensor_binding_size(ctx, K) },
1263        { .binding = 2,
1264         .buffer  = ggml_webgpu_tensor_buf(V),
1265         .offset  = ggml_webgpu_tensor_align_offset(ctx, V),
1266         .size    = ggml_webgpu_tensor_binding_size(ctx, V) }
1267    };
1268    uint32_t binding_index = 3;
1269    if (has_mask) {
1270        entries.push_back({ .binding = binding_index++,
1271                            .buffer  = ggml_webgpu_tensor_buf(mask),
1272                            .offset  = ggml_webgpu_tensor_align_offset(ctx, mask),
1273                            .size    = ggml_webgpu_tensor_binding_size(ctx, mask) });
1274    }
1275    if (has_sinks) {
1276        entries.push_back({ .binding = binding_index++,
1277                            .buffer  = ggml_webgpu_tensor_buf(sinks),
1278                            .offset  = ggml_webgpu_tensor_align_offset(ctx, sinks),
1279                            .size    = ggml_webgpu_tensor_binding_size(ctx, sinks) });
1280    }
1281    entries.push_back({ .binding = binding_index++,
1282                        .buffer  = ggml_webgpu_tensor_buf(dst),
1283                        .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1284                        .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
1285
1286    bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
1287                     (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
1288
1289    ggml_webgpu_flash_attn_pipeline_key key = {
1290        .kv_type            = K->type,
1291        .head_dim_qk        = (uint32_t) Q->ne[0],
1292        .head_dim_v         = (uint32_t) V->ne[0],
1293        .kv_direct          = kv_direct,
1294        .has_mask           = static_cast<bool>(has_mask),
1295        .has_sinks          = static_cast<bool>(has_sinks),
1296        .uses_logit_softcap = logit_softcap != 0.0f,
1297    };
1298
1299    webgpu_pipeline pipeline;
1300    auto            it = ctx->flash_attn_pipelines.find(key);
1301    if (it != ctx->flash_attn_pipelines.end()) {
1302        pipeline = it->second;
1303    } else {
1304        ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
1305            .key                = key,
1306            .sg_mat_m           = ctx->global_ctx->capabilities.sg_mat_m,
1307            .sg_mat_n           = ctx->global_ctx->capabilities.sg_mat_n,
1308            .sg_mat_k           = ctx->global_ctx->capabilities.sg_mat_k,
1309            .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1310            .max_subgroup_size  = ctx->global_ctx->capabilities.max_subgroup_size
1311        };
1312
1313        ggml_webgpu_processed_shader processed =
1314            ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
1315        pipeline =
1316            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1317        pipeline.context = processed.decisions;
1318        ctx->flash_attn_pipelines.emplace(key, pipeline);
1319    }
1320
1321    auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
1322
1323    uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
1324    uint32_t wg_x        = wg_per_head * Q->ne[2] * Q->ne[3];  // wg per head * number of heads * number of batches
1325    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1326}
1327#endif
1328
1329static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1330    bool is_unary = dst->op == GGML_OP_UNARY;
1331    bool inplace  = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
1332    int  op       = is_unary ? (int) ggml_get_unary_op(dst) : dst->op;
1333
1334    ggml_webgpu_unary_pipeline_key pipeline_key = {
1335        .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace
1336    };
1337    ggml_webgpu_unary_shader_lib_context shader_lib_ctx = {
1338        .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
1339    };
1340
1341    webgpu_pipeline pipeline;
1342    auto            it = ctx->unary_pipelines.find(pipeline_key);
1343    if (it != ctx->unary_pipelines.end()) {
1344        pipeline = it->second;
1345    } else {
1346        ggml_webgpu_processed_shader processed =
1347            ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
1348        pipeline =
1349            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1350        pipeline.context = processed.decisions;
1351        ctx->unary_pipelines.emplace(pipeline_key, pipeline);
1352    }
1353
1354    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1355
1356    uint32_t ne = (uint32_t) ggml_nelements(dst);
1357
1358    std::vector<uint32_t> params = { ne,
1359                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1360                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1361                                     (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
1362                                     (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1363                                     (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1364                                     (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1365                                     (uint32_t) src->ne[0],
1366                                     (uint32_t) src->ne[1],
1367                                     (uint32_t) src->ne[2] };
1368
1369    ggml_tensor * effective_src = src;
1370    if (is_unary) {
1371        ggml_unary_op unary_op = ggml_get_unary_op(dst);
1372        switch (unary_op) {
1373            case GGML_UNARY_OP_XIELU:
1374                {
1375                    // Get float parameters and reinterpret their bit patterns as uint32_t
1376                    // for passing through the params buffer
1377                    float alpha_n = ggml_get_op_params_f32(dst, 1);
1378                    float alpha_p = ggml_get_op_params_f32(dst, 2);
1379                    float beta    = ggml_get_op_params_f32(dst, 3);
1380                    float eps     = ggml_get_op_params_f32(dst, 4);
1381                    params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
1382                    params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
1383                    params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
1384                    params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
1385                    break;
1386                }
1387            default:
1388                break;
1389        }
1390    } else if (dst->op == GGML_OP_CLAMP) {
1391        float clamp_min = ggml_get_op_params_f32(dst, 0);
1392        float clamp_max = ggml_get_op_params_f32(dst, 1);
1393        params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_min));
1394        params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_max));
1395    } else if (dst->op == GGML_OP_FILL) {
1396        float fill_val = ggml_get_op_params_f32(dst, 0);
1397        params.push_back(*reinterpret_cast<const uint32_t *>(&fill_val));
1398        effective_src = dst;  // fill simply fills dst
1399    }
1400
1401    std::vector<wgpu::BindGroupEntry> entries = {
1402        { .binding = 0,
1403         .buffer  = ggml_webgpu_tensor_buf(effective_src),
1404         .offset  = ggml_webgpu_tensor_align_offset(ctx, effective_src),
1405         .size    = ggml_webgpu_tensor_binding_size(ctx, effective_src) },
1406    };
1407    if (!inplace) {
1408        entries.push_back({ .binding = 1,
1409                            .buffer  = ggml_webgpu_tensor_buf(dst),
1410                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1411                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
1412    }
1413
1414    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1415    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1416}
1417
1418static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
1419                                            ggml_tensor *    src0,
1420                                            ggml_tensor *    src1,
1421                                            ggml_tensor *    dst) {
1422    binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
1423
1424    ggml_webgpu_binary_pipeline_key pipeline_key = {
1425        .type    = dst->type,
1426        .op      = dst->op,
1427        .inplace = flags.inplace,
1428        .overlap = flags.overlap,
1429    };
1430    ggml_webgpu_binary_shader_lib_context shader_lib_ctx = {
1431        .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
1432    };
1433
1434    webgpu_pipeline pipeline;
1435    auto            it = ctx->binary_pipelines.find(pipeline_key);
1436    if (it != ctx->binary_pipelines.end()) {
1437        pipeline = it->second;
1438    } else {
1439        ggml_webgpu_processed_shader processed =
1440            ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx);
1441        pipeline =
1442            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1443        pipeline.context = processed.decisions;
1444        ctx->binary_pipelines.emplace(pipeline_key, pipeline);
1445    }
1446
1447    auto * decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(pipeline.context.get());
1448
1449    uint32_t ne = (uint32_t) ggml_nelements(dst);
1450
1451    std::vector<uint32_t> params = {
1452        ne,
1453        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1454        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1455        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1456        (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
1457        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1458        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
1459        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1460        (uint32_t) src0->ne[0],
1461        (uint32_t) src0->ne[1],
1462        (uint32_t) src0->ne[2],
1463        (uint32_t) src1->ne[0],
1464        (uint32_t) src1->ne[1],
1465        (uint32_t) src1->ne[2],
1466        (uint32_t) src1->ne[3],
1467    };
1468
1469    std::vector<wgpu::BindGroupEntry> entries;
1470
1471    entries.push_back({
1472        .binding = 0,
1473        .buffer  = ggml_webgpu_tensor_buf(src0),
1474        .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
1475        .size    = ggml_webgpu_tensor_binding_size(ctx, src0),
1476    });
1477
1478    entries.push_back({
1479        .binding = 1,
1480        .buffer  = ggml_webgpu_tensor_buf(src1),
1481        .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
1482        .size    = ggml_webgpu_tensor_binding_size(ctx, src1),
1483    });
1484
1485    if (!flags.inplace && !flags.overlap) {
1486        entries.push_back({ .binding = 2,
1487                            .buffer  = ggml_webgpu_tensor_buf(dst),
1488                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1489                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
1490    }
1491
1492    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1493    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1494}
1495
1496static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1497    int inplace = ggml_webgpu_tensor_equal(src, dst);
1498
1499    std::vector<uint32_t> params = {
1500        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1501        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1502        (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1503        (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1504        (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1505        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1506        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1507        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1508        (uint32_t) src->ne[0],
1509        (uint32_t) src->ne[1],
1510        (uint32_t) src->ne[2],
1511        (uint32_t) src->ne[3],
1512        *(uint32_t *) dst->op_params  // epsilon, treated as f32 in the shader
1513    };
1514
1515    std::vector<wgpu::BindGroupEntry> entries = {
1516        { .binding = 0,
1517         .buffer  = ggml_webgpu_tensor_buf(src),
1518         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
1519         .size    = ggml_webgpu_tensor_binding_size(ctx, src) }
1520    };
1521    if (!inplace) {
1522        entries.push_back({ .binding = 1,
1523                            .buffer  = ggml_webgpu_tensor_buf(dst),
1524                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1525                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
1526    }
1527
1528    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
1529                                     entries, ggml_nrows(src));
1530}
1531
1532static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
1533                                       ggml_tensor *    src0,
1534                                       ggml_tensor *    src1,
1535                                       ggml_tensor *    src2,
1536                                       ggml_tensor *    dst) {
1537    const int inplace         = ggml_webgpu_tensor_equal(src0, dst);
1538    const int has_freq_factor = (src2 != nullptr);
1539
1540    const int n_dims     = ((int32_t *) dst->op_params)[1];
1541    const int mode       = ((int32_t *) dst->op_params)[2];
1542    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1543
1544    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1545    memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1546    memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1547    memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1548    memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1549    memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1550    memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1551
1552    int sections[4];
1553    memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
1554
1555    float theta_scale = powf(freq_base, -2.0f / n_dims);
1556
1557    float corr_dims[2];
1558    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1559
1560    std::vector<uint32_t> params = {
1561        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1562        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1563        src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1564        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1565        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1566        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1567        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1568        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1569        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1570        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1571        (uint32_t) ggml_nelements(src0) / 2,
1572        (uint32_t) src0->ne[0],
1573        (uint32_t) src0->ne[1],
1574        (uint32_t) src0->ne[2],
1575        (uint32_t) n_dims,
1576        (uint32_t) mode,
1577        *(uint32_t *) &theta_scale,
1578        *(uint32_t *) &attn_factor,
1579        *(uint32_t *) &freq_scale,
1580        *(uint32_t *) &ext_factor,
1581        *(uint32_t *) &corr_dims[0],
1582        *(uint32_t *) &corr_dims[1],
1583        (uint32_t) sections[0],
1584        (uint32_t) sections[1],
1585        (uint32_t) sections[2],
1586        (uint32_t) sections[3]
1587    };
1588
1589    std::vector<wgpu::BindGroupEntry> entries = {
1590        { .binding = 0,
1591         .buffer  = ggml_webgpu_tensor_buf(src0),
1592         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
1593         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },
1594        { .binding = 1,
1595         .buffer  = ggml_webgpu_tensor_buf(src1),
1596         .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
1597         .size    = ggml_webgpu_tensor_binding_size(ctx, src1) }
1598    };
1599    uint32_t dst_binding = 2;
1600    if (has_freq_factor) {
1601        dst_binding = 3;
1602        entries.push_back({ .binding = 2,
1603                            .buffer  = ggml_webgpu_tensor_buf(src2),
1604                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src2),
1605                            .size    = ggml_webgpu_tensor_binding_size(ctx, src2) });
1606    }
1607    if (!inplace) {
1608        entries.push_back({ .binding = dst_binding,
1609                            .buffer  = ggml_webgpu_tensor_buf(dst),
1610                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1611                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
1612    }
1613
1614    webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
1615    uint32_t        wg_x     = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1616    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1617}
1618
1619static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
1620    const int split = (src1 != nullptr);
1621
1622    std::vector<uint32_t> params = {
1623        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1624        src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1625        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1626        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1627        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1628        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1629        src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
1630                          (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1631        src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
1632                          (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1633        src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
1634                          (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1635        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1636        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1637        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1638        (uint32_t) ggml_nelements(dst),
1639        (uint32_t) dst->ne[0],
1640        (uint32_t) dst->ne[1],
1641        (uint32_t) dst->ne[2],
1642        (uint32_t) ((int32_t *) dst->op_params)[1],  // swapped
1643        *(uint32_t *) &dst->op_params[2],            // alpha, for swiglu_oai
1644        *(uint32_t *) &dst->op_params[3],            // limit, for swiglu_oai
1645    };
1646
1647    std::vector<wgpu::BindGroupEntry> entries = {
1648        { .binding = 0,
1649         .buffer  = ggml_webgpu_tensor_buf(src0),
1650         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
1651         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },
1652    };
1653    uint32_t dst_binding = 1;
1654    if (split) {
1655        dst_binding = 2;
1656        entries.push_back({ .binding = 1,
1657                            .buffer  = ggml_webgpu_tensor_buf(src1),
1658                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
1659                            .size    = ggml_webgpu_tensor_binding_size(ctx, src1) });
1660    }
1661    entries.push_back({ .binding = dst_binding,
1662                        .buffer  = ggml_webgpu_tensor_buf(dst),
1663                        .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1664                        .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
1665
1666    webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
1667    uint32_t        wg_x     = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1668    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1669}
1670
1671static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1672    int inplace = ggml_webgpu_tensor_equal(src, dst);
1673
1674    std::vector<uint32_t> params = {
1675        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1676        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1677        (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1678        (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1679        (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1680        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1681        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1682        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1683        (uint32_t) ggml_nelements(dst),
1684        (uint32_t) src->ne[0],
1685        (uint32_t) src->ne[1],
1686        (uint32_t) src->ne[2],
1687        *(uint32_t *) dst->op_params,     // scale
1688        *(uint32_t *) &dst->op_params[1]  // bias
1689    };
1690
1691    std::vector<wgpu::BindGroupEntry> entries = {
1692        { .binding = 0,
1693         .buffer  = ggml_webgpu_tensor_buf(src),
1694         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
1695         .size    = ggml_webgpu_tensor_binding_size(ctx, src) }
1696    };
1697    if (!inplace) {
1698        entries.push_back({ .binding = 1,
1699                            .buffer  = ggml_webgpu_tensor_buf(dst),
1700                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1701                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
1702    }
1703
1704    uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1705    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params,
1706                                     entries, wg_x);
1707}
1708
1709static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
1710                                           ggml_tensor *    src0,
1711                                           ggml_tensor *    src1,
1712                                           ggml_tensor *    src2,
1713                                           ggml_tensor *    dst) {
1714    const int inplace   = ggml_webgpu_tensor_equal(src0, dst);
1715    const int mask_type = (src1 != nullptr) ? src1->type : 2;  // use 2 for no mask here
1716    const int has_sink  = (src2 != nullptr);
1717    float     max_bias;
1718    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1719    float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
1720    float m0          = powf(2.0f, -(max_bias) / n_head_log2);
1721    float m1          = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1722
1723    std::vector<uint32_t> params = {
1724        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1725        mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1726        has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1727        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1728        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1729        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1730        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1731        mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
1732        mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
1733        mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
1734        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1735        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1736        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1737        (uint32_t) ggml_nelements(dst),
1738        (uint32_t) src0->ne[0],
1739        (uint32_t) src0->ne[1],
1740        (uint32_t) src0->ne[2],
1741        mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
1742        mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
1743        *(uint32_t *) dst->op_params,  // scale
1744        *(uint32_t *) &max_bias,
1745        *(uint32_t *) &n_head_log2,
1746        *(uint32_t *) &m0,
1747        *(uint32_t *) &m1
1748    };
1749
1750    std::vector<wgpu::BindGroupEntry> entries = {
1751        { .binding = 0,
1752         .buffer  = ggml_webgpu_tensor_buf(src0),
1753         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
1754         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) }
1755    };
1756    uint32_t binding_num = 1;
1757    if (mask_type < 2) {
1758        entries.push_back({ .binding = binding_num,
1759                            .buffer  = ggml_webgpu_tensor_buf(src1),
1760                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
1761                            .size    = ggml_webgpu_tensor_binding_size(ctx, src1) });
1762        binding_num++;
1763    }
1764    if (has_sink) {
1765        entries.push_back({ .binding = binding_num,
1766                            .buffer  = ggml_webgpu_tensor_buf(src2),
1767                            .offset  = ggml_webgpu_tensor_align_offset(ctx, src2),
1768                            .size    = ggml_webgpu_tensor_binding_size(ctx, src2) });
1769        binding_num++;
1770    }
1771    if (!inplace) {
1772        entries.push_back({ .binding = binding_num,
1773                            .buffer  = ggml_webgpu_tensor_buf(dst),
1774                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1775                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
1776    }
1777
1778    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
1779                                     ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
1780                                     ggml_nrows(dst));
1781}
1782
1783static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1784    std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1785                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1786                                     (uint32_t) src->ne[0] };
1787
1788    std::vector<wgpu::BindGroupEntry> entries = {
1789        { .binding = 0,
1790         .buffer  = ggml_webgpu_tensor_buf(src),
1791         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
1792         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
1793        { .binding = 1,
1794         .buffer  = ggml_webgpu_tensor_buf(dst),
1795         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
1796         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
1797    };
1798
1799    ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
1800        .vec4        = src->ne[0] % 4 == 0,
1801        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1802    };
1803
1804    webgpu_pipeline pipeline;
1805    auto            it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
1806    if (it != ctx->argmax_pipelines.end()) {
1807        pipeline = it->second;
1808    } else {
1809        ggml_webgpu_processed_shader processed =
1810            ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
1811        pipeline =
1812            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1813        ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
1814    }
1815    uint32_t wg_x = ggml_nelements(dst);
1816    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1817}
1818
1819static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1820    bool          is_top_k = dst->op == GGML_OP_TOP_K;
1821    // ascending order is 0, descending order is 1
1822    const int32_t order    = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0);
1823
1824    ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = {
1825        .max_wg_size        = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1826        .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1827        .order              = order
1828    };
1829
1830    webgpu_pipeline argsort_pipeline;
1831    auto            it = ctx->argsort_pipelines.find(order);
1832    if (it != ctx->argsort_pipelines.end()) {
1833        argsort_pipeline = it->second;
1834    } else {
1835        ggml_webgpu_processed_shader processed =
1836            ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx);
1837        argsort_pipeline =
1838            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1839        argsort_pipeline.context = processed.decisions;
1840        ctx->argsort_pipelines.emplace(order, argsort_pipeline);
1841    }
1842    auto * argsort_decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(argsort_pipeline.context.get());
1843
1844    webgpu_pipeline argsort_merge_pipeline;
1845    it = ctx->argsort_merge_pipelines.find(order);
1846    if (it != ctx->argsort_merge_pipelines.end()) {
1847        argsort_merge_pipeline = it->second;
1848    } else {
1849        ggml_webgpu_processed_shader processed =
1850            ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx);
1851        argsort_merge_pipeline =
1852            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1853        argsort_merge_pipeline.context = processed.decisions;
1854        ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline);
1855    }
1856
1857    const uint32_t src_ne0 = (uint32_t) src->ne[0];
1858    const uint32_t nrows   = (uint32_t) ggml_nrows(src);
1859    const uint32_t npr     = CEIL_DIV(src_ne0, argsort_decisions->wg_size);
1860    const uint32_t block_size =
1861        is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size;
1862    uint32_t out_ne0 = src_ne0;
1863    if (is_top_k) {
1864        if (npr > 1) {
1865            const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size;
1866            out_ne0                  = (npr - 1) * block_size + std::min(last_tile, block_size);
1867        } else {
1868            out_ne0 = block_size;
1869        }
1870    }
1871
1872    uint32_t merge_len    = block_size;
1873    uint32_t merge_passes = 0;
1874    while (merge_len < out_ne0) {
1875        merge_len <<= 1;
1876        merge_passes++;
1877    }
1878
1879    const bool start_in_tmp = (merge_passes % 2) == 1;
1880
1881    const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
1882    const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);
1883    const size_t tmp_offset =
1884        ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
1885    const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
1886    const size_t dst_binding_size =
1887        ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT);
1888
1889    const uint32_t offset_src  = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type));
1890    const uint32_t offset_dst  = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type));
1891    const uint32_t offset_tmp  = 0;
1892    const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type));
1893    const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type));
1894    const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type));
1895    const uint32_t stride_idx1 = out_ne0;
1896    const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1];
1897    const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2];
1898
1899    std::vector<webgpu_pipeline>                   pipelines;
1900    std::vector<std::vector<uint32_t>>             params_list;
1901    std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
1902    std::vector<std::pair<uint32_t, uint32_t>>     workgroups_list;
1903
1904    const uint32_t init_offset       = start_in_tmp ? offset_tmp : offset_dst;
1905    const size_t   init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
1906    const size_t   init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size;
1907
1908    std::vector<uint32_t> init_params = {
1909        offset_src,  init_offset, stride_src1, stride_src2,           stride_src3,           stride_idx1,
1910        stride_idx2, stride_idx3, src_ne0,     (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0,
1911        block_size,  npr,         nrows
1912    };
1913
1914    const uint32_t                    total_wg_init = npr * nrows;
1915    const uint32_t                    max_wg    = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1916    const uint32_t                    wg_x_init = std::min(total_wg_init, max_wg);
1917    const uint32_t                    wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
1918    std::vector<wgpu::BindGroupEntry> init_entries = {
1919        { .binding = 0,
1920         .buffer  = ggml_webgpu_tensor_buf(src),
1921         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
1922         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
1923        { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size }
1924    };
1925
1926    pipelines.push_back(argsort_pipeline);
1927    params_list.push_back(std::move(init_params));
1928    entries_list.push_back(std::move(init_entries));
1929    workgroups_list.push_back({ wg_x_init, wg_y_init });
1930
1931    if (merge_passes == 0) {
1932        return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
1933                                               entries_list, workgroups_list);
1934    }
1935
1936    bool     in_is_tmp = start_in_tmp;
1937    uint32_t len       = block_size;
1938    while (len < out_ne0) {
1939        const uint32_t nm = CEIL_DIV(out_ne0, 2 * len);
1940
1941        const bool     out_is_tmp  = !in_is_tmp;
1942        const uint32_t offset_in   = in_is_tmp ? offset_tmp : offset_dst;
1943        const uint32_t offset_out  = out_is_tmp ? offset_tmp : offset_dst;
1944        const size_t   align_in    = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
1945        const size_t   align_out   = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
1946        const size_t   size_in     = in_is_tmp ? tmp_binding_size : dst_binding_size;
1947        const size_t   size_out    = out_is_tmp ? tmp_binding_size : dst_binding_size;
1948        const uint32_t top_k_out   = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0;
1949        const uint32_t stride_out1 = top_k_out;
1950        const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1];
1951        const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2];
1952
1953        std::vector<uint32_t> merge_params = { offset_src,
1954                                               offset_in,
1955                                               offset_out,
1956                                               stride_src1,
1957                                               stride_src2,
1958                                               stride_src3,
1959                                               stride_idx1,
1960                                               stride_idx2,
1961                                               stride_idx3,
1962                                               stride_out1,
1963                                               stride_out2,
1964                                               stride_out3,
1965                                               out_ne0,
1966                                               (uint32_t) src->ne[1],
1967                                               (uint32_t) src->ne[2],
1968                                               top_k_out,
1969                                               len,
1970                                               nm,
1971                                               nrows };
1972
1973        std::vector<wgpu::BindGroupEntry> merge_entries = {
1974            { .binding = 0,
1975             .buffer  = ggml_webgpu_tensor_buf(src),
1976             .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
1977             .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
1978            { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in },
1979            { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out }
1980        };
1981
1982        const uint32_t total_wg_merge = nm * nrows;
1983        const uint32_t wg_x_merge     = std::min(total_wg_merge, max_wg);
1984        const uint32_t wg_y_merge     = CEIL_DIV(total_wg_merge, wg_x_merge);
1985        workgroups_list.push_back({ wg_x_merge, wg_y_merge });
1986        pipelines.push_back(argsort_merge_pipeline);
1987        params_list.push_back(std::move(merge_params));
1988        entries_list.push_back(std::move(merge_entries));
1989
1990        len <<= 1;
1991        in_is_tmp = !in_is_tmp;
1992    }
1993
1994    return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list,
1995                                           workgroups_list);
1996}
1997
1998static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1999    std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
2000                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2001                                     (uint32_t) src->ne[0] };
2002
2003    std::vector<wgpu::BindGroupEntry> entries = {
2004        { .binding = 0,
2005         .buffer  = ggml_webgpu_tensor_buf(src),
2006         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
2007         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
2008        { .binding = 1,
2009         .buffer  = ggml_webgpu_tensor_buf(dst),
2010         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
2011         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
2012    };
2013
2014    ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
2015        .vec4        = false,
2016        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
2017    };
2018    webgpu_pipeline pipeline;
2019    auto            it = ctx->cumsum_pipelines.find(1);
2020    if (it != ctx->cumsum_pipelines.end()) {
2021        pipeline = it->second;
2022    } else {
2023        ggml_webgpu_processed_shader processed =
2024            ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
2025        pipeline =
2026            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
2027        ctx->cumsum_pipelines.emplace(1, pipeline);
2028    }
2029    uint32_t wg_x = ggml_nrows(dst);
2030    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
2031}
2032
2033static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
2034    bool                  total_sum = dst->op == GGML_OP_SUM;
2035    std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
2036                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2037                                     total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
2038                                     total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
2039                                     total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
2040                                     total_sum ? static_cast<uint32_t>(ggml_nelements(src)) : (uint32_t) src->ne[0],
2041                                     total_sum ? 1 : (uint32_t) src->ne[1],
2042                                     total_sum ? 1 : (uint32_t) src->ne[2] };
2043
2044    std::vector<wgpu::BindGroupEntry> entries = {
2045        { .binding = 0,
2046         .buffer  = ggml_webgpu_tensor_buf(src),
2047         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
2048         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
2049        { .binding = 1,
2050         .buffer  = ggml_webgpu_tensor_buf(dst),
2051         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
2052         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
2053    };
2054
2055    ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
2056        .vec4        = false,
2057        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
2058    };
2059
2060    webgpu_pipeline pipeline;
2061    auto            it = ctx->sum_rows_pipelines.find(1);
2062    if (it != ctx->sum_rows_pipelines.end()) {
2063        pipeline = it->second;
2064    } else {
2065        ggml_webgpu_processed_shader processed =
2066            ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
2067        pipeline =
2068            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
2069        ctx->sum_rows_pipelines.emplace(1, pipeline);
2070    }
2071    uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
2072    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
2073}
2074
2075// Returns the encoded command, or std::nullopt if the operation is a no-op
2076static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
2077    if (ggml_is_empty(node)) {
2078        return std::nullopt;
2079    }
2080    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
2081        return std::nullopt;
2082    }
2083    WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
2084
2085    ggml_tensor * src0 = node->src[0];
2086    ggml_tensor * src1 = node->src[1];
2087    ggml_tensor * src2 = node->src[2];
2088
2089    switch (node->op) {
2090            // no-ops
2091        case GGML_OP_NONE:
2092        case GGML_OP_VIEW:
2093        case GGML_OP_PERMUTE:
2094        case GGML_OP_TRANSPOSE:
2095        case GGML_OP_RESHAPE:
2096            return std::nullopt;
2097        case GGML_OP_CPY:
2098        case GGML_OP_CONT:
2099            return ggml_webgpu_cpy(ctx, src0, node);
2100        case GGML_OP_SET_ROWS:
2101            return ggml_webgpu_set_rows(ctx, src0, src1, node);
2102        case GGML_OP_GET_ROWS:
2103            return ggml_webgpu_get_rows(ctx, src0, src1, node);
2104        case GGML_OP_MUL_MAT:
2105            return ggml_webgpu_mul_mat(ctx, src0, src1, node);
2106        case GGML_OP_FLASH_ATTN_EXT:
2107#ifndef __EMSCRIPTEN__
2108            return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
2109#else
2110            return std::nullopt;
2111#endif
2112        case GGML_OP_ADD:
2113        case GGML_OP_SUB:
2114        case GGML_OP_MUL:
2115        case GGML_OP_DIV:
2116            return ggml_webgpu_binary_op(ctx, src0, src1, node);
2117        case GGML_OP_RMS_NORM:
2118            return ggml_webgpu_rms_norm(ctx, src0, node);
2119        case GGML_OP_ROPE:
2120            return ggml_webgpu_rope(ctx, src0, src1, src2, node);
2121        case GGML_OP_GLU:
2122            return ggml_webgpu_glu(ctx, src0, src1, node);
2123        case GGML_OP_SCALE:
2124            return ggml_webgpu_scale(ctx, src0, node);
2125        case GGML_OP_SOFT_MAX:
2126            return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
2127        case GGML_OP_UNARY:
2128            return ggml_webgpu_unary_op(ctx, src0, node);
2129        case GGML_OP_CLAMP:
2130            return ggml_webgpu_unary_op(ctx, src0, node);
2131        case GGML_OP_FILL:
2132            return ggml_webgpu_unary_op(ctx, src0, node);
2133        case GGML_OP_LOG:
2134            return ggml_webgpu_unary_op(ctx, src0, node);
2135        case GGML_OP_PAD:
2136            return ggml_webgpu_pad(ctx, src0, node);
2137        case GGML_OP_ARGMAX:
2138            return ggml_webgpu_argmax(ctx, src0, node);
2139        case GGML_OP_ARGSORT:
2140            return ggml_webgpu_argsort(ctx, src0, node);
2141        case GGML_OP_TOP_K:
2142            // we reuse the same argsort implementation for top_k
2143            return ggml_webgpu_argsort(ctx, src0, node);
2144        case GGML_OP_CUMSUM:
2145            return ggml_webgpu_cumsum(ctx, src0, node);
2146        case GGML_OP_SUM:
2147        case GGML_OP_SUM_ROWS:
2148            return ggml_webgpu_sum_rows(ctx, src0, node);
2149        default:
2150            return std::nullopt;
2151    }
2152}
2153
2154static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2155    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
2156
2157    ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
2158    webgpu_context                ctx         = backend_ctx->webgpu_ctx;
2159
2160    WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
2161
2162    ctx->global_ctx->inflight_threads++;
2163
2164    std::vector<webgpu_command>            commands;
2165    std::vector<webgpu_submission_futures> futures;
2166    for (int i = 0; i < cgraph->n_nodes; i++) {
2167        if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
2168            commands.push_back(*cmd);
2169        }
2170        // compute the batch size based on the number of inflight threads
2171        uint32_t inflight_threads = ctx->global_ctx->inflight_threads;
2172        uint32_t batch_size       = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
2173                                             WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
2174        if (commands.size() >= batch_size) {
2175            futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool,
2176                                                         &ctx->set_rows_error_buf_pool));
2177            // Process events and check for completed submissions
2178            ctx->global_ctx->instance.ProcessEvents();
2179            ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
2180            commands.clear();
2181        }
2182    }
2183    if (!commands.empty()) {
2184        webgpu_submission_futures new_futures =
2185            ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
2186        futures.push_back(new_futures);
2187    }
2188
2189    ggml_backend_webgpu_wait(ctx->global_ctx, futures);
2190    ctx->global_ctx->inflight_threads--;
2191    WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
2192    return GGML_STATUS_SUCCESS;
2193}
2194
2195static ggml_backend_i ggml_backend_webgpu_i = {
2196    /* .get_name                = */ ggml_backend_webgpu_name,
2197    /* .free                    = */ ggml_backend_webgpu_free,
2198    /* .set_tensor_async        = */ NULL,
2199    /* .get_tensor_async        = */ NULL,
2200    /* .cpy_tensor_async        = */ NULL,
2201    /* .synchronize             = */ NULL,
2202    /* .graph_plan_create       = */ NULL,
2203    /* .graph_plan_free         = */ NULL,
2204    /* .graph_plan_update       = */ NULL,
2205    /* .graph_plan_compute      = */ NULL,
2206    /* .graph_compute           = */ ggml_backend_webgpu_graph_compute,
2207    /* .event_record            = */ NULL,
2208    /* .event_wait              = */ NULL,
2209    /* .graph_optimize          = */ NULL,
2210};
2211
2212/* End GGML Backend Interface */
2213
2214/* GGML Backend Buffer Interface */
2215
2216static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2217    ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
2218    if (ctx != nullptr && ctx->buffer != nullptr) {
2219        ctx->buffer.Destroy();
2220        delete ctx;
2221    }
2222}
2223
2224// Returns the "fake" base pointer.
2225static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
2226    GGML_UNUSED(buffer);
2227    return webgpu_ptr_base;
2228}
2229
2230static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
2231                                                     ggml_tensor *         tensor,
2232                                                     uint8_t               value,
2233                                                     size_t                offset,
2234                                                     size_t                size) {
2235    if (size == 0) {
2236        WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
2237        return;
2238    }
2239
2240    WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
2241
2242    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2243
2244    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
2245                                                                 << ", " << offset << ", " << size << ")");
2246
2247    size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
2248
2249    // This is a trick to set all bytes of a u32 to the same 1 byte value.
2250    uint32_t val32 = (uint32_t) value * 0x01010101;
2251    ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size);
2252    WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx);
2253}
2254
2255static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
2256                                                  ggml_tensor *         tensor,
2257                                                  const void *          data,
2258                                                  size_t                offset,
2259                                                  size_t                size) {
2260    WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
2261    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2262
2263    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
2264                                                              << ", " << offset << ", " << size << ")");
2265
2266    size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
2267
2268    buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
2269
2270    if (size % 4 != 0) {
2271        // If size is not a multiple of 4, we need to memset the remaining bytes
2272        size_t remaining_size = size % 4;
2273
2274        // pack the remaining bytes into a uint32_t
2275        uint32_t val32 = 0;
2276
2277        for (size_t i = 0; i < remaining_size; i++) {
2278            ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
2279        }
2280        // memset the remaining bytes
2281        ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
2282                                          total_offset + (size - remaining_size), remaining_size);
2283    } else {
2284        // wait for WriteBuffer to complete
2285        buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
2286                                                  wgpu::CallbackMode::AllowSpontaneous,
2287                                                  [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
2288                                                      if (status != wgpu::QueueWorkDoneStatus::Success) {
2289                                                          GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
2290                                                                         std::string(message).c_str());
2291                                                      }
2292                                                  }),
2293                                              UINT64_MAX);
2294    }
2295    WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
2296}
2297
2298static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
2299                                                  const ggml_tensor *   tensor,
2300                                                  void *                data,
2301                                                  size_t                offset,
2302                                                  size_t                size) {
2303    WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
2304    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2305    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
2306                                                              << ", " << offset << ", " << size << ")");
2307    wgpu::Device device = buf_ctx->global_ctx->device;
2308
2309    size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
2310
2311    size_t final_size = size;
2312    if (size % 4 != 0) {
2313        // If size is not a multiple of 4, we need to round it up to the next multiple of 4
2314        final_size = size + (4 - (size % 4));
2315    }
2316
2317    std::lock_guard<std::recursive_mutex> lock(buf_ctx->global_ctx->mutex);
2318
2319    if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr ||
2320        buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) {
2321        // Create a new staging buffer if it doesn't exist or is too small
2322        if (buf_ctx->global_ctx->get_tensor_staging_buf) {
2323            buf_ctx->global_ctx->get_tensor_staging_buf.Destroy();
2324        }
2325        ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size,
2326                                  wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
2327    }
2328
2329    // Copy the data from the buffer to the staging buffer
2330    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
2331    encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0,
2332                               final_size);
2333    wgpu::CommandBuffer commands = encoder.Finish();
2334
2335    // Submit the command buffer to the queue
2336    buf_ctx->global_ctx->queue.Submit(1, &commands);
2337
2338    // Map the staging buffer to read the data
2339    ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf,
2340                                   wgpu::MapMode::Read, 0, final_size);
2341    // Must specify size here since the staging buffer might be larger than the tensor size
2342    const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
2343
2344    // Copy the data from the mapped range to the output buffer
2345    std::memcpy(data, mapped_range, size);
2346    buf_ctx->global_ctx->get_tensor_staging_buf.Unmap();
2347    WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx);
2348}
2349
2350static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2351    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
2352    WEBGPU_CPU_PROFILE_TOTAL_START(clear);
2353    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2354    ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size);
2355    WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx);
2356}
2357
2358static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
2359    /* .free_buffer     = */ ggml_backend_webgpu_buffer_free_buffer,
2360    /* .get_base        = */ ggml_backend_webgpu_buffer_get_base,
2361    /* .init_tensor     = */ NULL,  // TODO: optional, needed?
2362    /* .memset_tensor   = */ ggml_backend_webgpu_buffer_memset_tensor,
2363    /* .set_tensor      = */ ggml_backend_webgpu_buffer_set_tensor,
2364    /* .get_tensor      = */ ggml_backend_webgpu_buffer_get_tensor,
2365    /* .cpy_tensor      = */ NULL,  // TODO: optional, implement this
2366    /* .clear           = */ ggml_backend_webgpu_buffer_clear,
2367    /* .reset           = */ NULL,  // TODO: optional, think it coordinates with .init_tensor
2368};
2369
2370/* End GGML Backend Buffer Interface */
2371
2372/* GGML Backend Buffer Type Interface */
2373
2374static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
2375    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2376    return ctx->device_name.c_str();
2377}
2378
2379static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
2380                                                                          size_t                     size) {
2381    static std::atomic<int> buffer_count;
2382    int                     buffer_id = buffer_count++;
2383    std::string             buf_name  = "tensor_buf" + std::to_string(buffer_id);
2384    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
2385
2386    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2387    wgpu::Buffer                         buf;
2388    ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
2389                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
2390                              buf_name.c_str());
2391
2392    ggml_backend_webgpu_buffer_context * buf_ctx =
2393        new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx);
2394
2395    return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
2396}
2397
2398static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
2399    ggml_backend_webgpu_device_context * dev_ctx =
2400        static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2401    return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
2402}
2403
2404// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
2405static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
2406    ggml_backend_webgpu_device_context * dev_ctx =
2407        static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2408    return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize;
2409}
2410
2411static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
2412                                                             const ggml_tensor *        tensor) {
2413    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2414    size_t                               res = ggml_nbytes(tensor);
2415    switch (tensor->op) {
2416        case GGML_OP_ARGSORT:
2417            res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
2418                               WEBGPU_STORAGE_BUF_BINDING_MULT);
2419            break;
2420        case GGML_OP_TOP_K:
2421            {
2422                const ggml_tensor * src0 = tensor->src[0];
2423                if (src0) {
2424                    const size_t full = sizeof(int32_t) * ggml_nelements(src0);
2425                    res               = ROUNDUP_POW2(
2426                        full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
2427                        WEBGPU_STORAGE_BUF_BINDING_MULT);
2428                }
2429            }
2430            break;
2431        default:
2432            break;
2433    }
2434    return res;
2435}
2436
2437/* End GGML Backend Buffer Type Interface */
2438
2439/* GGML Backend Device Interface */
2440
2441static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
2442    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2443    return ctx->device_name.c_str();
2444}
2445
2446static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
2447    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2448    return ctx->device_desc.c_str();
2449}
2450
2451static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2452    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2453    // TODO: for now, return maxBufferSize as both free and total memory
2454    // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
2455    uint64_t                             max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize;
2456    // If we're on a 32-bit system, clamp to UINTPTR_MAX
2457#if UINTPTR_MAX < UINT64_MAX
2458    uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
2459    if (max_buffer_size > max_ptr_size) {
2460        max_buffer_size = max_ptr_size;
2461    }
2462#endif
2463    *free  = static_cast<size_t>(max_buffer_size);
2464    *total = static_cast<size_t>(max_buffer_size);
2465}
2466
2467static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
2468    GGML_UNUSED(dev);
2469    return GGML_BACKEND_DEVICE_TYPE_GPU;
2470}
2471
2472static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2473    props->name        = ggml_backend_webgpu_device_get_name(dev);
2474    props->description = ggml_backend_webgpu_device_get_description(dev);
2475    props->type        = ggml_backend_webgpu_device_get_type(dev);
2476    ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
2477    props->caps = {
2478        /* .async                 = */ false,
2479        /* .host_buffer           = */ false,
2480        /* .buffer_from_host_ptr  = */ false,
2481        /* .events                = */ false,
2482    };
2483}
2484
2485static ggml_guid_t ggml_backend_webgpu_guid(void) {
2486    static const char * guid_str = "__ggml_webgpu :)";
2487    return reinterpret_cast<ggml_guid_t>((void *) guid_str);
2488}
2489
2490// Workgroup size is a common constant
2491static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
2492    std::vector<wgpu::ConstantEntry> constants(1);
2493    constants[0].key   = "wg_size";
2494    constants[0].value = wg_size;
2495    return constants;
2496}
2497
2498static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
2499    // we use the maximum workgroup size for the memset pipeline
2500    size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
2501    // Size the bytes_per_thread so that the largest buffer size can be handled
2502    ctx->capabilities.memset_bytes_per_thread =
2503        CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
2504    std::vector<wgpu::ConstantEntry> constants(2);
2505    constants[0].key         = "wg_size";
2506    constants[0].value       = WEBGPU_MAX_WG_SIZE;
2507    constants[1].key         = "bytes_per_thread";
2508    constants[1].value       = ctx->capabilities.memset_bytes_per_thread;
2509    ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
2510}
2511
2512static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
2513    // Q4/Q5/Q8 classic quantizations
2514    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
2515        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
2516    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
2517        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
2518    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
2519        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
2520    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
2521        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
2522    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
2523        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
2524
2525    // K-quantizations
2526    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
2527        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
2528    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
2529        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
2530    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
2531        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
2532    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
2533        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
2534    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
2535        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
2536
2537    // IQ quantizations (2-, 3-, 4-bit variants)
2538    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
2539        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
2540    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
2541        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
2542    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
2543        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
2544
2545    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
2546        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
2547    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
2548        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
2549
2550    // 1-bit and 4-bit IQ variants
2551    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
2552        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
2553    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
2554        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
2555    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
2556        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
2557    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
2558        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
2559
2560    std::string proc_mul_mat_f32_f32;
2561    std::string proc_mul_mat_f32_f32_vec;
2562    std::string proc_mul_mat_f16_f32;
2563    std::string proc_mul_mat_f16_f32_vec;
2564    std::string proc_mul_mat_f16_f16;
2565    std::string proc_mul_mat_f16_f16_vec;
2566    std::string proc_mul_mat_q4_0_f32;
2567    std::string proc_mul_mat_q4_0_f32_vec;
2568
2569    std::vector<wgpu::ConstantEntry> mul_mat_constants;
2570#ifndef __EMSCRIPTEN__
2571    if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) {
2572        std::map<std::string, std::string> sg_matrix_repls;
2573        sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] =
2574            std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size);
2575        sg_matrix_repls["WEBGPU_TILE_K"]            = std::to_string(WEBGPU_MUL_MAT_TILE_K);
2576        sg_matrix_repls["WEBGPU_SUBGROUP_M"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
2577        sg_matrix_repls["WEBGPU_SUBGROUP_N"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
2578        sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
2579        sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
2580        sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"]     = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m);
2581        sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"]     = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n);
2582        sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"]     = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k);
2583        proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
2584        proc_mul_mat_f32_f32_vec =
2585            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
2586        proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
2587        proc_mul_mat_f16_f32_vec =
2588            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
2589        proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
2590        proc_mul_mat_f16_f16_vec =
2591            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
2592        proc_mul_mat_q4_0_f32 =
2593            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
2594        proc_mul_mat_q4_0_f32_vec =
2595            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
2596    } else {
2597#endif
2598        mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
2599        mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
2600        mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
2601
2602        std::map<std::string, std::string> reg_repls;
2603        reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
2604        reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
2605
2606        proc_mul_mat_f32_f32      = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
2607        proc_mul_mat_f32_f32_vec  = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
2608        proc_mul_mat_f16_f32      = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
2609        proc_mul_mat_f16_f32_vec  = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
2610        proc_mul_mat_f16_f16      = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
2611        proc_mul_mat_f16_f16_vec  = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
2612        proc_mul_mat_q4_0_f32     = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
2613        proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
2614#ifndef __EMSCRIPTEN__
2615    }
2616#endif
2617
2618    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2619        webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
2620    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2621        webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
2622    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2623        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
2624    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2625        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
2626    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
2627        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
2628    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2629        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
2630    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2631        webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
2632    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2633        webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
2634
2635    std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
2636    mul_mat_vec_constants[0].key   = "WORKGROUP_SIZE";
2637    mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
2638    mul_mat_vec_constants[1].key   = "TILE_K";
2639    mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
2640    mul_mat_vec_constants[2].key   = "OUTPUTS_PER_WG";
2641    mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
2642
2643    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2644        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
2645    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2646        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
2647    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2648        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
2649    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2650        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
2651    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
2652        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
2653    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2654        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
2655    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2656        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
2657}
2658
2659static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
2660    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2661
2662    webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
2663        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
2664    webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2665        webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
2666
2667    webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
2668        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
2669    webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
2670        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
2671    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
2672        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
2673    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
2674        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
2675    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
2676        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
2677    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
2678        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
2679    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
2680        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
2681
2682    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
2683        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
2684    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
2685        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
2686    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
2687        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
2688    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
2689        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
2690    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
2691        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
2692
2693    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline(
2694        webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
2695    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
2696        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
2697    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
2698        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
2699    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline(
2700        webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
2701    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
2702        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
2703    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
2704        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
2705    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
2706        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
2707    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
2708        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
2709    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
2710        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
2711}
2712
2713static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
2714    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2715
2716    webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
2717        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
2718    webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
2719        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
2720    webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
2721        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
2722    webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
2723        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
2724    webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
2725        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
2726}
2727
2728static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
2729    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2730
2731    webgpu_ctx->rms_norm_pipelines[0] =
2732        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
2733    webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
2734        webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
2735}
2736
2737static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
2738    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2739
2740    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
2741        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
2742    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
2743        webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
2744    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
2745        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
2746    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
2747        webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
2748
2749    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
2750        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
2751    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
2752        webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
2753    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
2754        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
2755    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
2756        webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
2757}
2758
2759static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
2760    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2761
2762    // REGLU
2763    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
2764        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
2765    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
2766        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
2767    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
2768        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
2769    webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
2770        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
2771
2772    // GEGLU
2773    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
2774        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
2775    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
2776        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
2777    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
2778        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
2779    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
2780        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
2781
2782    // SWIGLU
2783    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
2784        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
2785    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
2786        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
2787    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2788        webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
2789    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2790        webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
2791
2792    // SWIGLU_OAI
2793    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
2794        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
2795    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2796        webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
2797
2798    // GEGLU_ERF
2799    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
2800        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
2801    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
2802        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
2803    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2804        webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
2805    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2806        webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
2807
2808    // GEGLU_QUICK
2809    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
2810        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
2811    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
2812        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
2813    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2814        webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
2815    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2816        webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
2817}
2818
2819static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
2820    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2821
2822    webgpu_ctx->scale_pipelines[0] =
2823        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants);
2824    webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace,
2825                                                                 "scale_f32_inplace", constants);
2826}
2827
2828static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
2829    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2830
2831    // f32 (no mask)
2832    webgpu_ctx->soft_max_pipelines[2][0][0] =
2833        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
2834    webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
2835        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
2836    webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
2837        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
2838    webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
2839        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
2840
2841    // f32 mask (mask_type = 0)
2842    webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
2843        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
2844    webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
2845        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
2846    webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
2847        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
2848    webgpu_ctx->soft_max_pipelines[0][1][1] =
2849        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
2850                                    "soft_max_f32_mask_f32_sink_inplace", constants);
2851
2852    // f16 mask (mask_type = 1)
2853    webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
2854        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
2855    webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
2856        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
2857    webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
2858        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
2859    webgpu_ctx->soft_max_pipelines[1][1][1] =
2860        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
2861                                    "soft_max_f32_mask_f16_sink_inplace", constants);
2862}
2863
2864static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
2865    wgpu::RequestAdapterOptions options = {};
2866
2867#ifndef __EMSCRIPTEN__
2868    // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
2869    const char * const          adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
2870    wgpu::DawnTogglesDescriptor adapterTogglesDesc;
2871    adapterTogglesDesc.enabledToggles     = adapterEnabledToggles;
2872    adapterTogglesDesc.enabledToggleCount = 2;
2873    options.nextInChain                   = &adapterTogglesDesc;
2874#endif
2875
2876    ctx->webgpu_global_ctx->instance.WaitAny(
2877        ctx->webgpu_global_ctx->instance.RequestAdapter(
2878            &options, wgpu::CallbackMode::AllowSpontaneous,
2879            [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
2880                if (status != wgpu::RequestAdapterStatus::Success) {
2881                    GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
2882                    return;
2883                }
2884                ctx->webgpu_global_ctx->adapter = std::move(adapter);
2885            }),
2886        UINT64_MAX);
2887    GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
2888
2889    ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
2890
2891    wgpu::AdapterInfo info{};
2892#ifndef __EMSCRIPTEN__
2893    wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
2894    if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2895        info.nextInChain = &subgroup_matrix_configs;
2896    }
2897#endif
2898    ctx->webgpu_global_ctx->adapter.GetInfo(&info);
2899    wgpu::SupportedFeatures features;
2900    ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
2901    // we require f16 support
2902    GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
2903
2904#ifndef __EMSCRIPTEN__
2905    // Only support square f16 matrices of size 8 or 16 for now
2906    bool valid_subgroup_matrix_config = false;
2907    if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2908        for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
2909            const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
2910            if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
2911                config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2912                config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2913                ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
2914                ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
2915                ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K;
2916                valid_subgroup_matrix_config                  = true;
2917                break;
2918            }
2919        }
2920    }
2921    ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
2922#endif
2923
2924    // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
2925    // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
2926    ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
2927    // Initialize device
2928    std::vector<wgpu::FeatureName> required_features       = { wgpu::FeatureName::ShaderF16 };
2929
2930#ifndef __EMSCRIPTEN__
2931    required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
2932    if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
2933        required_features.push_back(wgpu::FeatureName::Subgroups);
2934        required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
2935    }
2936#endif
2937
2938#ifdef GGML_WEBGPU_GPU_PROFILE
2939    required_features.push_back(wgpu::FeatureName::TimestampQuery);
2940#endif
2941
2942    wgpu::DeviceDescriptor dev_desc;
2943    dev_desc.requiredLimits       = &ctx->webgpu_global_ctx->capabilities.limits;
2944    dev_desc.requiredFeatures     = required_features.data();
2945    dev_desc.requiredFeatureCount = required_features.size();
2946    dev_desc.SetDeviceLostCallback(
2947        wgpu::CallbackMode::AllowSpontaneous,
2948        [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
2949            if (reason == wgpu::DeviceLostReason::Destroyed) {
2950                return;
2951            }
2952            GGML_UNUSED(device);
2953            GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
2954                           std::string(message).c_str());
2955        });
2956    dev_desc.SetUncapturedErrorCallback(
2957        [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
2958            GGML_UNUSED(device);
2959            GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
2960                       std::string(message).c_str());
2961        });
2962
2963#ifndef __EMSCRIPTEN__
2964    // Enable Dawn-specific toggles to increase native performance
2965    // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
2966    //       only for native performance?
2967    const char * const deviceEnabledToggles[]  = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2968                                                   "disable_polyfills_on_integer_div_and_mod" };
2969    const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2970    wgpu::DawnTogglesDescriptor deviceTogglesDesc;
2971    deviceTogglesDesc.enabledToggles      = deviceEnabledToggles;
2972    deviceTogglesDesc.enabledToggleCount  = 4;
2973    deviceTogglesDesc.disabledToggles     = deviceDisabledToggles;
2974    deviceTogglesDesc.disabledToggleCount = 1;
2975
2976    dev_desc.nextInChain = &deviceTogglesDesc;
2977#endif
2978
2979    ctx->webgpu_global_ctx->instance.WaitAny(
2980        ctx->webgpu_global_ctx->adapter.RequestDevice(
2981            &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
2982            [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
2983                if (status != wgpu::RequestDeviceStatus::Success) {
2984                    GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
2985                    return;
2986                }
2987                ctx->webgpu_global_ctx->device = std::move(device);
2988            }),
2989        UINT64_MAX);
2990    GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr);
2991
2992    ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx);
2993    ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES,
2994                                                 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2995                                                 wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
2996    ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue();
2997
2998#ifdef GGML_WEBGPU_GPU_PROFILE
2999    // Initialize buffer pool for timestamp queries, used for profiling
3000    ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
3001        ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
3002        wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
3003        wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
3004#endif
3005
3006    GGML_LOG_INFO(
3007        "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
3008        "device_desc: %s\n",
3009        info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
3010        std::string(info.device).c_str(), std::string(info.description).c_str());
3011    return true;
3012}
3013
3014static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
3015    ggml_backend_webgpu_device_context * dev_ctx    = (ggml_backend_webgpu_device_context *) dev->context;
3016    webgpu_context                       webgpu_ctx = std::make_shared<webgpu_context_struct>();
3017    webgpu_ctx->global_ctx                          = dev_ctx->webgpu_global_ctx;
3018    webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
3019                                    wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
3020                                    wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
3021    webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
3022                                             WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
3023                                             wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
3024                                             wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
3025
3026    ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
3027    ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
3028    ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
3029    ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
3030    ggml_webgpu_init_rope_pipeline(webgpu_ctx);
3031    ggml_webgpu_init_glu_pipeline(webgpu_ctx);
3032    ggml_webgpu_init_scale_pipeline(webgpu_ctx);
3033    ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
3034#ifdef GGML_WEBGPU_DEBUG
3035    // Initialize debug buffers
3036    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,
3037                              WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
3038                              wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
3039    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf,
3040                              WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
3041                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
3042#endif
3043    return webgpu_ctx;
3044}
3045
3046static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) {
3047    GGML_UNUSED(params);
3048
3049    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()");
3050
3051    ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
3052
3053    auto * backend_ctx      = new ggml_backend_webgpu_context();
3054    backend_ctx->name       = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
3055    backend_ctx->webgpu_ctx = initialize_webgpu_context(dev);
3056
3057    // See GGML Backend Interface section
3058    auto * backend = new ggml_backend();
3059    *backend       = {
3060        /* .guid      = */ ggml_backend_webgpu_guid(),
3061        /* .interface = */ ggml_backend_webgpu_i,
3062        /* .device    = */ dev,
3063        /* .context   = */ backend_ctx,
3064    };
3065    return backend;
3066}
3067
3068static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
3069    // See GGML Backend Buffer Type Interface section
3070
3071    static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
3072        /* .iface = */ {
3073                        /* .get_name         = */ ggml_backend_webgpu_buffer_type_get_name,
3074                        /* .alloc_buffer     = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
3075                        /* .get_alignment    = */ ggml_backend_webgpu_buffer_type_get_alignment,
3076                        /* .get_max_size     = */ ggml_backend_webgpu_buffer_type_get_max_size,
3077                        /* .get_alloc_size   = */ ggml_backend_webgpu_buffer_type_get_alloc_size,
3078                        /* .is_host          = */ NULL,  // defaults to false
3079        },
3080        /* .device  = */
3081        dev,
3082        /* .context = */
3083        NULL
3084    };
3085
3086    return &ggml_backend_webgpu_buffer_type;
3087}
3088
3089static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
3090    GGML_UNUSED(dev);
3091    return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
3092}
3093
3094static bool ggml_webgpu_supported_qtype(ggml_type type) {
3095    switch (type) {
3096        case GGML_TYPE_Q4_0:
3097        case GGML_TYPE_Q4_1:
3098        case GGML_TYPE_Q5_0:
3099        case GGML_TYPE_Q5_1:
3100        case GGML_TYPE_Q8_0:
3101        case GGML_TYPE_Q2_K:
3102        case GGML_TYPE_Q3_K:
3103        case GGML_TYPE_Q4_K:
3104        case GGML_TYPE_Q5_K:
3105        case GGML_TYPE_Q6_K:
3106        case GGML_TYPE_IQ2_XXS:
3107        case GGML_TYPE_IQ2_XS:
3108        case GGML_TYPE_IQ2_S:
3109        case GGML_TYPE_IQ3_XXS:
3110        case GGML_TYPE_IQ3_S:
3111        case GGML_TYPE_IQ1_S:
3112        case GGML_TYPE_IQ1_M:
3113        case GGML_TYPE_IQ4_NL:
3114        case GGML_TYPE_IQ4_XS:
3115            return true;
3116        default:
3117            return false;
3118    }
3119}
3120
3121static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
3122    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
3123
3124    ggml_tensor * src0 = op->src[0];
3125    ggml_tensor * src1 = op->src[1];
3126    ggml_tensor * src2 = op->src[2];
3127
3128    // on smaller devices (or CI), tensors may be larger than the max storage buffer size
3129    if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
3130        (src0 != nullptr &&
3131         ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3132        (src1 != nullptr &&
3133         ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
3134        return false;
3135    }
3136
3137    bool supports_op = false;
3138    switch (op->op) {
3139        case GGML_OP_NONE:
3140        case GGML_OP_VIEW:
3141        case GGML_OP_PERMUTE:
3142        case GGML_OP_TRANSPOSE:
3143        case GGML_OP_RESHAPE:
3144            supports_op = true;
3145            break;
3146        case GGML_OP_ADD:
3147        case GGML_OP_SUB:
3148        case GGML_OP_MUL:
3149        case GGML_OP_DIV:
3150            // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
3151            // see https://github.com/ggml-org/llama.cpp/pull/16857
3152            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
3153                          (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
3154            break;
3155        case GGML_OP_CPY:
3156        case GGML_OP_CONT:
3157            supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
3158                           (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
3159                          (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
3160            break;
3161        case GGML_OP_SET_ROWS:
3162            supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
3163                           (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
3164            break;
3165        case GGML_OP_GET_ROWS:
3166            if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {
3167                supports_op = (op->type == GGML_TYPE_F32);
3168            } else if (src0->type == GGML_TYPE_I32) {
3169                supports_op = op->type == GGML_TYPE_I32;
3170            }
3171            break;
3172        case GGML_OP_MUL_MAT:
3173            {
3174                switch (src1->type) {
3175                    case GGML_TYPE_F16:
3176                        supports_op |= (src0->type == GGML_TYPE_F16);
3177                        break;
3178                    case GGML_TYPE_F32:
3179                        switch (src0->type) {
3180                            case GGML_TYPE_F32:
3181                            case GGML_TYPE_F16:
3182                            case GGML_TYPE_Q4_0:
3183                            case GGML_TYPE_Q4_1:
3184                            case GGML_TYPE_Q5_0:
3185                            case GGML_TYPE_Q5_1:
3186                            case GGML_TYPE_Q8_0:
3187                            case GGML_TYPE_Q2_K:
3188                            case GGML_TYPE_Q3_K:
3189                            case GGML_TYPE_Q4_K:
3190                            case GGML_TYPE_Q5_K:
3191                            case GGML_TYPE_Q6_K:
3192                            case GGML_TYPE_IQ2_XXS:
3193                            case GGML_TYPE_IQ2_XS:
3194                            case GGML_TYPE_IQ2_S:
3195                            case GGML_TYPE_IQ3_XXS:
3196                            case GGML_TYPE_IQ3_S:
3197                            case GGML_TYPE_IQ1_S:
3198                            case GGML_TYPE_IQ1_M:
3199                            case GGML_TYPE_IQ4_NL:
3200                            case GGML_TYPE_IQ4_XS:
3201                                supports_op = true;
3202                                break;
3203                            default:
3204                                break;
3205                        }
3206                    default:
3207                        break;
3208                }
3209                break;
3210            }
3211        case GGML_OP_FLASH_ATTN_EXT:
3212            {
3213#ifndef __EMSCRIPTEN__
3214                if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
3215                    break;
3216                }
3217                // Head dimensions must fit in workgroup memory with minimum tile sizes
3218                size_t     limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
3219                const bool has_mask    = op->src[3] != nullptr;
3220                const bool kv_direct   = src1->type == GGML_TYPE_F16 &&
3221                                       (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
3222                                       (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
3223                const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
3224                    ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
3225                    (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
3226                if (min_bytes > limit_bytes) {
3227                    break;
3228                }
3229
3230                supports_op = src0->type == GGML_TYPE_F32 &&
3231                              (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
3232                               src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
3233                              src2->type == src1->type && op->type == GGML_TYPE_F32;
3234#endif
3235                break;
3236            }
3237        case GGML_OP_RMS_NORM:
3238            supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
3239            break;
3240        case GGML_OP_ROPE:
3241            supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
3242            break;
3243        case GGML_OP_GLU:
3244            switch (ggml_get_glu_op(op)) {
3245                case GGML_GLU_OP_REGLU:
3246                case GGML_GLU_OP_GEGLU:
3247                case GGML_GLU_OP_SWIGLU:
3248                case GGML_GLU_OP_GEGLU_ERF:
3249                case GGML_GLU_OP_GEGLU_QUICK:
3250                    supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
3251                    break;
3252                case GGML_GLU_OP_SWIGLU_OAI:
3253                    supports_op = op->type == GGML_TYPE_F32;
3254                    break;
3255                default:
3256                    break;
3257            }
3258            break;
3259        case GGML_OP_SCALE:
3260            supports_op = op->type == GGML_TYPE_F32;
3261            break;
3262        case GGML_OP_SOFT_MAX:
3263            supports_op = op->type == GGML_TYPE_F32;
3264            break;
3265        case GGML_OP_UNARY:
3266            {
3267                const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);
3268
3269                switch (UNARY_OP) {
3270                    case GGML_UNARY_OP_ABS:
3271                    case GGML_UNARY_OP_SGN:
3272                    case GGML_UNARY_OP_NEG:
3273                    case GGML_UNARY_OP_STEP:
3274                    case GGML_UNARY_OP_TANH:
3275                    case GGML_UNARY_OP_ELU:
3276                    case GGML_UNARY_OP_RELU:
3277                    case GGML_UNARY_OP_SIGMOID:
3278                    case GGML_UNARY_OP_GELU:
3279                    case GGML_UNARY_OP_GELU_QUICK:
3280                    case GGML_UNARY_OP_SILU:
3281                    case GGML_UNARY_OP_HARDSWISH:
3282                    case GGML_UNARY_OP_HARDSIGMOID:
3283                    case GGML_UNARY_OP_EXP:
3284                    case GGML_UNARY_OP_GELU_ERF:
3285                    case GGML_UNARY_OP_SOFTPLUS:
3286                    case GGML_UNARY_OP_EXPM1:
3287                    case GGML_UNARY_OP_FLOOR:
3288                    case GGML_UNARY_OP_CEIL:
3289                    case GGML_UNARY_OP_ROUND:
3290                    case GGML_UNARY_OP_TRUNC:
3291                    case GGML_UNARY_OP_XIELU:
3292                        supports_op =
3293                            (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3294                        break;
3295                    default:
3296                        break;
3297                }
3298            }
3299            break;
3300        case GGML_OP_CLAMP:
3301            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3302            break;
3303        case GGML_OP_FILL:
3304            supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
3305            break;
3306        case GGML_OP_LOG:
3307            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3308            break;
3309        case GGML_OP_PAD:
3310            supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
3311            break;
3312        case GGML_OP_ARGMAX:
3313            supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32;
3314            break;
3315        case GGML_OP_ARGSORT:
3316            supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
3317            break;
3318        case GGML_OP_TOP_K:
3319            supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
3320            break;
3321        case GGML_OP_CUMSUM:
3322            supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type;
3323            break;
3324        case GGML_OP_SUM:
3325        case GGML_OP_SUM_ROWS:
3326            supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);
3327            break;
3328        default:
3329            break;
3330    }
3331    if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
3332        (src0 != nullptr &&
3333         ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3334        (src1 != nullptr &&
3335         ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3336        (src2 != nullptr &&
3337         ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
3338        supports_op = false;
3339        WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
3340    }
3341
3342    if (!supports_op) {
3343        WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
3344                         << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
3345                         << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
3346                         << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
3347    } else {
3348        WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
3349                         << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
3350                         << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
3351                         << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
3352    }
3353    return supports_op;
3354}
3355
3356static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
3357    /* .get_name             = */ ggml_backend_webgpu_device_get_name,
3358    /* .get_description      = */ ggml_backend_webgpu_device_get_description,
3359    /* .get_memory           = */ ggml_backend_webgpu_device_get_memory,
3360    /* .get_type             = */ ggml_backend_webgpu_device_get_type,
3361    /* .get_props            = */ ggml_backend_webgpu_device_get_props,
3362    /* .init_backend         = */ ggml_backend_webgpu_backend_init,
3363    /* .get_buffer_type      = */ ggml_backend_webgpu_device_get_buffer_type,
3364    /* .get_host_buffer_type = */ NULL,
3365    /* .buffer_from_host_ptr = */ NULL,
3366    /* .supports_op          = */ ggml_backend_webgpu_device_supports_op,
3367    /* .supports_buft        = */ ggml_backend_webgpu_device_supports_buft,
3368    /* .offload_op           = */ NULL,
3369    /* .event_new            = */ NULL,
3370    /* .event_free           = */ NULL,
3371    /* .event_synchronize    = */ NULL,
3372};
3373
3374/* End GGML Backend Device Interface */
3375
3376/* GGML Backend Registration Interface */
3377
3378static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
3379    ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
3380    return ctx->name;
3381}
3382
3383static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
3384    ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
3385    return ctx->device_count;
3386}
3387
3388// Only one device is supported for now
3389static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
3390    GGML_ASSERT(index == 0);
3391    WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
3392
3393    WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device);
3394
3395    ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
3396
3397    create_webgpu_device(reg_ctx);
3398
3399    static ggml_backend_webgpu_device_context device_ctx;
3400    device_ctx.device_name            = GGML_WEBGPU_NAME;
3401    device_ctx.device_desc            = GGML_WEBGPU_NAME;
3402    device_ctx.webgpu_global_ctx      = reg_ctx->webgpu_global_ctx;
3403    // See GGML Backend Device Interface section
3404    static ggml_backend_device device = {
3405        /* .iface   = */ ggml_backend_webgpu_device_i,
3406        /* .reg     = */ reg,
3407        /* .context = */ &device_ctx,
3408    };
3409
3410    WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx);
3411    return &device;
3412}
3413
3414static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
3415    /* .get_name         = */ ggml_backend_webgpu_reg_get_name,
3416    /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
3417    /* .get_device       = */ ggml_backend_webgpu_reg_get_device,
3418    /* .get_proc_address = */ NULL,
3419};
3420
3421/* End GGML Backend Registration Interface */
3422
3423ggml_backend_reg_t ggml_backend_webgpu_reg() {
3424    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
3425
3426    static ggml_backend_webgpu_reg_context ctx;
3427    ctx.name         = GGML_WEBGPU_NAME;
3428    ctx.device_count = 1;
3429
3430    wgpu::InstanceDescriptor               instance_descriptor{};
3431    std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
3432    instance_descriptor.requiredFeatures                     = instance_features.data();
3433    instance_descriptor.requiredFeatureCount                 = instance_features.size();
3434
3435#ifndef __EMSCRIPTEN__
3436    const char * const          instanceEnabledToggles[] = { "allow_unsafe_apis" };
3437    wgpu::DawnTogglesDescriptor instanceTogglesDesc;
3438    instanceTogglesDesc.enabledToggles     = instanceEnabledToggles;
3439    instanceTogglesDesc.enabledToggleCount = 1;
3440    instance_descriptor.nextInChain        = &instanceTogglesDesc;
3441#endif
3442
3443    wgpu::Instance inst             = wgpu::CreateInstance(&instance_descriptor);
3444    ctx.webgpu_global_ctx           = webgpu_global_context(new webgpu_global_context_struct());
3445    ctx.webgpu_global_ctx->instance = std::move(inst);
3446
3447#ifdef __EMSCRIPTEN__
3448    if (ctx.webgpu_global_ctx->instance == nullptr) {
3449        GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
3450        return nullptr;
3451    }
3452#endif
3453    GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
3454
3455    static ggml_backend_reg reg = {
3456        /* .api_version = */ GGML_BACKEND_API_VERSION,
3457        /* .iface       = */ ggml_backend_webgpu_reg_i,
3458        /* .context     = */ &ctx,
3459    };
3460    return &reg;
3461}
3462
3463ggml_backend_t ggml_backend_webgpu_init(void) {
3464    ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
3465
3466    return ggml_backend_webgpu_backend_init(dev, nullptr);
3467}
3468
3469GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)