1#ifndef GGML_WEBGPU_SHADER_LIB_HPP
  2#define GGML_WEBGPU_SHADER_LIB_HPP
  3
  4#include "ggml.h"
  5#include "pre_wgsl.hpp"
  6
  7#include <memory>
  8#include <string>
  9#include <vector>
 10
 11#define GGML_WEBGPU_F16_SIZE_BYTES                   2
 12#define GGML_WEBGPU_F32_SIZE_BYTES                   4
 13#define GGML_WEBGPU_I32_SIZE_BYTES                   4
 14#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
 15#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE     128u
 16// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
 17#define GGML_WEBGPU_KV_SEQ_PAD                       256u
 18
 19#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
 20
 21struct ggml_webgpu_processed_shader {
 22    std::string           wgsl;
 23    std::string           variant;
 24    std::shared_ptr<void> decisions;
 25};
 26
 27// Same hash combine function as in boost
 28template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
 29    seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
 30}
 31
 32/** FlashAttention */
 33
 34struct ggml_webgpu_flash_attn_pipeline_key {
 35    ggml_type kv_type;
 36    uint32_t  head_dim_qk;
 37    uint32_t  head_dim_v;
 38    bool      kv_direct;
 39    bool      has_mask;
 40    bool      has_sinks;
 41    bool      uses_logit_softcap;
 42
 43    bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
 44        return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
 45               kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
 46               uses_logit_softcap == other.uses_logit_softcap;
 47    }
 48};
 49
 50struct ggml_webgpu_flash_attn_pipeline_key_hash {
 51    size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
 52        size_t seed = 0;
 53        ggml_webgpu_hash_combine(seed, key.kv_type);
 54        ggml_webgpu_hash_combine(seed, key.head_dim_qk);
 55        ggml_webgpu_hash_combine(seed, key.head_dim_v);
 56        ggml_webgpu_hash_combine(seed, key.kv_direct);
 57        ggml_webgpu_hash_combine(seed, key.has_mask);
 58        ggml_webgpu_hash_combine(seed, key.has_sinks);
 59        ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
 60        return seed;
 61    }
 62};
 63
 64struct ggml_webgpu_flash_attn_shader_lib_context {
 65    ggml_webgpu_flash_attn_pipeline_key key;
 66    uint32_t                            sg_mat_m;
 67    uint32_t                            sg_mat_n;
 68    uint32_t                            sg_mat_k;
 69    size_t                              wg_mem_limit_bytes;
 70    uint32_t                            max_subgroup_size;
 71};
 72
 73struct ggml_webgpu_flash_attn_shader_decisions {
 74    uint32_t q_tile  = 0;
 75    uint32_t kv_tile = 0;
 76    uint32_t wg_size = 0;
 77};
 78
 79// This is exposed because it's necessary in supports_op
 80inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
 81                                                  uint32_t kv_tile,
 82                                                  uint32_t head_dim_qk,
 83                                                  uint32_t head_dim_v,
 84                                                  bool     has_mask,
 85                                                  bool     kv_direct) {
 86    const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
 87    size_t         f16_elems    = 0;
 88    size_t         f32_elems    = 0;
 89    f16_elems += q_tile * head_dim_qk;        // q_shmem
 90    if (!kv_direct) {
 91        f16_elems += kv_tile * max_head_dim;  // kv_shmem
 92    }
 93    f16_elems += q_tile * head_dim_v;         // o_shmem
 94    if (has_mask) {
 95        f16_elems += q_tile * kv_tile;        // mask_shmem
 96    }
 97    f16_elems += q_tile * kv_tile;            // inter_shmem
 98    f32_elems += q_tile;                      // row_max_shmem
 99    f32_elems += q_tile;                      // exp_sum_shmem
100    return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
101}
102
103static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
104    const size_t limit_bytes = context.wg_mem_limit_bytes;
105    const size_t q_tile      = context.sg_mat_m;
106    const size_t base_q_bytes =
107        (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
108        2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
109    size_t bytes_per_kv = 0;
110    if (!context.key.kv_direct) {
111        bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
112    }
113    if (context.key.has_mask) {
114        bytes_per_kv += q_tile;
115    }
116    bytes_per_kv += q_tile;
117    bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
118    const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
119    return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
120}
121
122inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
123    pre_wgsl::Preprocessor &                          preprocessor,
124    const char *                                      shader_src,
125    const ggml_webgpu_flash_attn_shader_lib_context & context) {
126    std::vector<std::string> defines;
127    std::string              variant = "flash_attn";
128
129    switch (context.key.kv_type) {
130        case GGML_TYPE_F32:
131            defines.push_back("KV_F32");
132            break;
133        case GGML_TYPE_F16:
134            defines.push_back("KV_F16");
135            break;
136        case GGML_TYPE_Q4_0:
137            defines.push_back("KV_Q4_0");
138            break;
139        case GGML_TYPE_Q8_0:
140            defines.push_back("KV_Q8_0");
141            break;
142        default:
143            GGML_ABORT("Unsupported KV type for flash attention shader");
144    }
145    variant += std::string("_") + ggml_type_name(context.key.kv_type);
146
147    if (context.key.has_mask) {
148        defines.push_back("MASK");
149        variant += "_mask";
150    }
151    if (context.key.has_sinks) {
152        defines.push_back("SINKS");
153        variant += "_sinks";
154    }
155    if (context.key.uses_logit_softcap) {
156        defines.push_back("LOGIT_SOFTCAP");
157        variant += "_lgsc";
158    }
159
160    if (context.key.kv_direct) {
161        defines.push_back("KV_DIRECT");
162        variant += "_kvdirect";
163    }
164
165    defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
166    variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
167
168    defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
169    variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
170    // For now these are not part of the variant name
171    defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
172    defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
173    defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
174
175    // Add chosen Q/KV tile sizes
176    uint32_t q_tile  = context.sg_mat_m;
177    uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
178                                context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
179    if (context.key.kv_direct) {
180        GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
181        // Avoids having to use bounds-checks and decreasing performance for direct KV loads
182        while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
183            kv_tile -= context.sg_mat_n;
184        }
185    }
186
187    defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
188    defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
189
190    // workgroup size
191    uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
192
193    defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
194
195    ggml_webgpu_processed_shader result;
196    result.wgsl        = preprocessor.preprocess(shader_src, defines);
197    result.variant     = variant;
198    auto decisions     = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
199    decisions->q_tile  = q_tile;
200    decisions->kv_tile = kv_tile;
201    decisions->wg_size = wg_size;
202    result.decisions   = decisions;
203    return result;
204}
205
206/** Generic **/
207
208struct ggml_webgpu_generic_shader_lib_context {
209    int      vec4;
210    uint32_t max_wg_size;
211};
212
213struct ggml_webgpu_generic_shader_decisions {
214    uint32_t wg_size;
215};
216
217inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
218    pre_wgsl::Preprocessor &                       preprocessor,
219    const char *                                   shader_src,
220    const ggml_webgpu_generic_shader_lib_context & context,
221    const std::string &                            base_variant) {
222    std::vector<std::string> defines;
223    std::string              variant = base_variant;
224
225    if (context.vec4) {
226        defines.push_back("VEC4");
227        variant += "_vec";
228    }
229
230    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
231
232    ggml_webgpu_processed_shader result;
233    result.wgsl    = preprocessor.preprocess(shader_src, defines);
234    result.variant = variant;
235    return result;
236}
237
238/** Pad **/
239
240struct ggml_webgpu_pad_pipeline_key {
241    bool circular;
242
243    bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
244};
245
246struct ggml_webgpu_pad_pipeline_key_hash {
247    size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
248        size_t seed = 0;
249        ggml_webgpu_hash_combine(seed, key.circular);
250        return seed;
251    }
252};
253
254struct ggml_webgpu_pad_shader_lib_context {
255    ggml_webgpu_pad_pipeline_key key;
256    uint32_t                     max_wg_size;
257};
258
259inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
260    pre_wgsl::Preprocessor &                   preprocessor,
261    const char *                               shader_src,
262    const ggml_webgpu_pad_shader_lib_context & context) {
263    std::vector<std::string> defines;
264    std::string              variant = "pad";
265
266    if (context.key.circular) {
267        defines.push_back("CIRCULAR");
268        variant += "_circular";
269    }
270
271    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
272
273    ggml_webgpu_processed_shader result;
274    result.wgsl        = preprocessor.preprocess(shader_src, defines);
275    result.variant     = variant;
276    auto decisions     = std::make_shared<ggml_webgpu_generic_shader_decisions>();
277    decisions->wg_size = context.max_wg_size;
278    result.decisions   = decisions;
279    return result;
280}
281
282/** Argsort **/
283
284struct ggml_webgpu_argsort_shader_lib_context {
285    uint32_t max_wg_size;
286    size_t   wg_mem_limit_bytes;
287    int32_t  order;
288};
289
290struct ggml_webgpu_argsort_shader_decisions {
291    uint32_t wg_size = 0;
292};
293
294inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
295    pre_wgsl::Preprocessor &                       preprocessor,
296    const char *                                   shader_src,
297    const ggml_webgpu_argsort_shader_lib_context & context) {
298    std::vector<std::string> defines;
299    std::string              variant = "argsort";
300    defines.push_back(std::string("ORDER=") + std::to_string(context.order));
301    variant += std::string("_order") + std::to_string(context.order);
302    uint32_t wg_size = 1;
303    while (wg_size * 2 <= context.max_wg_size &&
304           wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
305        wg_size *= 2;
306    }
307    defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
308    ggml_webgpu_processed_shader result;
309    result.wgsl        = preprocessor.preprocess(shader_src, defines);
310    result.variant     = variant;
311    auto decisions     = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
312    decisions->wg_size = wg_size;
313    result.decisions   = decisions;
314    return result;
315}
316
317inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
318    pre_wgsl::Preprocessor &                       preprocessor,
319    const char *                                   shader_src,
320    const ggml_webgpu_argsort_shader_lib_context & context) {
321    std::vector<std::string> defines;
322    std::string              variant = "argsort_merge";
323    defines.push_back(std::string("ORDER=") + std::to_string(context.order));
324    variant += std::string("_order") + std::to_string(context.order);
325    uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
326    defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
327    ggml_webgpu_processed_shader result;
328    result.wgsl        = preprocessor.preprocess(shader_src, defines);
329    result.variant     = variant;
330    auto decisions     = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
331    decisions->wg_size = wg_size;
332    result.decisions   = decisions;
333    return result;
334}
335
336/** Set Rows **/
337
338struct ggml_webgpu_set_rows_pipeline_key {
339    int dst_type;
340    int vec4;
341    int i64_idx;
342
343    bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
344        return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
345    }
346};
347
348struct ggml_webgpu_set_rows_pipeline_key_hash {
349    size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
350        size_t seed = 0;
351        ggml_webgpu_hash_combine(seed, key.dst_type);
352        ggml_webgpu_hash_combine(seed, key.vec4);
353        ggml_webgpu_hash_combine(seed, key.i64_idx);
354        return seed;
355    }
356};
357
358struct ggml_webgpu_set_rows_shader_lib_context {
359    ggml_webgpu_set_rows_pipeline_key key;
360    uint32_t                          max_wg_size;
361};
362
363inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
364    pre_wgsl::Preprocessor &                        preprocessor,
365    const char *                                    shader_src,
366    const ggml_webgpu_set_rows_shader_lib_context & context) {
367    std::vector<std::string> defines;
368    std::string              variant = "set_rows";
369
370    switch (context.key.dst_type) {
371        case GGML_TYPE_F32:
372            defines.push_back("DST_F32");
373            variant += "_dstf32";
374            break;
375        case GGML_TYPE_F16:
376            defines.push_back("DST_F16");
377            variant += "_dstf16";
378            break;
379        default:
380            GGML_ABORT("Unsupported dst type for set_rows shader");
381    }
382
383    if (context.key.vec4) {
384        defines.push_back("VEC4");
385        variant += "_vec";
386    }
387    if (context.key.i64_idx) {
388        defines.push_back("I64_IDX");
389        variant += "_i64idx";
390    }
391
392    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
393
394    ggml_webgpu_processed_shader result;
395    result.wgsl        = preprocessor.preprocess(shader_src, defines);
396    result.variant     = variant;
397    auto decisions     = std::make_shared<ggml_webgpu_generic_shader_decisions>();
398    decisions->wg_size = context.max_wg_size;
399    result.decisions   = decisions;
400    return result;
401}
402
403struct ggml_webgpu_unary_pipeline_key {
404    int  type;
405    int  op;
406    bool is_unary;  // many unary operators fall under the GGML_OP_UNARY umbrella
407    bool inplace;
408
409    bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
410        return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
411    }
412};
413
414struct ggml_webgpu_unary_pipeline_key_hash {
415    size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
416        size_t seed = 0;
417        ggml_webgpu_hash_combine(seed, key.type);
418        ggml_webgpu_hash_combine(seed, key.op);
419        ggml_webgpu_hash_combine(seed, key.is_unary);
420        ggml_webgpu_hash_combine(seed, key.inplace);
421        return seed;
422    }
423};
424
425struct ggml_webgpu_unary_shader_lib_context {
426    ggml_webgpu_unary_pipeline_key key;
427    uint32_t                       max_wg_size;
428};
429
430inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
431    pre_wgsl::Preprocessor &                     preprocessor,
432    const char *                                 shader_src,
433    const ggml_webgpu_unary_shader_lib_context & context) {
434    std::vector<std::string> defines;
435    std::string              variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
436                                                              ggml_op_name((ggml_op) context.key.op);
437    // Operation-specific behavior
438    defines.push_back(variant);
439
440    switch (context.key.type) {
441        case GGML_TYPE_F32:
442            defines.push_back("TYPE_F32");
443            variant += "_f32";
444            break;
445        case GGML_TYPE_F16:
446            defines.push_back("TYPE_F16");
447            variant += "_f16";
448            break;
449        default:
450            GGML_ABORT("Unsupported type for unary shader");
451    }
452
453    if (context.key.inplace) {
454        defines.push_back("INPLACE");
455        variant += "_inplace";
456    }
457
458    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
459
460    ggml_webgpu_processed_shader result;
461    result.wgsl        = preprocessor.preprocess(shader_src, defines);
462    result.variant     = variant;
463    auto decisions     = std::make_shared<ggml_webgpu_generic_shader_decisions>();
464    decisions->wg_size = context.max_wg_size;
465    result.decisions   = decisions;
466    return result;
467}
468
469/** Binary **/
470
471struct ggml_webgpu_binary_pipeline_key {
472    int  type;
473    int  op;
474    bool inplace;
475    bool overlap;
476
477    bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
478        return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
479    }
480};
481
482struct ggml_webgpu_binary_pipeline_key_hash {
483    size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
484        size_t seed = 0;
485        ggml_webgpu_hash_combine(seed, key.type);
486        ggml_webgpu_hash_combine(seed, key.op);
487        ggml_webgpu_hash_combine(seed, key.inplace);
488        ggml_webgpu_hash_combine(seed, key.overlap);
489        return seed;
490    }
491};
492
493struct ggml_webgpu_binary_shader_lib_context {
494    ggml_webgpu_binary_pipeline_key key;
495    uint32_t                        max_wg_size;
496};
497
498inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
499    pre_wgsl::Preprocessor &                      preprocessor,
500    const char *                                  shader_src,
501    const ggml_webgpu_binary_shader_lib_context & context) {
502    std::vector<std::string> defines;
503    std::string              op_name = ggml_op_name((ggml_op) context.key.op);
504    std::string              variant = op_name;
505
506    defines.push_back(std::string("OP_") + op_name);
507
508    switch (context.key.type) {
509        case GGML_TYPE_F32:
510            defines.push_back("TYPE_F32");
511            variant += "_f32";
512            break;
513        case GGML_TYPE_F16:
514            defines.push_back("TYPE_F16");
515            variant += "_f16";
516            break;
517        default:
518            GGML_ABORT("Unsupported type for binary shader");
519    }
520
521    if (context.key.inplace) {
522        defines.push_back("INPLACE");
523        variant += "_inplace";
524    } else if (context.key.overlap) {
525        defines.push_back("OVERLAP");
526        variant += "_overlap";
527    }
528
529    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
530    ggml_webgpu_processed_shader result;
531    result.wgsl        = preprocessor.preprocess(shader_src, defines);
532    result.variant     = variant;
533    auto decisions     = std::make_shared<ggml_webgpu_generic_shader_decisions>();
534    decisions->wg_size = context.max_wg_size;
535    result.decisions   = decisions;
536    return result;
537}
538#endif  // GGML_WEBGPU_SHADER_LIB_HPP