1#ifndef GGML_WEBGPU_SHADER_LIB_HPP
2#define GGML_WEBGPU_SHADER_LIB_HPP
3
4#include "ggml.h"
5#include "pre_wgsl.hpp"
6
7#include <memory>
8#include <string>
9#include <vector>
10
11#define GGML_WEBGPU_F16_SIZE_BYTES 2
12#define GGML_WEBGPU_F32_SIZE_BYTES 4
13#define GGML_WEBGPU_I32_SIZE_BYTES 4
14#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
15#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
16// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
17#define GGML_WEBGPU_KV_SEQ_PAD 256u
18
19#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
20
21struct ggml_webgpu_processed_shader {
22 std::string wgsl;
23 std::string variant;
24 std::shared_ptr<void> decisions;
25};
26
27// Same hash combine function as in boost
28template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
29 seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
30}
31
32/** FlashAttention */
33
34struct ggml_webgpu_flash_attn_pipeline_key {
35 ggml_type kv_type;
36 uint32_t head_dim_qk;
37 uint32_t head_dim_v;
38 bool kv_direct;
39 bool has_mask;
40 bool has_sinks;
41 bool uses_logit_softcap;
42
43 bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
44 return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
45 kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
46 uses_logit_softcap == other.uses_logit_softcap;
47 }
48};
49
50struct ggml_webgpu_flash_attn_pipeline_key_hash {
51 size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
52 size_t seed = 0;
53 ggml_webgpu_hash_combine(seed, key.kv_type);
54 ggml_webgpu_hash_combine(seed, key.head_dim_qk);
55 ggml_webgpu_hash_combine(seed, key.head_dim_v);
56 ggml_webgpu_hash_combine(seed, key.kv_direct);
57 ggml_webgpu_hash_combine(seed, key.has_mask);
58 ggml_webgpu_hash_combine(seed, key.has_sinks);
59 ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
60 return seed;
61 }
62};
63
64struct ggml_webgpu_flash_attn_shader_lib_context {
65 ggml_webgpu_flash_attn_pipeline_key key;
66 uint32_t sg_mat_m;
67 uint32_t sg_mat_n;
68 uint32_t sg_mat_k;
69 size_t wg_mem_limit_bytes;
70 uint32_t max_subgroup_size;
71};
72
73struct ggml_webgpu_flash_attn_shader_decisions {
74 uint32_t q_tile = 0;
75 uint32_t kv_tile = 0;
76 uint32_t wg_size = 0;
77};
78
79// This is exposed because it's necessary in supports_op
80inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
81 uint32_t kv_tile,
82 uint32_t head_dim_qk,
83 uint32_t head_dim_v,
84 bool has_mask,
85 bool kv_direct) {
86 const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
87 size_t f16_elems = 0;
88 size_t f32_elems = 0;
89 f16_elems += q_tile * head_dim_qk; // q_shmem
90 if (!kv_direct) {
91 f16_elems += kv_tile * max_head_dim; // kv_shmem
92 }
93 f16_elems += q_tile * head_dim_v; // o_shmem
94 if (has_mask) {
95 f16_elems += q_tile * kv_tile; // mask_shmem
96 }
97 f16_elems += q_tile * kv_tile; // inter_shmem
98 f32_elems += q_tile; // row_max_shmem
99 f32_elems += q_tile; // exp_sum_shmem
100 return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
101}
102
103static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
104 const size_t limit_bytes = context.wg_mem_limit_bytes;
105 const size_t q_tile = context.sg_mat_m;
106 const size_t base_q_bytes =
107 (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
108 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
109 size_t bytes_per_kv = 0;
110 if (!context.key.kv_direct) {
111 bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
112 }
113 if (context.key.has_mask) {
114 bytes_per_kv += q_tile;
115 }
116 bytes_per_kv += q_tile;
117 bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
118 const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
119 return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
120}
121
122inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
123 pre_wgsl::Preprocessor & preprocessor,
124 const char * shader_src,
125 const ggml_webgpu_flash_attn_shader_lib_context & context) {
126 std::vector<std::string> defines;
127 std::string variant = "flash_attn";
128
129 switch (context.key.kv_type) {
130 case GGML_TYPE_F32:
131 defines.push_back("KV_F32");
132 break;
133 case GGML_TYPE_F16:
134 defines.push_back("KV_F16");
135 break;
136 case GGML_TYPE_Q4_0:
137 defines.push_back("KV_Q4_0");
138 break;
139 case GGML_TYPE_Q8_0:
140 defines.push_back("KV_Q8_0");
141 break;
142 default:
143 GGML_ABORT("Unsupported KV type for flash attention shader");
144 }
145 variant += std::string("_") + ggml_type_name(context.key.kv_type);
146
147 if (context.key.has_mask) {
148 defines.push_back("MASK");
149 variant += "_mask";
150 }
151 if (context.key.has_sinks) {
152 defines.push_back("SINKS");
153 variant += "_sinks";
154 }
155 if (context.key.uses_logit_softcap) {
156 defines.push_back("LOGIT_SOFTCAP");
157 variant += "_lgsc";
158 }
159
160 if (context.key.kv_direct) {
161 defines.push_back("KV_DIRECT");
162 variant += "_kvdirect";
163 }
164
165 defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
166 variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
167
168 defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
169 variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
170 // For now these are not part of the variant name
171 defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
172 defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
173 defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
174
175 // Add chosen Q/KV tile sizes
176 uint32_t q_tile = context.sg_mat_m;
177 uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
178 context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
179 if (context.key.kv_direct) {
180 GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
181 // Avoids having to use bounds-checks and decreasing performance for direct KV loads
182 while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
183 kv_tile -= context.sg_mat_n;
184 }
185 }
186
187 defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
188 defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
189
190 // workgroup size
191 uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
192
193 defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
194
195 ggml_webgpu_processed_shader result;
196 result.wgsl = preprocessor.preprocess(shader_src, defines);
197 result.variant = variant;
198 auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
199 decisions->q_tile = q_tile;
200 decisions->kv_tile = kv_tile;
201 decisions->wg_size = wg_size;
202 result.decisions = decisions;
203 return result;
204}
205
206/** Generic **/
207
208struct ggml_webgpu_generic_shader_lib_context {
209 int vec4;
210 uint32_t max_wg_size;
211};
212
213struct ggml_webgpu_generic_shader_decisions {
214 uint32_t wg_size;
215};
216
217inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
218 pre_wgsl::Preprocessor & preprocessor,
219 const char * shader_src,
220 const ggml_webgpu_generic_shader_lib_context & context,
221 const std::string & base_variant) {
222 std::vector<std::string> defines;
223 std::string variant = base_variant;
224
225 if (context.vec4) {
226 defines.push_back("VEC4");
227 variant += "_vec";
228 }
229
230 defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
231
232 ggml_webgpu_processed_shader result;
233 result.wgsl = preprocessor.preprocess(shader_src, defines);
234 result.variant = variant;
235 return result;
236}
237
238/** Pad **/
239
240struct ggml_webgpu_pad_pipeline_key {
241 bool circular;
242
243 bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
244};
245
246struct ggml_webgpu_pad_pipeline_key_hash {
247 size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
248 size_t seed = 0;
249 ggml_webgpu_hash_combine(seed, key.circular);
250 return seed;
251 }
252};
253
254struct ggml_webgpu_pad_shader_lib_context {
255 ggml_webgpu_pad_pipeline_key key;
256 uint32_t max_wg_size;
257};
258
259inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
260 pre_wgsl::Preprocessor & preprocessor,
261 const char * shader_src,
262 const ggml_webgpu_pad_shader_lib_context & context) {
263 std::vector<std::string> defines;
264 std::string variant = "pad";
265
266 if (context.key.circular) {
267 defines.push_back("CIRCULAR");
268 variant += "_circular";
269 }
270
271 defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
272
273 ggml_webgpu_processed_shader result;
274 result.wgsl = preprocessor.preprocess(shader_src, defines);
275 result.variant = variant;
276 auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
277 decisions->wg_size = context.max_wg_size;
278 result.decisions = decisions;
279 return result;
280}
281
282/** Argsort **/
283
284struct ggml_webgpu_argsort_shader_lib_context {
285 uint32_t max_wg_size;
286 size_t wg_mem_limit_bytes;
287 int32_t order;
288};
289
290struct ggml_webgpu_argsort_shader_decisions {
291 uint32_t wg_size = 0;
292};
293
294inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
295 pre_wgsl::Preprocessor & preprocessor,
296 const char * shader_src,
297 const ggml_webgpu_argsort_shader_lib_context & context) {
298 std::vector<std::string> defines;
299 std::string variant = "argsort";
300 defines.push_back(std::string("ORDER=") + std::to_string(context.order));
301 variant += std::string("_order") + std::to_string(context.order);
302 uint32_t wg_size = 1;
303 while (wg_size * 2 <= context.max_wg_size &&
304 wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
305 wg_size *= 2;
306 }
307 defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
308 ggml_webgpu_processed_shader result;
309 result.wgsl = preprocessor.preprocess(shader_src, defines);
310 result.variant = variant;
311 auto decisions = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
312 decisions->wg_size = wg_size;
313 result.decisions = decisions;
314 return result;
315}
316
317inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
318 pre_wgsl::Preprocessor & preprocessor,
319 const char * shader_src,
320 const ggml_webgpu_argsort_shader_lib_context & context) {
321 std::vector<std::string> defines;
322 std::string variant = "argsort_merge";
323 defines.push_back(std::string("ORDER=") + std::to_string(context.order));
324 variant += std::string("_order") + std::to_string(context.order);
325 uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
326 defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
327 ggml_webgpu_processed_shader result;
328 result.wgsl = preprocessor.preprocess(shader_src, defines);
329 result.variant = variant;
330 auto decisions = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
331 decisions->wg_size = wg_size;
332 result.decisions = decisions;
333 return result;
334}
335
336/** Set Rows **/
337
338struct ggml_webgpu_set_rows_pipeline_key {
339 int dst_type;
340 int vec4;
341 int i64_idx;
342
343 bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
344 return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
345 }
346};
347
348struct ggml_webgpu_set_rows_pipeline_key_hash {
349 size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
350 size_t seed = 0;
351 ggml_webgpu_hash_combine(seed, key.dst_type);
352 ggml_webgpu_hash_combine(seed, key.vec4);
353 ggml_webgpu_hash_combine(seed, key.i64_idx);
354 return seed;
355 }
356};
357
358struct ggml_webgpu_set_rows_shader_lib_context {
359 ggml_webgpu_set_rows_pipeline_key key;
360 uint32_t max_wg_size;
361};
362
363inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
364 pre_wgsl::Preprocessor & preprocessor,
365 const char * shader_src,
366 const ggml_webgpu_set_rows_shader_lib_context & context) {
367 std::vector<std::string> defines;
368 std::string variant = "set_rows";
369
370 switch (context.key.dst_type) {
371 case GGML_TYPE_F32:
372 defines.push_back("DST_F32");
373 variant += "_dstf32";
374 break;
375 case GGML_TYPE_F16:
376 defines.push_back("DST_F16");
377 variant += "_dstf16";
378 break;
379 default:
380 GGML_ABORT("Unsupported dst type for set_rows shader");
381 }
382
383 if (context.key.vec4) {
384 defines.push_back("VEC4");
385 variant += "_vec";
386 }
387 if (context.key.i64_idx) {
388 defines.push_back("I64_IDX");
389 variant += "_i64idx";
390 }
391
392 defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
393
394 ggml_webgpu_processed_shader result;
395 result.wgsl = preprocessor.preprocess(shader_src, defines);
396 result.variant = variant;
397 auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
398 decisions->wg_size = context.max_wg_size;
399 result.decisions = decisions;
400 return result;
401}
402
403struct ggml_webgpu_unary_pipeline_key {
404 int type;
405 int op;
406 bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
407 bool inplace;
408
409 bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
410 return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
411 }
412};
413
414struct ggml_webgpu_unary_pipeline_key_hash {
415 size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
416 size_t seed = 0;
417 ggml_webgpu_hash_combine(seed, key.type);
418 ggml_webgpu_hash_combine(seed, key.op);
419 ggml_webgpu_hash_combine(seed, key.is_unary);
420 ggml_webgpu_hash_combine(seed, key.inplace);
421 return seed;
422 }
423};
424
425struct ggml_webgpu_unary_shader_lib_context {
426 ggml_webgpu_unary_pipeline_key key;
427 uint32_t max_wg_size;
428};
429
430inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
431 pre_wgsl::Preprocessor & preprocessor,
432 const char * shader_src,
433 const ggml_webgpu_unary_shader_lib_context & context) {
434 std::vector<std::string> defines;
435 std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
436 ggml_op_name((ggml_op) context.key.op);
437 // Operation-specific behavior
438 defines.push_back(variant);
439
440 switch (context.key.type) {
441 case GGML_TYPE_F32:
442 defines.push_back("TYPE_F32");
443 variant += "_f32";
444 break;
445 case GGML_TYPE_F16:
446 defines.push_back("TYPE_F16");
447 variant += "_f16";
448 break;
449 default:
450 GGML_ABORT("Unsupported type for unary shader");
451 }
452
453 if (context.key.inplace) {
454 defines.push_back("INPLACE");
455 variant += "_inplace";
456 }
457
458 defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
459
460 ggml_webgpu_processed_shader result;
461 result.wgsl = preprocessor.preprocess(shader_src, defines);
462 result.variant = variant;
463 auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
464 decisions->wg_size = context.max_wg_size;
465 result.decisions = decisions;
466 return result;
467}
468
469/** Binary **/
470
471struct ggml_webgpu_binary_pipeline_key {
472 int type;
473 int op;
474 bool inplace;
475 bool overlap;
476
477 bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
478 return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
479 }
480};
481
482struct ggml_webgpu_binary_pipeline_key_hash {
483 size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
484 size_t seed = 0;
485 ggml_webgpu_hash_combine(seed, key.type);
486 ggml_webgpu_hash_combine(seed, key.op);
487 ggml_webgpu_hash_combine(seed, key.inplace);
488 ggml_webgpu_hash_combine(seed, key.overlap);
489 return seed;
490 }
491};
492
493struct ggml_webgpu_binary_shader_lib_context {
494 ggml_webgpu_binary_pipeline_key key;
495 uint32_t max_wg_size;
496};
497
498inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
499 pre_wgsl::Preprocessor & preprocessor,
500 const char * shader_src,
501 const ggml_webgpu_binary_shader_lib_context & context) {
502 std::vector<std::string> defines;
503 std::string op_name = ggml_op_name((ggml_op) context.key.op);
504 std::string variant = op_name;
505
506 defines.push_back(std::string("OP_") + op_name);
507
508 switch (context.key.type) {
509 case GGML_TYPE_F32:
510 defines.push_back("TYPE_F32");
511 variant += "_f32";
512 break;
513 case GGML_TYPE_F16:
514 defines.push_back("TYPE_F16");
515 variant += "_f16";
516 break;
517 default:
518 GGML_ABORT("Unsupported type for binary shader");
519 }
520
521 if (context.key.inplace) {
522 defines.push_back("INPLACE");
523 variant += "_inplace";
524 } else if (context.key.overlap) {
525 defines.push_back("OVERLAP");
526 variant += "_overlap";
527 }
528
529 defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
530 ggml_webgpu_processed_shader result;
531 result.wgsl = preprocessor.preprocess(shader_src, defines);
532 result.variant = variant;
533 auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
534 decisions->wg_size = context.max_wg_size;
535 result.decisions = decisions;
536 return result;
537}
538#endif // GGML_WEBGPU_SHADER_LIB_HPP