1/*
2 WebGPU backend implementation.
3 Note: Use ClangFormat to format this file.
4*/
5
6#include "ggml-webgpu.h"
7
8#include "ggml-backend-impl.h"
9#include "ggml-impl.h"
10#include "ggml-webgpu-shader-lib.hpp"
11#include "ggml-wgsl-shaders.hpp"
12#include "pre_wgsl.hpp"
13
14#ifdef __EMSCRIPTEN__
15# include <emscripten/emscripten.h>
16#endif
17
18#include <webgpu/webgpu_cpp.h>
19
20#include <atomic>
21#include <condition_variable>
22#include <cstdint>
23#include <cstring>
24#include <iostream>
25#include <map>
26#include <mutex>
27#include <optional>
28#include <string>
29#include <vector>
30
31#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
32#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
33
34#ifdef GGML_WEBGPU_DEBUG
35# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
36# define WEBGPU_DEBUG_BUF_ELEMS 512
37#else
38# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
39#endif // GGML_WEBGPU_DEBUG
40
41#ifdef GGML_WEBGPU_CPU_PROFILE
42// total timing (aggregated)
43# define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now();
44
45# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \
46 auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \
47 double cpu_total_time_##id = \
48 std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
49 (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
50// fine-grained timing (not included in totals)
51# define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
52
53# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \
54 auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \
55 double cpu_detail_time_##id = \
56 std::chrono::duration<double, std::milli>(cpu_detail_end_##id - cpu_detail_start_##id).count(); \
57 (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id;
58#else
59# define WEBGPU_CPU_PROFILE_TOTAL_START(id)
60# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)
61# define WEBGPU_CPU_PROFILE_DETAIL_START(id)
62# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)
63#endif // GGML_WEBGPU_CPU_PROFILE
64
65#ifdef GGML_WEBGPU_GPU_PROFILE
66# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
67# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
68#endif
69
70/* Constants */
71
72// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
73#define WEBGPU_MAX_WG_SIZE 288
74
75#define WEBGPU_MUL_MAT_WG_SIZE 256
76#define WEBGPU_NUM_PARAM_BUFS 16u
77#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
78#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
79// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
80#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
81#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
82#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
83#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
84#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
85
86// For operations which process a row in parallel, this seems like a reasonable default
87#define WEBGPU_ROW_SPLIT_WG_SIZE 64
88
89// Matrix multiplication parameters
90
91// Register tiling parameters
92#define WEBGPU_MUL_MAT_TILE_M 8
93#define WEBGPU_MUL_MAT_TILE_N 8
94#define WEBGPU_MUL_MAT_WG_SIZE_M 8
95#define WEBGPU_MUL_MAT_WG_SIZE_N 8
96#define WEBGPU_MUL_MAT_TILE_K 32
97
98// Subgroup matrix parameters
99// The number of subgroups in the M dimension
100#define WEBGPU_MUL_MAT_SUBGROUP_M 2
101// The number of subgroups in the N dimension
102#define WEBGPU_MUL_MAT_SUBGROUP_N 2
103// The number of subgroup matrices each subgroup accumulates over
104#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
105#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
106
107// Matrix-vector multiplication parameters
108#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
109// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
110#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
111#define WEBGPU_MUL_MAT_VEC_TILE_K 256
112
113/* End Constants */
114
115// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
116static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
117
118// Always returns the base offset of a tensor, regardless of views.
119static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
120 if (tensor->view_src) {
121 return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
122 }
123 return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
124}
125
126/* Struct definitions */
127
128// Forward reference
129static void ggml_webgpu_create_buffer(wgpu::Device & device,
130 wgpu::Buffer & buffer,
131 size_t size,
132 wgpu::BufferUsage usage,
133 const char * label);
134
135struct webgpu_pool_bufs {
136 wgpu::Buffer host_buf;
137 wgpu::Buffer dev_buf;
138};
139
140// The futures to wait on for a single queue submission
141struct webgpu_submission_futures {
142 std::vector<wgpu::FutureWaitInfo> futures;
143};
144
145// Holds a pool of parameter buffers for WebGPU operations
146struct webgpu_buf_pool {
147 std::vector<webgpu_pool_bufs> free;
148
149 // The pool must be synchronized because
150 // 1. The memset pool is shared globally by every ggml buffer,
151 // since allocating a pool per ggml buffer would consume too much memory.
152 // 2. For the per-thread buffer pools in webgpu_context,
153 // buffers are allocated and freed in Dawn callbacks,
154 // which can run on a different thread than the calling thread.
155 std::mutex mutex;
156 std::condition_variable cv;
157
158 void init(wgpu::Device device,
159 int num_bufs,
160 size_t buf_size,
161 wgpu::BufferUsage dev_buf_usage,
162 wgpu::BufferUsage host_buf_usage) {
163 for (int i = 0; i < num_bufs; i++) {
164 wgpu::Buffer host_buf;
165 wgpu::Buffer dev_buf;
166 ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
167 ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
168 free.push_back({ host_buf, dev_buf });
169 }
170 }
171
172 webgpu_pool_bufs alloc_bufs() {
173 std::unique_lock<std::mutex> lock(mutex);
174 cv.wait(lock, [this] { return !free.empty(); });
175 webgpu_pool_bufs bufs = free.back();
176 free.pop_back();
177 return bufs;
178 }
179
180 void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
181 std::lock_guard<std::mutex> lock(mutex);
182 free.insert(free.end(), bufs.begin(), bufs.end());
183 cv.notify_all();
184 }
185
186 void cleanup() {
187 std::lock_guard<std::mutex> lock(mutex);
188 for (auto & bufs : free) {
189 if (bufs.host_buf) {
190 bufs.host_buf.Destroy();
191 }
192 if (bufs.dev_buf) {
193 bufs.dev_buf.Destroy();
194 }
195 }
196 free.clear();
197 }
198
199 ~webgpu_buf_pool() { this->cleanup(); }
200};
201
202#ifdef GGML_WEBGPU_GPU_PROFILE
203struct webgpu_gpu_profile_bufs {
204 wgpu::Buffer host_buf;
205 wgpu::Buffer dev_buf;
206 wgpu::QuerySet query_set;
207};
208
209// Holds a pool of parameter buffers for WebGPU operations
210struct webgpu_gpu_profile_buf_pool {
211 std::vector<webgpu_gpu_profile_bufs> free;
212
213 std::mutex mutex;
214
215 std::condition_variable cv;
216
217 void init(wgpu::Device device,
218 int num_bufs,
219 size_t buf_size,
220 wgpu::BufferUsage dev_buf_usage,
221 wgpu::BufferUsage host_buf_usage) {
222 for (int i = 0; i < num_bufs; i++) {
223 wgpu::Buffer host_buf;
224 wgpu::Buffer dev_buf;
225 ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf");
226 ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf");
227 // Create a query set for 2 timestamps
228 wgpu::QuerySetDescriptor ts_query_set_desc = {};
229
230 ts_query_set_desc.type = wgpu::QueryType::Timestamp;
231 ts_query_set_desc.count = 2;
232 wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc);
233
234 free.push_back({ host_buf, dev_buf, ts_query_set });
235 }
236 }
237
238 webgpu_gpu_profile_bufs alloc_bufs() {
239 std::unique_lock<std::mutex> lock(mutex);
240 cv.wait(lock, [this] { return !free.empty(); });
241 webgpu_gpu_profile_bufs bufs = free.back();
242 free.pop_back();
243 return bufs;
244 }
245
246 void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) {
247 std::lock_guard<std::mutex> lock(mutex);
248 free.insert(free.end(), bufs.begin(), bufs.end());
249 cv.notify_all();
250 }
251
252 void cleanup() {
253 std::lock_guard<std::mutex> lock(mutex);
254 for (auto & bufs : free) {
255 bufs.host_buf.Destroy();
256 bufs.dev_buf.Destroy();
257 bufs.query_set.Destroy();
258 }
259 free.clear();
260 }
261
262 ~webgpu_gpu_profile_buf_pool() { this->cleanup(); }
263};
264#endif
265
266struct webgpu_pipeline {
267 wgpu::ComputePipeline pipeline;
268 std::string name;
269 std::shared_ptr<void> context = nullptr;
270};
271
272struct webgpu_command {
273 wgpu::CommandBuffer commands;
274 std::vector<webgpu_pool_bufs> params_bufs;
275 std::optional<webgpu_pool_bufs> set_rows_error_bufs;
276#ifdef GGML_WEBGPU_GPU_PROFILE
277 webgpu_gpu_profile_bufs timestamp_query_bufs;
278 std::string pipeline_name;
279#endif
280};
281
282struct webgpu_capabilities {
283 wgpu::Limits limits;
284 bool supports_subgroup_matrix = false;
285
286 uint32_t sg_mat_m = 0;
287 uint32_t sg_mat_n = 0;
288 uint32_t sg_mat_k = 0;
289
290 uint32_t subgroup_size = 0;
291 uint32_t max_subgroup_size = 0;
292 size_t memset_bytes_per_thread;
293};
294
295// Stores global webgpu members
296struct webgpu_global_context_struct {
297 wgpu::Instance instance;
298 wgpu::Adapter adapter;
299 wgpu::Device device;
300 wgpu::Queue queue;
301
302 webgpu_capabilities capabilities;
303 // Shared buffer to move data from device to host
304 wgpu::Buffer get_tensor_staging_buf;
305 // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
306 std::recursive_mutex mutex;
307
308 webgpu_buf_pool memset_buf_pool;
309 std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
310 std::atomic_uint inflight_threads = 0;
311
312#ifdef GGML_WEBGPU_CPU_PROFILE
313 // Profiling: labeled CPU time in ms (total)
314 std::unordered_map<std::string, double> cpu_time_ms;
315 // Profiling: detailed CPU time in ms
316 std::unordered_map<std::string, double> cpu_detail_ms;
317#endif
318
319#ifdef GGML_WEBGPU_GPU_PROFILE
320 // Profiling: per-shader GPU time in ms
321 std::unordered_map<std::string, double> shader_gpu_time_ms;
322 // Profiling: pool of timestamp query buffers (one per operation)
323 webgpu_gpu_profile_buf_pool timestamp_query_buf_pool;
324#endif
325
326#ifdef GGML_WEBGPU_DEBUG
327 wgpu::Buffer debug_host_buf;
328 wgpu::Buffer debug_dev_buf;
329#endif
330
331 ~webgpu_global_context_struct() {
332 if (this->get_tensor_staging_buf) {
333 this->get_tensor_staging_buf.Destroy();
334 this->get_tensor_staging_buf = nullptr;
335 }
336#ifdef GGML_WEBGPU_DEBUG
337 if (this->debug_host_buf) {
338 this->debug_host_buf.Destroy();
339 this->debug_host_buf = nullptr;
340 }
341 if (this->debug_dev_buf) {
342 this->debug_dev_buf.Destroy();
343 this->debug_dev_buf = nullptr;
344 }
345#endif
346 }
347};
348
349typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
350
351// All the base objects needed to run operations on a WebGPU device
352struct webgpu_context_struct {
353 // Points to global instances owned by ggml_backend_webgpu_reg_context
354 webgpu_global_context global_ctx;
355
356 pre_wgsl::Preprocessor p;
357
358 webgpu_buf_pool param_buf_pool;
359 webgpu_buf_pool set_rows_error_buf_pool;
360
361 std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
362 std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
363 mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
364
365 std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
366 flash_attn_pipelines;
367
368 std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
369 std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order (asc/desc)
370 std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order (asc/desc)
371 std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
372 std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
373
374 std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
375 set_rows_pipelines;
376 std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
377
378 std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
379
380 std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
381 binary_pipelines;
382
383 std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
384 std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
385 std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
386 std::map<int, webgpu_pipeline> scale_pipelines; // inplace
387 std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
388 std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
389 unary_pipelines;
390 std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;
391
392 size_t memset_bytes_per_thread;
393};
394
395typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
396
397// Metadata required for the ggml backend registration/discovery interface
398struct ggml_backend_webgpu_reg_context {
399 // Since the Instance is a global entrypoint into the WebGPU API, it lives here
400 webgpu_global_context webgpu_global_ctx;
401 size_t device_count;
402 const char * name;
403};
404
405// Per-device struct for the global logical device interface
406struct ggml_backend_webgpu_device_context {
407 webgpu_global_context webgpu_global_ctx;
408 std::string device_name;
409 std::string device_desc;
410};
411
412// Per-thread data required to actually run WebGPU operations in a backend instance
413struct ggml_backend_webgpu_context {
414 webgpu_context webgpu_ctx;
415 std::string name;
416};
417
418// Per-thread data related to buffers
419struct ggml_backend_webgpu_buffer_context {
420 wgpu::Buffer buffer;
421 std::string label;
422 webgpu_global_context global_ctx;
423
424 ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) :
425 buffer(std::move(buf)),
426 label(std::move(lbl)),
427 global_ctx(std::move(global_ctx_)) {}
428};
429
430/* WebGPU object initializations */
431
432// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
433// the corresponding values provided in `repls`.
434static std::string ggml_webgpu_process_shader_repls(const char * src,
435 const std::map<std::string, std::string> & repls) {
436 if (!src) {
437 return std::string();
438 }
439 std::string s = src;
440 for (const auto & kv : repls) {
441 std::string token = "{{" + kv.first + "}}";
442 size_t pos = 0;
443 while ((pos = s.find(token, pos)) != std::string::npos) {
444 s.replace(pos, token.length(), kv.second);
445 pos += kv.second.length();
446 }
447 }
448 return s;
449}
450
451static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
452 const char * shader_code,
453 const char * label,
454 const std::vector<wgpu::ConstantEntry> & constants = {}) {
455 wgpu::ShaderSourceWGSL shader_source;
456 shader_source.code = shader_code;
457
458 wgpu::ShaderModuleDescriptor shader_desc;
459 shader_desc.nextInChain = &shader_source;
460
461 wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
462
463 wgpu::ComputePipelineDescriptor pipeline_desc;
464 pipeline_desc.label = label;
465 pipeline_desc.compute.module = shader_module;
466 pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
467 pipeline_desc.layout = nullptr; // nullptr means auto layout
468 if (constants.size() > 0) {
469 pipeline_desc.compute.constants = constants.data();
470 pipeline_desc.compute.constantCount = constants.size();
471 }
472 return { device.CreateComputePipeline(&pipeline_desc), label };
473}
474
475static void ggml_webgpu_create_buffer(wgpu::Device & device,
476 wgpu::Buffer & buffer,
477 size_t size,
478 wgpu::BufferUsage usage,
479 const char * label) {
480 wgpu::BufferDescriptor buffer_desc;
481 buffer_desc.size = size;
482 buffer_desc.usage = usage;
483 buffer_desc.label = label;
484 buffer_desc.mappedAtCreation = false;
485
486 // TODO: error handling
487 buffer = device.CreateBuffer(&buffer_desc);
488}
489
490/** End WebGPU object initializations */
491
492/** WebGPU Actions */
493
494// Wait for the queue to finish processing all submitted work
495static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
496 std::vector<webgpu_submission_futures> & futures,
497 bool block = true) {
498 // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
499 // inflight_max may be 0, meaning that we must wait on all futures.
500 uint64_t timeout_ms = block ? UINT64_MAX : 0;
501 uint32_t inflight_threads = ctx->inflight_threads;
502 uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
503 while (futures.size() >= inflight_max && futures.size() > 0) {
504 ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
505 futures.erase(futures.begin());
506 }
507 size_t i = 0;
508 while (i < futures.size()) {
509 auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
510 switch (waitStatus) {
511 case wgpu::WaitStatus::Success:
512 futures.erase(futures.begin() + i);
513 break;
514 case wgpu::WaitStatus::TimedOut:
515 i++;
516 break;
517 case wgpu::WaitStatus::Error:
518 GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
519 break;
520 default:
521 GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
522 break;
523 }
524 }
525}
526
527static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
528 wgpu::Buffer & buffer,
529 wgpu::MapMode mode,
530 size_t offset,
531 size_t size) {
532 ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
533 [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
534 if (status != wgpu::MapAsyncStatus::Success) {
535 GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
536 message.data);
537 }
538 }),
539 UINT64_MAX);
540}
541
542#ifdef GGML_WEBGPU_DEBUG
543// This function adds debugging information to shaders, as WebGPU does not support printing directly.
544// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
545// debug statements in the shader, and then call this function after encoding the commands and submitting them.
546static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
547 wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
548 encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
549 wgpu::CommandBuffer commands = encoder.Finish();
550 ctx->queue.Submit(1, &commands);
551 ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
552 const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
553 std::cout << "debug[0]: " << debug_data[0] << "\n";
554 ctx->debug_host_buf.Unmap();
555}
556#endif
557
558static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx,
559 std::vector<webgpu_command> commands,
560 webgpu_buf_pool & param_buf_pool,
561 webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
562 std::vector<wgpu::CommandBuffer> command_buffers;
563 std::vector<webgpu_pool_bufs> params_bufs;
564 std::vector<webgpu_pool_bufs> set_rows_error_bufs;
565#ifdef GGML_WEBGPU_GPU_PROFILE
566 std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
567#endif
568
569 for (const auto & command : commands) {
570 command_buffers.push_back(command.commands);
571 params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
572 if (command.set_rows_error_bufs) {
573 set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
574 }
575 }
576 ctx->queue.Submit(command_buffers.size(), command_buffers.data());
577
578 std::vector<wgpu::FutureWaitInfo> futures;
579
580 wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
581 wgpu::CallbackMode::AllowSpontaneous,
582 [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
583 if (status != wgpu::QueueWorkDoneStatus::Success) {
584 GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
585 }
586 // Free the staged buffers
587 param_buf_pool.free_bufs(params_bufs);
588 });
589 futures.push_back({ p_f });
590
591 for (const auto & bufs : set_rows_error_bufs) {
592 wgpu::Future f = bufs.host_buf.MapAsync(
593 wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
594 [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
595 if (status != wgpu::MapAsyncStatus::Success) {
596 GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
597 } else {
598 const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
599 if (*error_data) {
600 GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
601 }
602 // We can't unmap in here due to WebGPU reentrancy limitations.
603 if (set_rows_error_buf_pool) {
604 set_rows_error_buf_pool->free_bufs({ bufs });
605 }
606 }
607 });
608 futures.push_back({ f });
609 }
610
611#ifdef GGML_WEBGPU_GPU_PROFILE
612 for (const auto & command : commands) {
613 auto label = command.pipeline_name;
614 auto ts_bufs = command.timestamp_query_bufs;
615
616 wgpu::Future f = ts_bufs.host_buf.MapAsync(
617 wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
618 [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
619 if (status != wgpu::MapAsyncStatus::Success) {
620 GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
621 } else {
622 const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
623 // WebGPU timestamps are in ns; convert to ms
624 double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
625 ctx->shader_gpu_time_ms[label] += elapsed_ms;
626 // We can't unmap in here due to WebGPU reentrancy limitations.
627 ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
628 }
629 });
630 futures.push_back({ f });
631 }
632#endif
633 return { futures };
634}
635
636static webgpu_command ggml_backend_webgpu_build_multi(
637 webgpu_global_context & ctx,
638 webgpu_buf_pool & param_buf_pool,
639 const std::vector<webgpu_pipeline> & pipelines,
640 const std::vector<std::vector<uint32_t>> & params_list,
641 const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
642 const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list,
643 const std::optional<webgpu_pool_bufs> & set_rows_error_bufs = std::nullopt) {
644 GGML_ASSERT(pipelines.size() == params_list.size());
645 GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
646 GGML_ASSERT(pipelines.size() == workgroups_list.size());
647
648 std::vector<webgpu_pool_bufs> params_bufs_list;
649 std::vector<wgpu::BindGroup> bind_groups;
650
651 for (size_t i = 0; i < pipelines.size(); i++) {
652 webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
653
654 ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
655 params_bufs.host_buf.GetSize());
656 uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
657 for (size_t j = 0; j < params_list[i].size(); j++) {
658 _params[j] = params_list[i][j];
659 }
660 params_bufs.host_buf.Unmap();
661
662 std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
663 uint32_t params_binding_num = entries.size();
664 entries.push_back({ .binding = params_binding_num,
665 .buffer = params_bufs.dev_buf,
666 .offset = 0,
667 .size = params_bufs.dev_buf.GetSize() });
668
669 wgpu::BindGroupDescriptor bind_group_desc;
670 bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
671 bind_group_desc.entryCount = entries.size();
672 bind_group_desc.entries = entries.data();
673 bind_group_desc.label = pipelines[i].name.c_str();
674 bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc));
675
676 params_bufs_list.push_back(params_bufs);
677 }
678
679 wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
680 for (const auto & params_bufs : params_bufs_list) {
681 encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
682 }
683
684 // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
685 if (set_rows_error_bufs) {
686 encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
687 set_rows_error_bufs->host_buf.GetSize());
688 }
689
690#ifdef GGML_WEBGPU_GPU_PROFILE
691 webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
692 if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
693 ts_bufs.host_buf.Unmap();
694 }
695
696 wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set,
697 .beginningOfPassWriteIndex = 0,
698 .endOfPassWriteIndex = 1 };
699 wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes };
700 wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc);
701#else
702 wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
703#endif
704 for (size_t i = 0; i < pipelines.size(); i++) {
705 pass.SetPipeline(pipelines[i].pipeline);
706 pass.SetBindGroup(0, bind_groups[i]);
707 pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1);
708 }
709 pass.End();
710
711#ifdef GGML_WEBGPU_GPU_PROFILE
712 encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
713 encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
714#endif
715
716 wgpu::CommandBuffer commands = encoder.Finish();
717 webgpu_command result = {};
718 result.commands = commands;
719 result.params_bufs = params_bufs_list;
720 result.set_rows_error_bufs = set_rows_error_bufs;
721#ifdef GGML_WEBGPU_GPU_PROFILE
722 result.timestamp_query_bufs = ts_bufs;
723 // TODO: handle multiple pipeline names
724 result.pipeline_name = pipelines.front().name;
725#endif
726 return result;
727}
728
729static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx,
730 webgpu_buf_pool & param_buf_pool,
731 webgpu_pipeline & pipeline,
732 std::vector<uint32_t> params,
733 std::vector<wgpu::BindGroupEntry> bind_group_entries,
734 uint32_t wg_x,
735 uint32_t wg_y = 1,
736 std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
737 return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
738 {
739 pipeline
740 },
741 { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
742}
743
744static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
745 wgpu::Buffer & buf,
746 uint32_t value,
747 size_t offset,
748 size_t size) {
749 std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
750 std::vector<wgpu::BindGroupEntry> entries = {
751 { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
752 };
753 size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;
754 uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
755
756 webgpu_command command =
757 ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
758 std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command },
759 ctx->memset_buf_pool) };
760 ggml_backend_webgpu_wait(ctx, futures);
761}
762
763/** End WebGPU Actions */
764
765/** GGML Backend Interface */
766
767static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
768 ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
769 return ctx->name.c_str();
770}
771
772static void ggml_backend_webgpu_free(ggml_backend_t backend) {
773 ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
774 WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
775
776#ifdef GGML_WEBGPU_CPU_PROFILE
777 std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
778 double total_cpu = 0.0;
779 for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
780 total_cpu += kv.second;
781 }
782 std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
783 std::cout << "ggml_webgpu: cpu breakdown:\n";
784 for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
785 double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
786 std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
787 }
788 if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) {
789 std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
790 }
791 for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) {
792 double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
793 std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
794 }
795#endif
796
797#ifdef GGML_WEBGPU_GPU_PROFILE
798 std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
799 double total_gpu = 0.0;
800 for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
801 total_gpu += kv.second;
802 }
803 std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
804 std::cout << "\nggml_webgpu: gpu breakdown:\n";
805 for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
806 double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
807 std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
808 }
809#endif
810
811#if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE)
812 std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
813#endif
814
815 delete ctx;
816 delete backend;
817}
818
819static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
820 return webgpu_tensor_offset(tensor) + tensor->view_offs;
821}
822
823static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
824 ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
825 return ctx->buffer;
826}
827
828static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
829 size_t offset = ggml_webgpu_tensor_offset(t);
830 return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
831}
832
833static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
834 size_t offset = ggml_webgpu_tensor_offset(t);
835 return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
836}
837
838static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
839 return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
840}
841
842// Used to determine if two tensors are the same for in-place operations
843static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
844 return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
845 (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
846}
847
848// Used to determine if two tensors share the same buffer and their byte ranges overlap,
849static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
850 return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
851 ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
852 ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
853}
854
855struct binary_overlap_flags {
856 bool inplace; // src0 == dst
857 bool overlap; // src1 == dst
858};
859
860static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
861 ggml_tensor * src1,
862 ggml_tensor * dst) {
863 binary_overlap_flags flags = {};
864 flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
865 flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
866
867 return flags;
868}
869
870static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
871 uint32_t ne = (uint32_t) ggml_nelements(dst);
872
873 std::vector<uint32_t> params = {
874 ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
875 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
876 // Convert byte-strides to element-strides
877 (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
878 (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
879 (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
880 (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
881 // Logical shapes
882 (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
883 (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
884 };
885
886 std::vector<wgpu::BindGroupEntry> entries = {
887 { .binding = 0,
888 .buffer = ggml_webgpu_tensor_buf(src),
889 .offset = ggml_webgpu_tensor_align_offset(ctx, src),
890 .size = ggml_webgpu_tensor_binding_size(ctx, src) },
891 { .binding = 1,
892 .buffer = ggml_webgpu_tensor_buf(dst),
893 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
894 .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
895 };
896
897 uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
898 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
899 params, entries, wg_x);
900}
901
902static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
903 const bool circular = ggml_get_op_params_i32(dst, 8) != 0;
904
905 ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular };
906 ggml_webgpu_pad_shader_lib_context shader_lib_ctx = {
907 .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
908 };
909
910 webgpu_pipeline pipeline;
911 auto it = ctx->pad_pipelines.find(pipeline_key);
912 if (it != ctx->pad_pipelines.end()) {
913 pipeline = it->second;
914 } else {
915 ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
916 pipeline =
917 ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
918 pipeline.context = processed.decisions;
919 ctx->pad_pipelines.emplace(pipeline_key, pipeline);
920 }
921
922 auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
923
924 const uint32_t ne = (uint32_t) ggml_nelements(dst);
925
926 std::vector<uint32_t> params = {
927 ne,
928 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
929 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
930 // Strides (in elements)
931 (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
932 (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
933 (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
934 (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
935 // Shapes
936 (uint32_t) src->ne[0],
937 (uint32_t) src->ne[1],
938 (uint32_t) src->ne[2],
939 (uint32_t) src->ne[3],
940 (uint32_t) dst->ne[0],
941 (uint32_t) dst->ne[1],
942 (uint32_t) dst->ne[2],
943 (uint32_t) dst->ne[3],
944 // Pad sizes
945 (uint32_t) ggml_get_op_params_i32(dst, 0),
946 (uint32_t) ggml_get_op_params_i32(dst, 1),
947 (uint32_t) ggml_get_op_params_i32(dst, 2),
948 (uint32_t) ggml_get_op_params_i32(dst, 3),
949 (uint32_t) ggml_get_op_params_i32(dst, 4),
950 (uint32_t) ggml_get_op_params_i32(dst, 5),
951 (uint32_t) ggml_get_op_params_i32(dst, 6),
952 (uint32_t) ggml_get_op_params_i32(dst, 7),
953 };
954
955 std::vector<wgpu::BindGroupEntry> entries = {
956 { .binding = 0,
957 .buffer = ggml_webgpu_tensor_buf(src),
958 .offset = ggml_webgpu_tensor_align_offset(ctx, src),
959 .size = ggml_webgpu_tensor_binding_size(ctx, src) },
960 { .binding = 1,
961 .buffer = ggml_webgpu_tensor_buf(dst),
962 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
963 .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
964 };
965
966 uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
967 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
968}
969
970static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
971 ggml_tensor * src,
972 ggml_tensor * idx,
973 ggml_tensor * dst) {
974 // For set rows specifically, we need to check if src and idx are empty tensors.
975 if (ggml_is_empty(src) || ggml_is_empty(idx)) {
976 return std::nullopt;
977 }
978
979 ggml_webgpu_set_rows_pipeline_key key = { .dst_type = dst->type,
980 .vec4 = src->ne[0] % 4 == 0,
981 .i64_idx = idx->type == GGML_TYPE_I64 };
982
983 ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = {
984 .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
985 };
986
987 webgpu_pipeline pipeline;
988 auto it = ctx->set_rows_pipelines.find(key);
989 if (it != ctx->set_rows_pipelines.end()) {
990 pipeline = it->second;
991 } else {
992 ggml_webgpu_processed_shader processed =
993 ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
994 pipeline =
995 ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
996 pipeline.context = processed.decisions;
997 ctx->set_rows_pipelines.emplace(key, pipeline);
998 }
999
1000 auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1001
1002 std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
1003 if (key.i64_idx) {
1004 error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
1005 if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
1006 error_bufs->host_buf.Unmap();
1007 }
1008 }
1009
1010 std::vector<uint32_t> params = {
1011 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1012 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
1013 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1014 // Convert byte-strides to element-strides
1015 (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1016 (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
1017 (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
1018 (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1019 (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1020 // Shape of src
1021 (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
1022 // Shape of idx
1023 (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
1024 };
1025
1026 std::vector<wgpu::BindGroupEntry> entries = {
1027 { .binding = 0,
1028 .buffer = ggml_webgpu_tensor_buf(src),
1029 .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1030 .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1031 { .binding = 1,
1032 .buffer = ggml_webgpu_tensor_buf(idx),
1033 .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
1034 .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
1035 { .binding = 2,
1036 .buffer = ggml_webgpu_tensor_buf(dst),
1037 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1038 .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1039 };
1040
1041 if (key.i64_idx) {
1042 entries.push_back(
1043 { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
1044 }
1045
1046 uint32_t threads;
1047 if (key.vec4) {
1048 threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
1049 } else {
1050 threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
1051 }
1052 uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
1053 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
1054 error_bufs);
1055}
1056
1057static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
1058 ggml_tensor * src,
1059 ggml_tensor * idx,
1060 ggml_tensor * dst) {
1061 std::vector<uint32_t> params = {
1062 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1063 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
1064 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1065 // Convert byte-strides to element-strides
1066 (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1067 (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
1068 (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
1069 (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1070 (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1071 // Shape of dst
1072 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
1073 // Shape of idx
1074 (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
1075 };
1076
1077 std::vector<wgpu::BindGroupEntry> entries = {
1078 { .binding = 0,
1079 .buffer = ggml_webgpu_tensor_buf(src),
1080 .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1081 .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1082 { .binding = 1,
1083 .buffer = ggml_webgpu_tensor_buf(idx),
1084 .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
1085 .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
1086 { .binding = 2,
1087 .buffer = ggml_webgpu_tensor_buf(dst),
1088 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1089 .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1090 };
1091
1092 uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
1093
1094 uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
1095 webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized];
1096 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1097}
1098
1099static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
1100 ggml_tensor * src0,
1101 ggml_tensor * src1,
1102 ggml_tensor * dst) {
1103 std::vector<uint32_t> params = {
1104 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1105 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1106 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1107 (uint32_t) dst->ne[0], // number of rows in result (M, transposed)
1108 (uint32_t) dst->ne[1], // number of columns in result (N)
1109 (uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
1110 (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
1111 (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
1112 (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
1113 (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
1114 (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
1115 (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
1116 (uint32_t) src0->ne[2], // batch size in dimension 2
1117 (uint32_t) src0->ne[3], // batch size in dimension 3
1118 (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
1119 (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
1120 };
1121
1122 std::vector<wgpu::BindGroupEntry> entries = {
1123 { .binding = 0,
1124 .buffer = ggml_webgpu_tensor_buf(src0),
1125 .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1126 .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1127 { .binding = 1,
1128 .buffer = ggml_webgpu_tensor_buf(src1),
1129 .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1130 .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
1131 { .binding = 2,
1132 .buffer = ggml_webgpu_tensor_buf(dst),
1133 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1134 .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
1135 };
1136
1137 webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
1138
1139 uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
1140 uint32_t wg_y = 1;
1141
1142 bool use_fast = false;
1143 switch (src1->type) {
1144 case GGML_TYPE_F16:
1145 use_fast = (src0->type == GGML_TYPE_F16);
1146 break;
1147 case GGML_TYPE_F32:
1148 switch (src0->type) {
1149 case GGML_TYPE_F32:
1150 case GGML_TYPE_F16:
1151 case GGML_TYPE_Q4_0:
1152 use_fast = true;
1153 break;
1154 default:
1155 break;
1156 }
1157 break;
1158 default:
1159 break;
1160 }
1161
1162 if (use_fast) {
1163 int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
1164 if (dst->ne[1] == 1) {
1165 // We don't support vectorized mul_mat_vec for quantized types
1166 vectorized = vectorized && (src0->type < 2);
1167 pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
1168 uint32_t batches = dst->ne[2] * dst->ne[3];
1169 uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
1170 uint32_t total_wg = output_groups * batches;
1171 wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1172 wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
1173 } else {
1174 pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
1175 uint32_t wg_m;
1176 uint32_t wg_n;
1177#ifndef __EMSCRIPTEN__
1178 if (ctx->global_ctx->capabilities.supports_subgroup_matrix) {
1179 // The total number of subgroups/workgroups needed per matrix.
1180 uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M *
1181 ctx->global_ctx->capabilities.sg_mat_m;
1182 wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
1183 uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N *
1184 ctx->global_ctx->capabilities.sg_mat_n;
1185 wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
1186 } else {
1187#endif
1188 uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
1189 uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
1190 wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
1191 wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
1192#ifndef __EMSCRIPTEN__
1193 }
1194#endif
1195
1196 wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
1197 }
1198 }
1199 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
1200}
1201
1202#ifndef __EMSCRIPTEN__
1203static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
1204 ggml_tensor * Q,
1205 ggml_tensor * K,
1206 ggml_tensor * V,
1207 ggml_tensor * mask,
1208 ggml_tensor * sinks,
1209 ggml_tensor * dst) {
1210 float scale = *(float *) dst->op_params;
1211 float max_bias;
1212 memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1213 float logit_softcap;
1214 memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
1215 if (logit_softcap != 0.0f) {
1216 scale /= logit_softcap;
1217 }
1218 float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
1219 float m0 = powf(2.0f, -(max_bias) / n_head_log2);
1220 float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1221
1222 const int has_mask = (mask != nullptr);
1223 const int has_sinks = (sinks != nullptr);
1224
1225 std::vector<uint32_t> params = {
1226 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
1227 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
1228 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
1229 has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
1230 has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
1231 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1232 (uint32_t) Q->ne[2], // number of heads
1233 (uint32_t) Q->ne[1], // sequence length (Q)
1234 (uint32_t) K->ne[1], // sequence length (K/V)
1235 (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
1236 (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
1237 (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
1238 (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
1239 (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
1240 (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
1241 (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
1242 (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
1243 (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
1244 has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
1245 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
1246 *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
1247 *(uint32_t *) &max_bias,
1248 *(uint32_t *) &logit_softcap,
1249 *(uint32_t *) &n_head_log2,
1250 *(uint32_t *) &m0,
1251 *(uint32_t *) &m1
1252
1253 };
1254 std::vector<wgpu::BindGroupEntry> entries = {
1255 { .binding = 0,
1256 .buffer = ggml_webgpu_tensor_buf(Q),
1257 .offset = ggml_webgpu_tensor_align_offset(ctx, Q),
1258 .size = ggml_webgpu_tensor_binding_size(ctx, Q) },
1259 { .binding = 1,
1260 .buffer = ggml_webgpu_tensor_buf(K),
1261 .offset = ggml_webgpu_tensor_align_offset(ctx, K),
1262 .size = ggml_webgpu_tensor_binding_size(ctx, K) },
1263 { .binding = 2,
1264 .buffer = ggml_webgpu_tensor_buf(V),
1265 .offset = ggml_webgpu_tensor_align_offset(ctx, V),
1266 .size = ggml_webgpu_tensor_binding_size(ctx, V) }
1267 };
1268 uint32_t binding_index = 3;
1269 if (has_mask) {
1270 entries.push_back({ .binding = binding_index++,
1271 .buffer = ggml_webgpu_tensor_buf(mask),
1272 .offset = ggml_webgpu_tensor_align_offset(ctx, mask),
1273 .size = ggml_webgpu_tensor_binding_size(ctx, mask) });
1274 }
1275 if (has_sinks) {
1276 entries.push_back({ .binding = binding_index++,
1277 .buffer = ggml_webgpu_tensor_buf(sinks),
1278 .offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
1279 .size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
1280 }
1281 entries.push_back({ .binding = binding_index++,
1282 .buffer = ggml_webgpu_tensor_buf(dst),
1283 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1284 .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1285
1286 bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
1287 (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
1288
1289 ggml_webgpu_flash_attn_pipeline_key key = {
1290 .kv_type = K->type,
1291 .head_dim_qk = (uint32_t) Q->ne[0],
1292 .head_dim_v = (uint32_t) V->ne[0],
1293 .kv_direct = kv_direct,
1294 .has_mask = static_cast<bool>(has_mask),
1295 .has_sinks = static_cast<bool>(has_sinks),
1296 .uses_logit_softcap = logit_softcap != 0.0f,
1297 };
1298
1299 webgpu_pipeline pipeline;
1300 auto it = ctx->flash_attn_pipelines.find(key);
1301 if (it != ctx->flash_attn_pipelines.end()) {
1302 pipeline = it->second;
1303 } else {
1304 ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
1305 .key = key,
1306 .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
1307 .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
1308 .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
1309 .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1310 .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
1311 };
1312
1313 ggml_webgpu_processed_shader processed =
1314 ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
1315 pipeline =
1316 ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1317 pipeline.context = processed.decisions;
1318 ctx->flash_attn_pipelines.emplace(key, pipeline);
1319 }
1320
1321 auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
1322
1323 uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
1324 uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
1325 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1326}
1327#endif
1328
1329static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1330 bool is_unary = dst->op == GGML_OP_UNARY;
1331 bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
1332 int op = is_unary ? (int) ggml_get_unary_op(dst) : dst->op;
1333
1334 ggml_webgpu_unary_pipeline_key pipeline_key = {
1335 .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace
1336 };
1337 ggml_webgpu_unary_shader_lib_context shader_lib_ctx = {
1338 .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
1339 };
1340
1341 webgpu_pipeline pipeline;
1342 auto it = ctx->unary_pipelines.find(pipeline_key);
1343 if (it != ctx->unary_pipelines.end()) {
1344 pipeline = it->second;
1345 } else {
1346 ggml_webgpu_processed_shader processed =
1347 ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
1348 pipeline =
1349 ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1350 pipeline.context = processed.decisions;
1351 ctx->unary_pipelines.emplace(pipeline_key, pipeline);
1352 }
1353
1354 auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
1355
1356 uint32_t ne = (uint32_t) ggml_nelements(dst);
1357
1358 std::vector<uint32_t> params = { ne,
1359 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1360 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1361 (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
1362 (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1363 (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1364 (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1365 (uint32_t) src->ne[0],
1366 (uint32_t) src->ne[1],
1367 (uint32_t) src->ne[2] };
1368
1369 ggml_tensor * effective_src = src;
1370 if (is_unary) {
1371 ggml_unary_op unary_op = ggml_get_unary_op(dst);
1372 switch (unary_op) {
1373 case GGML_UNARY_OP_XIELU:
1374 {
1375 // Get float parameters and reinterpret their bit patterns as uint32_t
1376 // for passing through the params buffer
1377 float alpha_n = ggml_get_op_params_f32(dst, 1);
1378 float alpha_p = ggml_get_op_params_f32(dst, 2);
1379 float beta = ggml_get_op_params_f32(dst, 3);
1380 float eps = ggml_get_op_params_f32(dst, 4);
1381 params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
1382 params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
1383 params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
1384 params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
1385 break;
1386 }
1387 default:
1388 break;
1389 }
1390 } else if (dst->op == GGML_OP_CLAMP) {
1391 float clamp_min = ggml_get_op_params_f32(dst, 0);
1392 float clamp_max = ggml_get_op_params_f32(dst, 1);
1393 params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_min));
1394 params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_max));
1395 } else if (dst->op == GGML_OP_FILL) {
1396 float fill_val = ggml_get_op_params_f32(dst, 0);
1397 params.push_back(*reinterpret_cast<const uint32_t *>(&fill_val));
1398 effective_src = dst; // fill simply fills dst
1399 }
1400
1401 std::vector<wgpu::BindGroupEntry> entries = {
1402 { .binding = 0,
1403 .buffer = ggml_webgpu_tensor_buf(effective_src),
1404 .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src),
1405 .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) },
1406 };
1407 if (!inplace) {
1408 entries.push_back({ .binding = 1,
1409 .buffer = ggml_webgpu_tensor_buf(dst),
1410 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1411 .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1412 }
1413
1414 uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1415 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1416}
1417
1418static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
1419 ggml_tensor * src0,
1420 ggml_tensor * src1,
1421 ggml_tensor * dst) {
1422 binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
1423
1424 ggml_webgpu_binary_pipeline_key pipeline_key = {
1425 .type = dst->type,
1426 .op = dst->op,
1427 .inplace = flags.inplace,
1428 .overlap = flags.overlap,
1429 };
1430 ggml_webgpu_binary_shader_lib_context shader_lib_ctx = {
1431 .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
1432 };
1433
1434 webgpu_pipeline pipeline;
1435 auto it = ctx->binary_pipelines.find(pipeline_key);
1436 if (it != ctx->binary_pipelines.end()) {
1437 pipeline = it->second;
1438 } else {
1439 ggml_webgpu_processed_shader processed =
1440 ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx);
1441 pipeline =
1442 ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1443 pipeline.context = processed.decisions;
1444 ctx->binary_pipelines.emplace(pipeline_key, pipeline);
1445 }
1446
1447 auto * decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(pipeline.context.get());
1448
1449 uint32_t ne = (uint32_t) ggml_nelements(dst);
1450
1451 std::vector<uint32_t> params = {
1452 ne,
1453 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1454 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1455 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1456 (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
1457 (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
1458 (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
1459 (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
1460 (uint32_t) src0->ne[0],
1461 (uint32_t) src0->ne[1],
1462 (uint32_t) src0->ne[2],
1463 (uint32_t) src1->ne[0],
1464 (uint32_t) src1->ne[1],
1465 (uint32_t) src1->ne[2],
1466 (uint32_t) src1->ne[3],
1467 };
1468
1469 std::vector<wgpu::BindGroupEntry> entries;
1470
1471 entries.push_back({
1472 .binding = 0,
1473 .buffer = ggml_webgpu_tensor_buf(src0),
1474 .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1475 .size = ggml_webgpu_tensor_binding_size(ctx, src0),
1476 });
1477
1478 entries.push_back({
1479 .binding = 1,
1480 .buffer = ggml_webgpu_tensor_buf(src1),
1481 .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1482 .size = ggml_webgpu_tensor_binding_size(ctx, src1),
1483 });
1484
1485 if (!flags.inplace && !flags.overlap) {
1486 entries.push_back({ .binding = 2,
1487 .buffer = ggml_webgpu_tensor_buf(dst),
1488 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1489 .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1490 }
1491
1492 uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
1493 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1494}
1495
1496static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1497 int inplace = ggml_webgpu_tensor_equal(src, dst);
1498
1499 std::vector<uint32_t> params = {
1500 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1501 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1502 (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1503 (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1504 (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1505 (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1506 (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1507 (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1508 (uint32_t) src->ne[0],
1509 (uint32_t) src->ne[1],
1510 (uint32_t) src->ne[2],
1511 (uint32_t) src->ne[3],
1512 *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
1513 };
1514
1515 std::vector<wgpu::BindGroupEntry> entries = {
1516 { .binding = 0,
1517 .buffer = ggml_webgpu_tensor_buf(src),
1518 .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1519 .size = ggml_webgpu_tensor_binding_size(ctx, src) }
1520 };
1521 if (!inplace) {
1522 entries.push_back({ .binding = 1,
1523 .buffer = ggml_webgpu_tensor_buf(dst),
1524 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1525 .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1526 }
1527
1528 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
1529 entries, ggml_nrows(src));
1530}
1531
1532static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
1533 ggml_tensor * src0,
1534 ggml_tensor * src1,
1535 ggml_tensor * src2,
1536 ggml_tensor * dst) {
1537 const int inplace = ggml_webgpu_tensor_equal(src0, dst);
1538 const int has_freq_factor = (src2 != nullptr);
1539
1540 const int n_dims = ((int32_t *) dst->op_params)[1];
1541 const int mode = ((int32_t *) dst->op_params)[2];
1542 const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1543
1544 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1545 memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1546 memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1547 memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1548 memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1549 memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1550 memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1551
1552 int sections[4];
1553 memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
1554
1555 float theta_scale = powf(freq_base, -2.0f / n_dims);
1556
1557 float corr_dims[2];
1558 ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
1559
1560 std::vector<uint32_t> params = {
1561 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1562 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
1563 src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1564 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1565 (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1566 (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1567 (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1568 (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1569 (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1570 (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1571 (uint32_t) ggml_nelements(src0) / 2,
1572 (uint32_t) src0->ne[0],
1573 (uint32_t) src0->ne[1],
1574 (uint32_t) src0->ne[2],
1575 (uint32_t) n_dims,
1576 (uint32_t) mode,
1577 *(uint32_t *) &theta_scale,
1578 *(uint32_t *) &attn_factor,
1579 *(uint32_t *) &freq_scale,
1580 *(uint32_t *) &ext_factor,
1581 *(uint32_t *) &corr_dims[0],
1582 *(uint32_t *) &corr_dims[1],
1583 (uint32_t) sections[0],
1584 (uint32_t) sections[1],
1585 (uint32_t) sections[2],
1586 (uint32_t) sections[3]
1587 };
1588
1589 std::vector<wgpu::BindGroupEntry> entries = {
1590 { .binding = 0,
1591 .buffer = ggml_webgpu_tensor_buf(src0),
1592 .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1593 .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1594 { .binding = 1,
1595 .buffer = ggml_webgpu_tensor_buf(src1),
1596 .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1597 .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
1598 };
1599 uint32_t dst_binding = 2;
1600 if (has_freq_factor) {
1601 dst_binding = 3;
1602 entries.push_back({ .binding = 2,
1603 .buffer = ggml_webgpu_tensor_buf(src2),
1604 .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
1605 .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
1606 }
1607 if (!inplace) {
1608 entries.push_back({ .binding = dst_binding,
1609 .buffer = ggml_webgpu_tensor_buf(dst),
1610 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1611 .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1612 }
1613
1614 webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
1615 uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1616 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1617}
1618
1619static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
1620 const int split = (src1 != nullptr);
1621
1622 std::vector<uint32_t> params = {
1623 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1624 src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1625 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1626 (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1627 (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1628 (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1629 src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
1630 (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1631 src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
1632 (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1633 src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
1634 (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1635 (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1636 (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1637 (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1638 (uint32_t) ggml_nelements(dst),
1639 (uint32_t) dst->ne[0],
1640 (uint32_t) dst->ne[1],
1641 (uint32_t) dst->ne[2],
1642 (uint32_t) ((int32_t *) dst->op_params)[1], // swapped
1643 *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
1644 *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
1645 };
1646
1647 std::vector<wgpu::BindGroupEntry> entries = {
1648 { .binding = 0,
1649 .buffer = ggml_webgpu_tensor_buf(src0),
1650 .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1651 .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
1652 };
1653 uint32_t dst_binding = 1;
1654 if (split) {
1655 dst_binding = 2;
1656 entries.push_back({ .binding = 1,
1657 .buffer = ggml_webgpu_tensor_buf(src1),
1658 .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1659 .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
1660 }
1661 entries.push_back({ .binding = dst_binding,
1662 .buffer = ggml_webgpu_tensor_buf(dst),
1663 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1664 .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1665
1666 webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
1667 uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1668 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1669}
1670
1671static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1672 int inplace = ggml_webgpu_tensor_equal(src, dst);
1673
1674 std::vector<uint32_t> params = {
1675 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1676 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1677 (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
1678 (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
1679 (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
1680 (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1681 (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1682 (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1683 (uint32_t) ggml_nelements(dst),
1684 (uint32_t) src->ne[0],
1685 (uint32_t) src->ne[1],
1686 (uint32_t) src->ne[2],
1687 *(uint32_t *) dst->op_params, // scale
1688 *(uint32_t *) &dst->op_params[1] // bias
1689 };
1690
1691 std::vector<wgpu::BindGroupEntry> entries = {
1692 { .binding = 0,
1693 .buffer = ggml_webgpu_tensor_buf(src),
1694 .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1695 .size = ggml_webgpu_tensor_binding_size(ctx, src) }
1696 };
1697 if (!inplace) {
1698 entries.push_back({ .binding = 1,
1699 .buffer = ggml_webgpu_tensor_buf(dst),
1700 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1701 .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1702 }
1703
1704 uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
1705 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params,
1706 entries, wg_x);
1707}
1708
1709static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
1710 ggml_tensor * src0,
1711 ggml_tensor * src1,
1712 ggml_tensor * src2,
1713 ggml_tensor * dst) {
1714 const int inplace = ggml_webgpu_tensor_equal(src0, dst);
1715 const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
1716 const int has_sink = (src2 != nullptr);
1717 float max_bias;
1718 memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1719 float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
1720 float m0 = powf(2.0f, -(max_bias) / n_head_log2);
1721 float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1722
1723 std::vector<uint32_t> params = {
1724 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
1725 mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
1726 has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
1727 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1728 (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
1729 (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
1730 (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
1731 mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
1732 mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
1733 mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
1734 (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
1735 (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
1736 (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
1737 (uint32_t) ggml_nelements(dst),
1738 (uint32_t) src0->ne[0],
1739 (uint32_t) src0->ne[1],
1740 (uint32_t) src0->ne[2],
1741 mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
1742 mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
1743 *(uint32_t *) dst->op_params, // scale
1744 *(uint32_t *) &max_bias,
1745 *(uint32_t *) &n_head_log2,
1746 *(uint32_t *) &m0,
1747 *(uint32_t *) &m1
1748 };
1749
1750 std::vector<wgpu::BindGroupEntry> entries = {
1751 { .binding = 0,
1752 .buffer = ggml_webgpu_tensor_buf(src0),
1753 .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
1754 .size = ggml_webgpu_tensor_binding_size(ctx, src0) }
1755 };
1756 uint32_t binding_num = 1;
1757 if (mask_type < 2) {
1758 entries.push_back({ .binding = binding_num,
1759 .buffer = ggml_webgpu_tensor_buf(src1),
1760 .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
1761 .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
1762 binding_num++;
1763 }
1764 if (has_sink) {
1765 entries.push_back({ .binding = binding_num,
1766 .buffer = ggml_webgpu_tensor_buf(src2),
1767 .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
1768 .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
1769 binding_num++;
1770 }
1771 if (!inplace) {
1772 entries.push_back({ .binding = binding_num,
1773 .buffer = ggml_webgpu_tensor_buf(dst),
1774 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1775 .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
1776 }
1777
1778 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
1779 ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
1780 ggml_nrows(dst));
1781}
1782
1783static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1784 std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
1785 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
1786 (uint32_t) src->ne[0] };
1787
1788 std::vector<wgpu::BindGroupEntry> entries = {
1789 { .binding = 0,
1790 .buffer = ggml_webgpu_tensor_buf(src),
1791 .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1792 .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1793 { .binding = 1,
1794 .buffer = ggml_webgpu_tensor_buf(dst),
1795 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
1796 .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
1797 };
1798
1799 ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
1800 .vec4 = src->ne[0] % 4 == 0,
1801 .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1802 };
1803
1804 webgpu_pipeline pipeline;
1805 auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
1806 if (it != ctx->argmax_pipelines.end()) {
1807 pipeline = it->second;
1808 } else {
1809 ggml_webgpu_processed_shader processed =
1810 ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
1811 pipeline =
1812 ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1813 ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
1814 }
1815 uint32_t wg_x = ggml_nelements(dst);
1816 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
1817}
1818
1819static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1820 bool is_top_k = dst->op == GGML_OP_TOP_K;
1821 // ascending order is 0, descending order is 1
1822 const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0);
1823
1824 ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = {
1825 .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
1826 .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
1827 .order = order
1828 };
1829
1830 webgpu_pipeline argsort_pipeline;
1831 auto it = ctx->argsort_pipelines.find(order);
1832 if (it != ctx->argsort_pipelines.end()) {
1833 argsort_pipeline = it->second;
1834 } else {
1835 ggml_webgpu_processed_shader processed =
1836 ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx);
1837 argsort_pipeline =
1838 ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1839 argsort_pipeline.context = processed.decisions;
1840 ctx->argsort_pipelines.emplace(order, argsort_pipeline);
1841 }
1842 auto * argsort_decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(argsort_pipeline.context.get());
1843
1844 webgpu_pipeline argsort_merge_pipeline;
1845 it = ctx->argsort_merge_pipelines.find(order);
1846 if (it != ctx->argsort_merge_pipelines.end()) {
1847 argsort_merge_pipeline = it->second;
1848 } else {
1849 ggml_webgpu_processed_shader processed =
1850 ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx);
1851 argsort_merge_pipeline =
1852 ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
1853 argsort_merge_pipeline.context = processed.decisions;
1854 ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline);
1855 }
1856
1857 const uint32_t src_ne0 = (uint32_t) src->ne[0];
1858 const uint32_t nrows = (uint32_t) ggml_nrows(src);
1859 const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size);
1860 const uint32_t block_size =
1861 is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size;
1862 uint32_t out_ne0 = src_ne0;
1863 if (is_top_k) {
1864 if (npr > 1) {
1865 const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size;
1866 out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size);
1867 } else {
1868 out_ne0 = block_size;
1869 }
1870 }
1871
1872 uint32_t merge_len = block_size;
1873 uint32_t merge_passes = 0;
1874 while (merge_len < out_ne0) {
1875 merge_len <<= 1;
1876 merge_passes++;
1877 }
1878
1879 const bool start_in_tmp = (merge_passes % 2) == 1;
1880
1881 const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
1882 const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);
1883 const size_t tmp_offset =
1884 ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
1885 const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
1886 const size_t dst_binding_size =
1887 ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT);
1888
1889 const uint32_t offset_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type));
1890 const uint32_t offset_dst = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type));
1891 const uint32_t offset_tmp = 0;
1892 const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type));
1893 const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type));
1894 const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type));
1895 const uint32_t stride_idx1 = out_ne0;
1896 const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1];
1897 const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2];
1898
1899 std::vector<webgpu_pipeline> pipelines;
1900 std::vector<std::vector<uint32_t>> params_list;
1901 std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
1902 std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
1903
1904 const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst;
1905 const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
1906 const size_t init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size;
1907
1908 std::vector<uint32_t> init_params = {
1909 offset_src, init_offset, stride_src1, stride_src2, stride_src3, stride_idx1,
1910 stride_idx2, stride_idx3, src_ne0, (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0,
1911 block_size, npr, nrows
1912 };
1913
1914 const uint32_t total_wg_init = npr * nrows;
1915 const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
1916 const uint32_t wg_x_init = std::min(total_wg_init, max_wg);
1917 const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
1918 std::vector<wgpu::BindGroupEntry> init_entries = {
1919 { .binding = 0,
1920 .buffer = ggml_webgpu_tensor_buf(src),
1921 .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1922 .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1923 { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size }
1924 };
1925
1926 pipelines.push_back(argsort_pipeline);
1927 params_list.push_back(std::move(init_params));
1928 entries_list.push_back(std::move(init_entries));
1929 workgroups_list.push_back({ wg_x_init, wg_y_init });
1930
1931 if (merge_passes == 0) {
1932 return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
1933 entries_list, workgroups_list);
1934 }
1935
1936 bool in_is_tmp = start_in_tmp;
1937 uint32_t len = block_size;
1938 while (len < out_ne0) {
1939 const uint32_t nm = CEIL_DIV(out_ne0, 2 * len);
1940
1941 const bool out_is_tmp = !in_is_tmp;
1942 const uint32_t offset_in = in_is_tmp ? offset_tmp : offset_dst;
1943 const uint32_t offset_out = out_is_tmp ? offset_tmp : offset_dst;
1944 const size_t align_in = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
1945 const size_t align_out = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
1946 const size_t size_in = in_is_tmp ? tmp_binding_size : dst_binding_size;
1947 const size_t size_out = out_is_tmp ? tmp_binding_size : dst_binding_size;
1948 const uint32_t top_k_out = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0;
1949 const uint32_t stride_out1 = top_k_out;
1950 const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1];
1951 const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2];
1952
1953 std::vector<uint32_t> merge_params = { offset_src,
1954 offset_in,
1955 offset_out,
1956 stride_src1,
1957 stride_src2,
1958 stride_src3,
1959 stride_idx1,
1960 stride_idx2,
1961 stride_idx3,
1962 stride_out1,
1963 stride_out2,
1964 stride_out3,
1965 out_ne0,
1966 (uint32_t) src->ne[1],
1967 (uint32_t) src->ne[2],
1968 top_k_out,
1969 len,
1970 nm,
1971 nrows };
1972
1973 std::vector<wgpu::BindGroupEntry> merge_entries = {
1974 { .binding = 0,
1975 .buffer = ggml_webgpu_tensor_buf(src),
1976 .offset = ggml_webgpu_tensor_align_offset(ctx, src),
1977 .size = ggml_webgpu_tensor_binding_size(ctx, src) },
1978 { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in },
1979 { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out }
1980 };
1981
1982 const uint32_t total_wg_merge = nm * nrows;
1983 const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg);
1984 const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge);
1985 workgroups_list.push_back({ wg_x_merge, wg_y_merge });
1986 pipelines.push_back(argsort_merge_pipeline);
1987 params_list.push_back(std::move(merge_params));
1988 entries_list.push_back(std::move(merge_entries));
1989
1990 len <<= 1;
1991 in_is_tmp = !in_is_tmp;
1992 }
1993
1994 return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list,
1995 workgroups_list);
1996}
1997
1998static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
1999 std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
2000 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2001 (uint32_t) src->ne[0] };
2002
2003 std::vector<wgpu::BindGroupEntry> entries = {
2004 { .binding = 0,
2005 .buffer = ggml_webgpu_tensor_buf(src),
2006 .offset = ggml_webgpu_tensor_align_offset(ctx, src),
2007 .size = ggml_webgpu_tensor_binding_size(ctx, src) },
2008 { .binding = 1,
2009 .buffer = ggml_webgpu_tensor_buf(dst),
2010 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
2011 .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
2012 };
2013
2014 ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
2015 .vec4 = false,
2016 .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
2017 };
2018 webgpu_pipeline pipeline;
2019 auto it = ctx->cumsum_pipelines.find(1);
2020 if (it != ctx->cumsum_pipelines.end()) {
2021 pipeline = it->second;
2022 } else {
2023 ggml_webgpu_processed_shader processed =
2024 ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
2025 pipeline =
2026 ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
2027 ctx->cumsum_pipelines.emplace(1, pipeline);
2028 }
2029 uint32_t wg_x = ggml_nrows(dst);
2030 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
2031}
2032
2033static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
2034 bool total_sum = dst->op == GGML_OP_SUM;
2035 std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
2036 (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
2037 total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
2038 total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
2039 total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
2040 total_sum ? static_cast<uint32_t>(ggml_nelements(src)) : (uint32_t) src->ne[0],
2041 total_sum ? 1 : (uint32_t) src->ne[1],
2042 total_sum ? 1 : (uint32_t) src->ne[2] };
2043
2044 std::vector<wgpu::BindGroupEntry> entries = {
2045 { .binding = 0,
2046 .buffer = ggml_webgpu_tensor_buf(src),
2047 .offset = ggml_webgpu_tensor_align_offset(ctx, src),
2048 .size = ggml_webgpu_tensor_binding_size(ctx, src) },
2049 { .binding = 1,
2050 .buffer = ggml_webgpu_tensor_buf(dst),
2051 .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
2052 .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
2053 };
2054
2055 ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
2056 .vec4 = false,
2057 .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
2058 };
2059
2060 webgpu_pipeline pipeline;
2061 auto it = ctx->sum_rows_pipelines.find(1);
2062 if (it != ctx->sum_rows_pipelines.end()) {
2063 pipeline = it->second;
2064 } else {
2065 ggml_webgpu_processed_shader processed =
2066 ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
2067 pipeline =
2068 ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
2069 ctx->sum_rows_pipelines.emplace(1, pipeline);
2070 }
2071 uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
2072 return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
2073}
2074
2075// Returns the encoded command, or std::nullopt if the operation is a no-op
2076static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
2077 if (ggml_is_empty(node)) {
2078 return std::nullopt;
2079 }
2080 if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
2081 return std::nullopt;
2082 }
2083 WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
2084
2085 ggml_tensor * src0 = node->src[0];
2086 ggml_tensor * src1 = node->src[1];
2087 ggml_tensor * src2 = node->src[2];
2088
2089 switch (node->op) {
2090 // no-ops
2091 case GGML_OP_NONE:
2092 case GGML_OP_VIEW:
2093 case GGML_OP_PERMUTE:
2094 case GGML_OP_TRANSPOSE:
2095 case GGML_OP_RESHAPE:
2096 return std::nullopt;
2097 case GGML_OP_CPY:
2098 case GGML_OP_CONT:
2099 return ggml_webgpu_cpy(ctx, src0, node);
2100 case GGML_OP_SET_ROWS:
2101 return ggml_webgpu_set_rows(ctx, src0, src1, node);
2102 case GGML_OP_GET_ROWS:
2103 return ggml_webgpu_get_rows(ctx, src0, src1, node);
2104 case GGML_OP_MUL_MAT:
2105 return ggml_webgpu_mul_mat(ctx, src0, src1, node);
2106 case GGML_OP_FLASH_ATTN_EXT:
2107#ifndef __EMSCRIPTEN__
2108 return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
2109#else
2110 return std::nullopt;
2111#endif
2112 case GGML_OP_ADD:
2113 case GGML_OP_SUB:
2114 case GGML_OP_MUL:
2115 case GGML_OP_DIV:
2116 return ggml_webgpu_binary_op(ctx, src0, src1, node);
2117 case GGML_OP_RMS_NORM:
2118 return ggml_webgpu_rms_norm(ctx, src0, node);
2119 case GGML_OP_ROPE:
2120 return ggml_webgpu_rope(ctx, src0, src1, src2, node);
2121 case GGML_OP_GLU:
2122 return ggml_webgpu_glu(ctx, src0, src1, node);
2123 case GGML_OP_SCALE:
2124 return ggml_webgpu_scale(ctx, src0, node);
2125 case GGML_OP_SOFT_MAX:
2126 return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
2127 case GGML_OP_UNARY:
2128 return ggml_webgpu_unary_op(ctx, src0, node);
2129 case GGML_OP_CLAMP:
2130 return ggml_webgpu_unary_op(ctx, src0, node);
2131 case GGML_OP_FILL:
2132 return ggml_webgpu_unary_op(ctx, src0, node);
2133 case GGML_OP_LOG:
2134 return ggml_webgpu_unary_op(ctx, src0, node);
2135 case GGML_OP_PAD:
2136 return ggml_webgpu_pad(ctx, src0, node);
2137 case GGML_OP_ARGMAX:
2138 return ggml_webgpu_argmax(ctx, src0, node);
2139 case GGML_OP_ARGSORT:
2140 return ggml_webgpu_argsort(ctx, src0, node);
2141 case GGML_OP_TOP_K:
2142 // we reuse the same argsort implementation for top_k
2143 return ggml_webgpu_argsort(ctx, src0, node);
2144 case GGML_OP_CUMSUM:
2145 return ggml_webgpu_cumsum(ctx, src0, node);
2146 case GGML_OP_SUM:
2147 case GGML_OP_SUM_ROWS:
2148 return ggml_webgpu_sum_rows(ctx, src0, node);
2149 default:
2150 return std::nullopt;
2151 }
2152}
2153
2154static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2155 WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
2156
2157 ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
2158 webgpu_context ctx = backend_ctx->webgpu_ctx;
2159
2160 WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
2161
2162 ctx->global_ctx->inflight_threads++;
2163
2164 std::vector<webgpu_command> commands;
2165 std::vector<webgpu_submission_futures> futures;
2166 for (int i = 0; i < cgraph->n_nodes; i++) {
2167 if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
2168 commands.push_back(*cmd);
2169 }
2170 // compute the batch size based on the number of inflight threads
2171 uint32_t inflight_threads = ctx->global_ctx->inflight_threads;
2172 uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
2173 WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
2174 if (commands.size() >= batch_size) {
2175 futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool,
2176 &ctx->set_rows_error_buf_pool));
2177 // Process events and check for completed submissions
2178 ctx->global_ctx->instance.ProcessEvents();
2179 ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
2180 commands.clear();
2181 }
2182 }
2183 if (!commands.empty()) {
2184 webgpu_submission_futures new_futures =
2185 ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
2186 futures.push_back(new_futures);
2187 }
2188
2189 ggml_backend_webgpu_wait(ctx->global_ctx, futures);
2190 ctx->global_ctx->inflight_threads--;
2191 WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
2192 return GGML_STATUS_SUCCESS;
2193}
2194
2195static ggml_backend_i ggml_backend_webgpu_i = {
2196 /* .get_name = */ ggml_backend_webgpu_name,
2197 /* .free = */ ggml_backend_webgpu_free,
2198 /* .set_tensor_async = */ NULL,
2199 /* .get_tensor_async = */ NULL,
2200 /* .cpy_tensor_async = */ NULL,
2201 /* .synchronize = */ NULL,
2202 /* .graph_plan_create = */ NULL,
2203 /* .graph_plan_free = */ NULL,
2204 /* .graph_plan_update = */ NULL,
2205 /* .graph_plan_compute = */ NULL,
2206 /* .graph_compute = */ ggml_backend_webgpu_graph_compute,
2207 /* .event_record = */ NULL,
2208 /* .event_wait = */ NULL,
2209 /* .graph_optimize = */ NULL,
2210};
2211
2212/* End GGML Backend Interface */
2213
2214/* GGML Backend Buffer Interface */
2215
2216static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
2217 ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
2218 if (ctx != nullptr && ctx->buffer != nullptr) {
2219 ctx->buffer.Destroy();
2220 delete ctx;
2221 }
2222}
2223
2224// Returns the "fake" base pointer.
2225static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
2226 GGML_UNUSED(buffer);
2227 return webgpu_ptr_base;
2228}
2229
2230static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
2231 ggml_tensor * tensor,
2232 uint8_t value,
2233 size_t offset,
2234 size_t size) {
2235 if (size == 0) {
2236 WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
2237 return;
2238 }
2239
2240 WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
2241
2242 ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2243
2244 WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
2245 << ", " << offset << ", " << size << ")");
2246
2247 size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
2248
2249 // This is a trick to set all bytes of a u32 to the same 1 byte value.
2250 uint32_t val32 = (uint32_t) value * 0x01010101;
2251 ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size);
2252 WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx);
2253}
2254
2255static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
2256 ggml_tensor * tensor,
2257 const void * data,
2258 size_t offset,
2259 size_t size) {
2260 WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
2261 ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2262
2263 WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
2264 << ", " << offset << ", " << size << ")");
2265
2266 size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
2267
2268 buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
2269
2270 if (size % 4 != 0) {
2271 // If size is not a multiple of 4, we need to memset the remaining bytes
2272 size_t remaining_size = size % 4;
2273
2274 // pack the remaining bytes into a uint32_t
2275 uint32_t val32 = 0;
2276
2277 for (size_t i = 0; i < remaining_size; i++) {
2278 ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
2279 }
2280 // memset the remaining bytes
2281 ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
2282 total_offset + (size - remaining_size), remaining_size);
2283 } else {
2284 // wait for WriteBuffer to complete
2285 buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
2286 wgpu::CallbackMode::AllowSpontaneous,
2287 [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
2288 if (status != wgpu::QueueWorkDoneStatus::Success) {
2289 GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
2290 std::string(message).c_str());
2291 }
2292 }),
2293 UINT64_MAX);
2294 }
2295 WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
2296}
2297
2298static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
2299 const ggml_tensor * tensor,
2300 void * data,
2301 size_t offset,
2302 size_t size) {
2303 WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
2304 ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2305 WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
2306 << ", " << offset << ", " << size << ")");
2307 wgpu::Device device = buf_ctx->global_ctx->device;
2308
2309 size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
2310
2311 size_t final_size = size;
2312 if (size % 4 != 0) {
2313 // If size is not a multiple of 4, we need to round it up to the next multiple of 4
2314 final_size = size + (4 - (size % 4));
2315 }
2316
2317 std::lock_guard<std::recursive_mutex> lock(buf_ctx->global_ctx->mutex);
2318
2319 if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr ||
2320 buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) {
2321 // Create a new staging buffer if it doesn't exist or is too small
2322 if (buf_ctx->global_ctx->get_tensor_staging_buf) {
2323 buf_ctx->global_ctx->get_tensor_staging_buf.Destroy();
2324 }
2325 ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size,
2326 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
2327 }
2328
2329 // Copy the data from the buffer to the staging buffer
2330 wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
2331 encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0,
2332 final_size);
2333 wgpu::CommandBuffer commands = encoder.Finish();
2334
2335 // Submit the command buffer to the queue
2336 buf_ctx->global_ctx->queue.Submit(1, &commands);
2337
2338 // Map the staging buffer to read the data
2339 ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf,
2340 wgpu::MapMode::Read, 0, final_size);
2341 // Must specify size here since the staging buffer might be larger than the tensor size
2342 const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
2343
2344 // Copy the data from the mapped range to the output buffer
2345 std::memcpy(data, mapped_range, size);
2346 buf_ctx->global_ctx->get_tensor_staging_buf.Unmap();
2347 WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx);
2348}
2349
2350static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
2351 WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
2352 WEBGPU_CPU_PROFILE_TOTAL_START(clear);
2353 ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
2354 ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size);
2355 WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx);
2356}
2357
2358static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
2359 /* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
2360 /* .get_base = */ ggml_backend_webgpu_buffer_get_base,
2361 /* .init_tensor = */ NULL, // TODO: optional, needed?
2362 /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
2363 /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
2364 /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
2365 /* .cpy_tensor = */ NULL, // TODO: optional, implement this
2366 /* .clear = */ ggml_backend_webgpu_buffer_clear,
2367 /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
2368};
2369
2370/* End GGML Backend Buffer Interface */
2371
2372/* GGML Backend Buffer Type Interface */
2373
2374static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
2375 ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2376 return ctx->device_name.c_str();
2377}
2378
2379static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
2380 size_t size) {
2381 static std::atomic<int> buffer_count;
2382 int buffer_id = buffer_count++;
2383 std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
2384 WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
2385
2386 ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2387 wgpu::Buffer buf;
2388 ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
2389 wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
2390 buf_name.c_str());
2391
2392 ggml_backend_webgpu_buffer_context * buf_ctx =
2393 new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx);
2394
2395 return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
2396}
2397
2398static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
2399 ggml_backend_webgpu_device_context * dev_ctx =
2400 static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2401 return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
2402}
2403
2404// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
2405static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
2406 ggml_backend_webgpu_device_context * dev_ctx =
2407 static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2408 return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize;
2409}
2410
2411static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
2412 const ggml_tensor * tensor) {
2413 ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
2414 size_t res = ggml_nbytes(tensor);
2415 switch (tensor->op) {
2416 case GGML_OP_ARGSORT:
2417 res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
2418 WEBGPU_STORAGE_BUF_BINDING_MULT);
2419 break;
2420 case GGML_OP_TOP_K:
2421 {
2422 const ggml_tensor * src0 = tensor->src[0];
2423 if (src0) {
2424 const size_t full = sizeof(int32_t) * ggml_nelements(src0);
2425 res = ROUNDUP_POW2(
2426 full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
2427 WEBGPU_STORAGE_BUF_BINDING_MULT);
2428 }
2429 }
2430 break;
2431 default:
2432 break;
2433 }
2434 return res;
2435}
2436
2437/* End GGML Backend Buffer Type Interface */
2438
2439/* GGML Backend Device Interface */
2440
2441static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
2442 ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2443 return ctx->device_name.c_str();
2444}
2445
2446static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
2447 ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2448 return ctx->device_desc.c_str();
2449}
2450
2451static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2452 ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
2453 // TODO: for now, return maxBufferSize as both free and total memory
2454 // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
2455 uint64_t max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize;
2456 // If we're on a 32-bit system, clamp to UINTPTR_MAX
2457#if UINTPTR_MAX < UINT64_MAX
2458 uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
2459 if (max_buffer_size > max_ptr_size) {
2460 max_buffer_size = max_ptr_size;
2461 }
2462#endif
2463 *free = static_cast<size_t>(max_buffer_size);
2464 *total = static_cast<size_t>(max_buffer_size);
2465}
2466
2467static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
2468 GGML_UNUSED(dev);
2469 return GGML_BACKEND_DEVICE_TYPE_GPU;
2470}
2471
2472static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2473 props->name = ggml_backend_webgpu_device_get_name(dev);
2474 props->description = ggml_backend_webgpu_device_get_description(dev);
2475 props->type = ggml_backend_webgpu_device_get_type(dev);
2476 ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
2477 props->caps = {
2478 /* .async = */ false,
2479 /* .host_buffer = */ false,
2480 /* .buffer_from_host_ptr = */ false,
2481 /* .events = */ false,
2482 };
2483}
2484
2485static ggml_guid_t ggml_backend_webgpu_guid(void) {
2486 static const char * guid_str = "__ggml_webgpu :)";
2487 return reinterpret_cast<ggml_guid_t>((void *) guid_str);
2488}
2489
2490// Workgroup size is a common constant
2491static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
2492 std::vector<wgpu::ConstantEntry> constants(1);
2493 constants[0].key = "wg_size";
2494 constants[0].value = wg_size;
2495 return constants;
2496}
2497
2498static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
2499 // we use the maximum workgroup size for the memset pipeline
2500 size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
2501 // Size the bytes_per_thread so that the largest buffer size can be handled
2502 ctx->capabilities.memset_bytes_per_thread =
2503 CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
2504 std::vector<wgpu::ConstantEntry> constants(2);
2505 constants[0].key = "wg_size";
2506 constants[0].value = WEBGPU_MAX_WG_SIZE;
2507 constants[1].key = "bytes_per_thread";
2508 constants[1].value = ctx->capabilities.memset_bytes_per_thread;
2509 ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
2510}
2511
2512static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
2513 // Q4/Q5/Q8 classic quantizations
2514 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
2515 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
2516 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
2517 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
2518 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
2519 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
2520 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
2521 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
2522 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
2523 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
2524
2525 // K-quantizations
2526 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
2527 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
2528 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
2529 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
2530 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
2531 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
2532 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
2533 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
2534 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
2535 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
2536
2537 // IQ quantizations (2-, 3-, 4-bit variants)
2538 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
2539 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
2540 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
2541 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
2542 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
2543 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
2544
2545 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
2546 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
2547 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
2548 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
2549
2550 // 1-bit and 4-bit IQ variants
2551 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
2552 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
2553 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
2554 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
2555 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
2556 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
2557 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
2558 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
2559
2560 std::string proc_mul_mat_f32_f32;
2561 std::string proc_mul_mat_f32_f32_vec;
2562 std::string proc_mul_mat_f16_f32;
2563 std::string proc_mul_mat_f16_f32_vec;
2564 std::string proc_mul_mat_f16_f16;
2565 std::string proc_mul_mat_f16_f16_vec;
2566 std::string proc_mul_mat_q4_0_f32;
2567 std::string proc_mul_mat_q4_0_f32_vec;
2568
2569 std::vector<wgpu::ConstantEntry> mul_mat_constants;
2570#ifndef __EMSCRIPTEN__
2571 if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) {
2572 std::map<std::string, std::string> sg_matrix_repls;
2573 sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] =
2574 std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size);
2575 sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
2576 sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
2577 sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
2578 sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
2579 sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
2580 sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m);
2581 sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n);
2582 sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k);
2583 proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
2584 proc_mul_mat_f32_f32_vec =
2585 ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
2586 proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
2587 proc_mul_mat_f16_f32_vec =
2588 ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
2589 proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
2590 proc_mul_mat_f16_f16_vec =
2591 ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
2592 proc_mul_mat_q4_0_f32 =
2593 ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
2594 proc_mul_mat_q4_0_f32_vec =
2595 ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
2596 } else {
2597#endif
2598 mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
2599 mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
2600 mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
2601
2602 std::map<std::string, std::string> reg_repls;
2603 reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
2604 reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
2605
2606 proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
2607 proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
2608 proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
2609 proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
2610 proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
2611 proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
2612 proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
2613 proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
2614#ifndef __EMSCRIPTEN__
2615 }
2616#endif
2617
2618 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2619 webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
2620 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2621 webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
2622 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2623 webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
2624 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2625 webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
2626 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
2627 webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
2628 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2629 webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
2630 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2631 webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
2632 webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2633 webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
2634
2635 std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
2636 mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
2637 mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
2638 mul_mat_vec_constants[1].key = "TILE_K";
2639 mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
2640 mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
2641 mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
2642
2643 webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2644 webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
2645 webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2646 webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
2647 webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2648 webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
2649 webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2650 webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
2651 webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
2652 webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
2653 webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2654 webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
2655 webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
2656 webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
2657}
2658
2659static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
2660 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2661
2662 webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
2663 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
2664 webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2665 webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
2666
2667 webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
2668 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
2669 webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
2670 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
2671 webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
2672 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
2673 webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
2674 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
2675 webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
2676 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
2677 webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
2678 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
2679 webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
2680 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
2681
2682 webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
2683 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
2684 webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
2685 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
2686 webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
2687 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
2688 webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
2689 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
2690 webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
2691 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
2692
2693 webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline(
2694 webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
2695 webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
2696 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
2697 webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
2698 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
2699 webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline(
2700 webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
2701 webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
2702 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
2703 webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
2704 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
2705 webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
2706 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
2707 webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
2708 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
2709 webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
2710 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
2711}
2712
2713static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
2714 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2715
2716 webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
2717 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
2718 webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
2719 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
2720 webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
2721 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
2722 webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
2723 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
2724 webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
2725 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
2726}
2727
2728static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
2729 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2730
2731 webgpu_ctx->rms_norm_pipelines[0] =
2732 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
2733 webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
2734 webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
2735}
2736
2737static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
2738 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2739
2740 webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
2741 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
2742 webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
2743 webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
2744 webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
2745 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
2746 webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
2747 webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
2748
2749 webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
2750 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
2751 webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
2752 webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
2753 webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
2754 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
2755 webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
2756 webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
2757}
2758
2759static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
2760 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2761
2762 // REGLU
2763 webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
2764 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
2765 webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
2766 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
2767 webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
2768 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
2769 webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
2770 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
2771
2772 // GEGLU
2773 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
2774 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
2775 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
2776 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
2777 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
2778 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
2779 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
2780 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
2781
2782 // SWIGLU
2783 webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
2784 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
2785 webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
2786 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
2787 webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2788 webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
2789 webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2790 webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
2791
2792 // SWIGLU_OAI
2793 webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
2794 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
2795 webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2796 webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
2797
2798 // GEGLU_ERF
2799 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
2800 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
2801 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
2802 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
2803 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2804 webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
2805 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2806 webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
2807
2808 // GEGLU_QUICK
2809 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
2810 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
2811 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
2812 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
2813 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
2814 webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
2815 webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
2816 webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
2817}
2818
2819static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
2820 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
2821
2822 webgpu_ctx->scale_pipelines[0] =
2823 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants);
2824 webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace,
2825 "scale_f32_inplace", constants);
2826}
2827
2828static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
2829 std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
2830
2831 // f32 (no mask)
2832 webgpu_ctx->soft_max_pipelines[2][0][0] =
2833 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
2834 webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
2835 webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
2836 webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
2837 webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
2838 webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
2839 webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
2840
2841 // f32 mask (mask_type = 0)
2842 webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
2843 webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
2844 webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
2845 webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
2846 webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
2847 webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
2848 webgpu_ctx->soft_max_pipelines[0][1][1] =
2849 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
2850 "soft_max_f32_mask_f32_sink_inplace", constants);
2851
2852 // f16 mask (mask_type = 1)
2853 webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
2854 webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
2855 webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
2856 webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
2857 webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
2858 webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
2859 webgpu_ctx->soft_max_pipelines[1][1][1] =
2860 ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
2861 "soft_max_f32_mask_f16_sink_inplace", constants);
2862}
2863
2864static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
2865 wgpu::RequestAdapterOptions options = {};
2866
2867#ifndef __EMSCRIPTEN__
2868 // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
2869 const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
2870 wgpu::DawnTogglesDescriptor adapterTogglesDesc;
2871 adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
2872 adapterTogglesDesc.enabledToggleCount = 2;
2873 options.nextInChain = &adapterTogglesDesc;
2874#endif
2875
2876 ctx->webgpu_global_ctx->instance.WaitAny(
2877 ctx->webgpu_global_ctx->instance.RequestAdapter(
2878 &options, wgpu::CallbackMode::AllowSpontaneous,
2879 [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
2880 if (status != wgpu::RequestAdapterStatus::Success) {
2881 GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
2882 return;
2883 }
2884 ctx->webgpu_global_ctx->adapter = std::move(adapter);
2885 }),
2886 UINT64_MAX);
2887 GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
2888
2889 ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
2890
2891 wgpu::AdapterInfo info{};
2892#ifndef __EMSCRIPTEN__
2893 wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
2894 if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2895 info.nextInChain = &subgroup_matrix_configs;
2896 }
2897#endif
2898 ctx->webgpu_global_ctx->adapter.GetInfo(&info);
2899 wgpu::SupportedFeatures features;
2900 ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
2901 // we require f16 support
2902 GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
2903
2904#ifndef __EMSCRIPTEN__
2905 // Only support square f16 matrices of size 8 or 16 for now
2906 bool valid_subgroup_matrix_config = false;
2907 if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
2908 for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
2909 const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
2910 if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
2911 config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
2912 config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
2913 ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
2914 ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
2915 ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K;
2916 valid_subgroup_matrix_config = true;
2917 break;
2918 }
2919 }
2920 }
2921 ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
2922#endif
2923
2924 // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
2925 // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
2926 ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
2927 // Initialize device
2928 std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
2929
2930#ifndef __EMSCRIPTEN__
2931 required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
2932 if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
2933 required_features.push_back(wgpu::FeatureName::Subgroups);
2934 required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
2935 }
2936#endif
2937
2938#ifdef GGML_WEBGPU_GPU_PROFILE
2939 required_features.push_back(wgpu::FeatureName::TimestampQuery);
2940#endif
2941
2942 wgpu::DeviceDescriptor dev_desc;
2943 dev_desc.requiredLimits = &ctx->webgpu_global_ctx->capabilities.limits;
2944 dev_desc.requiredFeatures = required_features.data();
2945 dev_desc.requiredFeatureCount = required_features.size();
2946 dev_desc.SetDeviceLostCallback(
2947 wgpu::CallbackMode::AllowSpontaneous,
2948 [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
2949 if (reason == wgpu::DeviceLostReason::Destroyed) {
2950 return;
2951 }
2952 GGML_UNUSED(device);
2953 GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
2954 std::string(message).c_str());
2955 });
2956 dev_desc.SetUncapturedErrorCallback(
2957 [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
2958 GGML_UNUSED(device);
2959 GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
2960 std::string(message).c_str());
2961 });
2962
2963#ifndef __EMSCRIPTEN__
2964 // Enable Dawn-specific toggles to increase native performance
2965 // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
2966 // only for native performance?
2967 const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2968 "disable_polyfills_on_integer_div_and_mod" };
2969 const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2970 wgpu::DawnTogglesDescriptor deviceTogglesDesc;
2971 deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
2972 deviceTogglesDesc.enabledToggleCount = 4;
2973 deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
2974 deviceTogglesDesc.disabledToggleCount = 1;
2975
2976 dev_desc.nextInChain = &deviceTogglesDesc;
2977#endif
2978
2979 ctx->webgpu_global_ctx->instance.WaitAny(
2980 ctx->webgpu_global_ctx->adapter.RequestDevice(
2981 &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
2982 [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
2983 if (status != wgpu::RequestDeviceStatus::Success) {
2984 GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
2985 return;
2986 }
2987 ctx->webgpu_global_ctx->device = std::move(device);
2988 }),
2989 UINT64_MAX);
2990 GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr);
2991
2992 ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx);
2993 ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES,
2994 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
2995 wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
2996 ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue();
2997
2998#ifdef GGML_WEBGPU_GPU_PROFILE
2999 // Initialize buffer pool for timestamp queries, used for profiling
3000 ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
3001 ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
3002 wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
3003 wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
3004#endif
3005
3006 GGML_LOG_INFO(
3007 "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
3008 "device_desc: %s\n",
3009 info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
3010 std::string(info.device).c_str(), std::string(info.description).c_str());
3011 return true;
3012}
3013
3014static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
3015 ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
3016 webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
3017 webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
3018 webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
3019 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
3020 wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
3021 webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
3022 WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
3023 wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
3024 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
3025
3026 ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
3027 ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
3028 ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
3029 ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
3030 ggml_webgpu_init_rope_pipeline(webgpu_ctx);
3031 ggml_webgpu_init_glu_pipeline(webgpu_ctx);
3032 ggml_webgpu_init_scale_pipeline(webgpu_ctx);
3033 ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
3034#ifdef GGML_WEBGPU_DEBUG
3035 // Initialize debug buffers
3036 ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,
3037 WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
3038 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
3039 ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf,
3040 WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
3041 wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
3042#endif
3043 return webgpu_ctx;
3044}
3045
3046static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) {
3047 GGML_UNUSED(params);
3048
3049 WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()");
3050
3051 ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
3052
3053 auto * backend_ctx = new ggml_backend_webgpu_context();
3054 backend_ctx->name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
3055 backend_ctx->webgpu_ctx = initialize_webgpu_context(dev);
3056
3057 // See GGML Backend Interface section
3058 auto * backend = new ggml_backend();
3059 *backend = {
3060 /* .guid = */ ggml_backend_webgpu_guid(),
3061 /* .interface = */ ggml_backend_webgpu_i,
3062 /* .device = */ dev,
3063 /* .context = */ backend_ctx,
3064 };
3065 return backend;
3066}
3067
3068static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
3069 // See GGML Backend Buffer Type Interface section
3070
3071 static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
3072 /* .iface = */ {
3073 /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
3074 /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
3075 /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
3076 /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
3077 /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size,
3078 /* .is_host = */ NULL, // defaults to false
3079 },
3080 /* .device = */
3081 dev,
3082 /* .context = */
3083 NULL
3084 };
3085
3086 return &ggml_backend_webgpu_buffer_type;
3087}
3088
3089static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
3090 GGML_UNUSED(dev);
3091 return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
3092}
3093
3094static bool ggml_webgpu_supported_qtype(ggml_type type) {
3095 switch (type) {
3096 case GGML_TYPE_Q4_0:
3097 case GGML_TYPE_Q4_1:
3098 case GGML_TYPE_Q5_0:
3099 case GGML_TYPE_Q5_1:
3100 case GGML_TYPE_Q8_0:
3101 case GGML_TYPE_Q2_K:
3102 case GGML_TYPE_Q3_K:
3103 case GGML_TYPE_Q4_K:
3104 case GGML_TYPE_Q5_K:
3105 case GGML_TYPE_Q6_K:
3106 case GGML_TYPE_IQ2_XXS:
3107 case GGML_TYPE_IQ2_XS:
3108 case GGML_TYPE_IQ2_S:
3109 case GGML_TYPE_IQ3_XXS:
3110 case GGML_TYPE_IQ3_S:
3111 case GGML_TYPE_IQ1_S:
3112 case GGML_TYPE_IQ1_M:
3113 case GGML_TYPE_IQ4_NL:
3114 case GGML_TYPE_IQ4_XS:
3115 return true;
3116 default:
3117 return false;
3118 }
3119}
3120
3121static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
3122 ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
3123
3124 ggml_tensor * src0 = op->src[0];
3125 ggml_tensor * src1 = op->src[1];
3126 ggml_tensor * src2 = op->src[2];
3127
3128 // on smaller devices (or CI), tensors may be larger than the max storage buffer size
3129 if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
3130 (src0 != nullptr &&
3131 ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3132 (src1 != nullptr &&
3133 ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
3134 return false;
3135 }
3136
3137 bool supports_op = false;
3138 switch (op->op) {
3139 case GGML_OP_NONE:
3140 case GGML_OP_VIEW:
3141 case GGML_OP_PERMUTE:
3142 case GGML_OP_TRANSPOSE:
3143 case GGML_OP_RESHAPE:
3144 supports_op = true;
3145 break;
3146 case GGML_OP_ADD:
3147 case GGML_OP_SUB:
3148 case GGML_OP_MUL:
3149 case GGML_OP_DIV:
3150 // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
3151 // see https://github.com/ggml-org/llama.cpp/pull/16857
3152 supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
3153 (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
3154 break;
3155 case GGML_OP_CPY:
3156 case GGML_OP_CONT:
3157 supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
3158 (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
3159 (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
3160 break;
3161 case GGML_OP_SET_ROWS:
3162 supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
3163 (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
3164 break;
3165 case GGML_OP_GET_ROWS:
3166 if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {
3167 supports_op = (op->type == GGML_TYPE_F32);
3168 } else if (src0->type == GGML_TYPE_I32) {
3169 supports_op = op->type == GGML_TYPE_I32;
3170 }
3171 break;
3172 case GGML_OP_MUL_MAT:
3173 {
3174 switch (src1->type) {
3175 case GGML_TYPE_F16:
3176 supports_op |= (src0->type == GGML_TYPE_F16);
3177 break;
3178 case GGML_TYPE_F32:
3179 switch (src0->type) {
3180 case GGML_TYPE_F32:
3181 case GGML_TYPE_F16:
3182 case GGML_TYPE_Q4_0:
3183 case GGML_TYPE_Q4_1:
3184 case GGML_TYPE_Q5_0:
3185 case GGML_TYPE_Q5_1:
3186 case GGML_TYPE_Q8_0:
3187 case GGML_TYPE_Q2_K:
3188 case GGML_TYPE_Q3_K:
3189 case GGML_TYPE_Q4_K:
3190 case GGML_TYPE_Q5_K:
3191 case GGML_TYPE_Q6_K:
3192 case GGML_TYPE_IQ2_XXS:
3193 case GGML_TYPE_IQ2_XS:
3194 case GGML_TYPE_IQ2_S:
3195 case GGML_TYPE_IQ3_XXS:
3196 case GGML_TYPE_IQ3_S:
3197 case GGML_TYPE_IQ1_S:
3198 case GGML_TYPE_IQ1_M:
3199 case GGML_TYPE_IQ4_NL:
3200 case GGML_TYPE_IQ4_XS:
3201 supports_op = true;
3202 break;
3203 default:
3204 break;
3205 }
3206 default:
3207 break;
3208 }
3209 break;
3210 }
3211 case GGML_OP_FLASH_ATTN_EXT:
3212 {
3213#ifndef __EMSCRIPTEN__
3214 if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
3215 break;
3216 }
3217 // Head dimensions must fit in workgroup memory with minimum tile sizes
3218 size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
3219 const bool has_mask = op->src[3] != nullptr;
3220 const bool kv_direct = src1->type == GGML_TYPE_F16 &&
3221 (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
3222 (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
3223 const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
3224 ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
3225 (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
3226 if (min_bytes > limit_bytes) {
3227 break;
3228 }
3229
3230 supports_op = src0->type == GGML_TYPE_F32 &&
3231 (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
3232 src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
3233 src2->type == src1->type && op->type == GGML_TYPE_F32;
3234#endif
3235 break;
3236 }
3237 case GGML_OP_RMS_NORM:
3238 supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
3239 break;
3240 case GGML_OP_ROPE:
3241 supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
3242 break;
3243 case GGML_OP_GLU:
3244 switch (ggml_get_glu_op(op)) {
3245 case GGML_GLU_OP_REGLU:
3246 case GGML_GLU_OP_GEGLU:
3247 case GGML_GLU_OP_SWIGLU:
3248 case GGML_GLU_OP_GEGLU_ERF:
3249 case GGML_GLU_OP_GEGLU_QUICK:
3250 supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
3251 break;
3252 case GGML_GLU_OP_SWIGLU_OAI:
3253 supports_op = op->type == GGML_TYPE_F32;
3254 break;
3255 default:
3256 break;
3257 }
3258 break;
3259 case GGML_OP_SCALE:
3260 supports_op = op->type == GGML_TYPE_F32;
3261 break;
3262 case GGML_OP_SOFT_MAX:
3263 supports_op = op->type == GGML_TYPE_F32;
3264 break;
3265 case GGML_OP_UNARY:
3266 {
3267 const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);
3268
3269 switch (UNARY_OP) {
3270 case GGML_UNARY_OP_ABS:
3271 case GGML_UNARY_OP_SGN:
3272 case GGML_UNARY_OP_NEG:
3273 case GGML_UNARY_OP_STEP:
3274 case GGML_UNARY_OP_TANH:
3275 case GGML_UNARY_OP_ELU:
3276 case GGML_UNARY_OP_RELU:
3277 case GGML_UNARY_OP_SIGMOID:
3278 case GGML_UNARY_OP_GELU:
3279 case GGML_UNARY_OP_GELU_QUICK:
3280 case GGML_UNARY_OP_SILU:
3281 case GGML_UNARY_OP_HARDSWISH:
3282 case GGML_UNARY_OP_HARDSIGMOID:
3283 case GGML_UNARY_OP_EXP:
3284 case GGML_UNARY_OP_GELU_ERF:
3285 case GGML_UNARY_OP_SOFTPLUS:
3286 case GGML_UNARY_OP_EXPM1:
3287 case GGML_UNARY_OP_FLOOR:
3288 case GGML_UNARY_OP_CEIL:
3289 case GGML_UNARY_OP_ROUND:
3290 case GGML_UNARY_OP_TRUNC:
3291 case GGML_UNARY_OP_XIELU:
3292 supports_op =
3293 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3294 break;
3295 default:
3296 break;
3297 }
3298 }
3299 break;
3300 case GGML_OP_CLAMP:
3301 supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3302 break;
3303 case GGML_OP_FILL:
3304 supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
3305 break;
3306 case GGML_OP_LOG:
3307 supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
3308 break;
3309 case GGML_OP_PAD:
3310 supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
3311 break;
3312 case GGML_OP_ARGMAX:
3313 supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32;
3314 break;
3315 case GGML_OP_ARGSORT:
3316 supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
3317 break;
3318 case GGML_OP_TOP_K:
3319 supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
3320 break;
3321 case GGML_OP_CUMSUM:
3322 supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type;
3323 break;
3324 case GGML_OP_SUM:
3325 case GGML_OP_SUM_ROWS:
3326 supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);
3327 break;
3328 default:
3329 break;
3330 }
3331 if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
3332 (src0 != nullptr &&
3333 ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3334 (src1 != nullptr &&
3335 ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
3336 (src2 != nullptr &&
3337 ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
3338 supports_op = false;
3339 WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
3340 }
3341
3342 if (!supports_op) {
3343 WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
3344 << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
3345 << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
3346 << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
3347 } else {
3348 WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
3349 << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
3350 << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
3351 << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
3352 }
3353 return supports_op;
3354}
3355
3356static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
3357 /* .get_name = */ ggml_backend_webgpu_device_get_name,
3358 /* .get_description = */ ggml_backend_webgpu_device_get_description,
3359 /* .get_memory = */ ggml_backend_webgpu_device_get_memory,
3360 /* .get_type = */ ggml_backend_webgpu_device_get_type,
3361 /* .get_props = */ ggml_backend_webgpu_device_get_props,
3362 /* .init_backend = */ ggml_backend_webgpu_backend_init,
3363 /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
3364 /* .get_host_buffer_type = */ NULL,
3365 /* .buffer_from_host_ptr = */ NULL,
3366 /* .supports_op = */ ggml_backend_webgpu_device_supports_op,
3367 /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
3368 /* .offload_op = */ NULL,
3369 /* .event_new = */ NULL,
3370 /* .event_free = */ NULL,
3371 /* .event_synchronize = */ NULL,
3372};
3373
3374/* End GGML Backend Device Interface */
3375
3376/* GGML Backend Registration Interface */
3377
3378static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
3379 ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
3380 return ctx->name;
3381}
3382
3383static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
3384 ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
3385 return ctx->device_count;
3386}
3387
3388// Only one device is supported for now
3389static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
3390 GGML_ASSERT(index == 0);
3391 WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
3392
3393 WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device);
3394
3395 ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
3396
3397 create_webgpu_device(reg_ctx);
3398
3399 static ggml_backend_webgpu_device_context device_ctx;
3400 device_ctx.device_name = GGML_WEBGPU_NAME;
3401 device_ctx.device_desc = GGML_WEBGPU_NAME;
3402 device_ctx.webgpu_global_ctx = reg_ctx->webgpu_global_ctx;
3403 // See GGML Backend Device Interface section
3404 static ggml_backend_device device = {
3405 /* .iface = */ ggml_backend_webgpu_device_i,
3406 /* .reg = */ reg,
3407 /* .context = */ &device_ctx,
3408 };
3409
3410 WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx);
3411 return &device;
3412}
3413
3414static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
3415 /* .get_name = */ ggml_backend_webgpu_reg_get_name,
3416 /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
3417 /* .get_device = */ ggml_backend_webgpu_reg_get_device,
3418 /* .get_proc_address = */ NULL,
3419};
3420
3421/* End GGML Backend Registration Interface */
3422
3423ggml_backend_reg_t ggml_backend_webgpu_reg() {
3424 WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
3425
3426 static ggml_backend_webgpu_reg_context ctx;
3427 ctx.name = GGML_WEBGPU_NAME;
3428 ctx.device_count = 1;
3429
3430 wgpu::InstanceDescriptor instance_descriptor{};
3431 std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
3432 instance_descriptor.requiredFeatures = instance_features.data();
3433 instance_descriptor.requiredFeatureCount = instance_features.size();
3434
3435#ifndef __EMSCRIPTEN__
3436 const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
3437 wgpu::DawnTogglesDescriptor instanceTogglesDesc;
3438 instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
3439 instanceTogglesDesc.enabledToggleCount = 1;
3440 instance_descriptor.nextInChain = &instanceTogglesDesc;
3441#endif
3442
3443 wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor);
3444 ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
3445 ctx.webgpu_global_ctx->instance = std::move(inst);
3446
3447#ifdef __EMSCRIPTEN__
3448 if (ctx.webgpu_global_ctx->instance == nullptr) {
3449 GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
3450 return nullptr;
3451 }
3452#endif
3453 GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
3454
3455 static ggml_backend_reg reg = {
3456 /* .api_version = */ GGML_BACKEND_API_VERSION,
3457 /* .iface = */ ggml_backend_webgpu_reg_i,
3458 /* .context = */ &ctx,
3459 };
3460 return ®
3461}
3462
3463ggml_backend_t ggml_backend_webgpu_init(void) {
3464 ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
3465
3466 return ggml_backend_webgpu_backend_init(dev, nullptr);
3467}
3468
3469GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)