summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-webgpu
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt80
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp538
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp3469
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp778
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl72
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl106
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl134
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl107
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl930
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl107
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl66
-rwxr-xr-xllama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py147
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl636
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl874
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl323
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl40
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl907
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl97
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl247
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl302
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl267
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl86
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl123
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl295
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl90
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl109
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl345
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl55
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl179
29 files changed, 11509 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt b/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt
new file mode 100644
index 0000000..3ccce58
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt
@@ -0,0 +1,80 @@
+cmake_minimum_required(VERSION 3.13)
+
+find_package(Python3 REQUIRED)
+
+# Shader locations
+set(SHADER_DIR "${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders")
+set(SHADER_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated")
+set(SHADER_HEADER "${SHADER_OUTPUT_DIR}/ggml-wgsl-shaders.hpp")
+file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR})
+
+message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}")
+
+# Find all WGSL files
+file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl")
+
+# Generate the header using a Python script
+add_custom_command(
+ OUTPUT ${SHADER_HEADER}
+ COMMAND ${CMAKE_COMMAND} -E echo "Embedding WGSL shaders to ggml-wgsl-shaders.hpp"
+ COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
+ COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
+ ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
+ --input_dir "${SHADER_DIR}"
+ --output_file "${SHADER_HEADER}"
+ DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
+ VERBATIM
+)
+
+add_custom_target(generate_shaders DEPENDS ${SHADER_HEADER})
+
+ggml_add_backend_library(ggml-webgpu
+ ggml-webgpu.cpp
+ ${SHADER_HEADER}
+ ../../include/ggml-webgpu.h
+)
+
+add_dependencies(ggml-webgpu generate_shaders)
+
+if(EMSCRIPTEN)
+ set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg")
+
+ if(NOT EMDAWNWEBGPU_DIR)
+ # default built-in port
+ target_compile_options(ggml-webgpu PRIVATE "--use-port=emdawnwebgpu")
+ target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu")
+ else()
+ # custom port
+ target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
+ target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
+ endif()
+
+ if (GGML_WEBGPU_JSPI)
+ target_compile_options(ggml-webgpu PRIVATE "-fwasm-exceptions")
+ target_link_options(ggml-webgpu INTERFACE "-sJSPI" "-fwasm-exceptions")
+ else()
+ target_compile_options(ggml-webgpu PRIVATE "-fexceptions")
+ target_link_options(ggml-webgpu INTERFACE "-sASYNCIFY" "-exceptions")
+ endif()
+else()
+ find_package(Dawn REQUIRED)
+ set(DawnWebGPU_TARGET dawn::webgpu_dawn)
+endif()
+
+if (GGML_WEBGPU_DEBUG)
+ target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1)
+ if(EMSCRIPTEN)
+ target_link_options(ggml-webgpu INTERFACE "-sASSERTIONS=2")
+ endif()
+endif()
+
+if (GGML_WEBGPU_CPU_PROFILE)
+ target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_CPU_PROFILE=1)
+endif()
+
+if (GGML_WEBGPU_GPU_PROFILE)
+ target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_GPU_PROFILE=1)
+endif()
+
+target_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR})
+target_link_libraries(ggml-webgpu PRIVATE ${DawnWebGPU_TARGET})
diff --git a/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
new file mode 100644
index 0000000..63f797f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
@@ -0,0 +1,538 @@
+#ifndef GGML_WEBGPU_SHADER_LIB_HPP
+#define GGML_WEBGPU_SHADER_LIB_HPP
+
+#include "ggml.h"
+#include "pre_wgsl.hpp"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#define GGML_WEBGPU_F16_SIZE_BYTES 2
+#define GGML_WEBGPU_F32_SIZE_BYTES 4
+#define GGML_WEBGPU_I32_SIZE_BYTES 4
+#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
+#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
+// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
+#define GGML_WEBGPU_KV_SEQ_PAD 256u
+
+#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
+
+struct ggml_webgpu_processed_shader {
+ std::string wgsl;
+ std::string variant;
+ std::shared_ptr<void> decisions;
+};
+
+// Same hash combine function as in boost
+template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
+ seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+}
+
+/** FlashAttention */
+
+struct ggml_webgpu_flash_attn_pipeline_key {
+ ggml_type kv_type;
+ uint32_t head_dim_qk;
+ uint32_t head_dim_v;
+ bool kv_direct;
+ bool has_mask;
+ bool has_sinks;
+ bool uses_logit_softcap;
+
+ bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
+ return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
+ kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
+ uses_logit_softcap == other.uses_logit_softcap;
+ }
+};
+
+struct ggml_webgpu_flash_attn_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.kv_type);
+ ggml_webgpu_hash_combine(seed, key.head_dim_qk);
+ ggml_webgpu_hash_combine(seed, key.head_dim_v);
+ ggml_webgpu_hash_combine(seed, key.kv_direct);
+ ggml_webgpu_hash_combine(seed, key.has_mask);
+ ggml_webgpu_hash_combine(seed, key.has_sinks);
+ ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
+ return seed;
+ }
+};
+
+struct ggml_webgpu_flash_attn_shader_lib_context {
+ ggml_webgpu_flash_attn_pipeline_key key;
+ uint32_t sg_mat_m;
+ uint32_t sg_mat_n;
+ uint32_t sg_mat_k;
+ size_t wg_mem_limit_bytes;
+ uint32_t max_subgroup_size;
+};
+
+struct ggml_webgpu_flash_attn_shader_decisions {
+ uint32_t q_tile = 0;
+ uint32_t kv_tile = 0;
+ uint32_t wg_size = 0;
+};
+
+// This is exposed because it's necessary in supports_op
+inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
+ uint32_t kv_tile,
+ uint32_t head_dim_qk,
+ uint32_t head_dim_v,
+ bool has_mask,
+ bool kv_direct) {
+ const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
+ size_t f16_elems = 0;
+ size_t f32_elems = 0;
+ f16_elems += q_tile * head_dim_qk; // q_shmem
+ if (!kv_direct) {
+ f16_elems += kv_tile * max_head_dim; // kv_shmem
+ }
+ f16_elems += q_tile * head_dim_v; // o_shmem
+ if (has_mask) {
+ f16_elems += q_tile * kv_tile; // mask_shmem
+ }
+ f16_elems += q_tile * kv_tile; // inter_shmem
+ f32_elems += q_tile; // row_max_shmem
+ f32_elems += q_tile; // exp_sum_shmem
+ return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
+}
+
+static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
+ const size_t limit_bytes = context.wg_mem_limit_bytes;
+ const size_t q_tile = context.sg_mat_m;
+ const size_t base_q_bytes =
+ (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
+ 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
+ size_t bytes_per_kv = 0;
+ if (!context.key.kv_direct) {
+ bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
+ }
+ if (context.key.has_mask) {
+ bytes_per_kv += q_tile;
+ }
+ bytes_per_kv += q_tile;
+ bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
+ const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
+ return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
+}
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_flash_attn_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = "flash_attn";
+
+ switch (context.key.kv_type) {
+ case GGML_TYPE_F32:
+ defines.push_back("KV_F32");
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("KV_F16");
+ break;
+ case GGML_TYPE_Q4_0:
+ defines.push_back("KV_Q4_0");
+ break;
+ case GGML_TYPE_Q8_0:
+ defines.push_back("KV_Q8_0");
+ break;
+ default:
+ GGML_ABORT("Unsupported KV type for flash attention shader");
+ }
+ variant += std::string("_") + ggml_type_name(context.key.kv_type);
+
+ if (context.key.has_mask) {
+ defines.push_back("MASK");
+ variant += "_mask";
+ }
+ if (context.key.has_sinks) {
+ defines.push_back("SINKS");
+ variant += "_sinks";
+ }
+ if (context.key.uses_logit_softcap) {
+ defines.push_back("LOGIT_SOFTCAP");
+ variant += "_lgsc";
+ }
+
+ if (context.key.kv_direct) {
+ defines.push_back("KV_DIRECT");
+ variant += "_kvdirect";
+ }
+
+ defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
+ variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
+
+ defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
+ variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
+ // For now these are not part of the variant name
+ defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
+ defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
+ defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
+
+ // Add chosen Q/KV tile sizes
+ uint32_t q_tile = context.sg_mat_m;
+ uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
+ context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
+ if (context.key.kv_direct) {
+ GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
+ // Avoids having to use bounds-checks and decreasing performance for direct KV loads
+ while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
+ kv_tile -= context.sg_mat_n;
+ }
+ }
+
+ defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
+ defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
+
+ // workgroup size
+ uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
+ decisions->q_tile = q_tile;
+ decisions->kv_tile = kv_tile;
+ decisions->wg_size = wg_size;
+ result.decisions = decisions;
+ return result;
+}
+
+/** Generic **/
+
+struct ggml_webgpu_generic_shader_lib_context {
+ int vec4;
+ uint32_t max_wg_size;
+};
+
+struct ggml_webgpu_generic_shader_decisions {
+ uint32_t wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_generic_shader_lib_context & context,
+ const std::string & base_variant) {
+ std::vector<std::string> defines;
+ std::string variant = base_variant;
+
+ if (context.vec4) {
+ defines.push_back("VEC4");
+ variant += "_vec";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ return result;
+}
+
+/** Pad **/
+
+struct ggml_webgpu_pad_pipeline_key {
+ bool circular;
+
+ bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
+};
+
+struct ggml_webgpu_pad_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.circular);
+ return seed;
+ }
+};
+
+struct ggml_webgpu_pad_shader_lib_context {
+ ggml_webgpu_pad_pipeline_key key;
+ uint32_t max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_pad_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = "pad";
+
+ if (context.key.circular) {
+ defines.push_back("CIRCULAR");
+ variant += "_circular";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = context.max_wg_size;
+ result.decisions = decisions;
+ return result;
+}
+
+/** Argsort **/
+
+struct ggml_webgpu_argsort_shader_lib_context {
+ uint32_t max_wg_size;
+ size_t wg_mem_limit_bytes;
+ int32_t order;
+};
+
+struct ggml_webgpu_argsort_shader_decisions {
+ uint32_t wg_size = 0;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_argsort_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = "argsort";
+ defines.push_back(std::string("ORDER=") + std::to_string(context.order));
+ variant += std::string("_order") + std::to_string(context.order);
+ uint32_t wg_size = 1;
+ while (wg_size * 2 <= context.max_wg_size &&
+ wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
+ wg_size *= 2;
+ }
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ auto decisions = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
+ decisions->wg_size = wg_size;
+ result.decisions = decisions;
+ return result;
+}
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_argsort_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = "argsort_merge";
+ defines.push_back(std::string("ORDER=") + std::to_string(context.order));
+ variant += std::string("_order") + std::to_string(context.order);
+ uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ auto decisions = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
+ decisions->wg_size = wg_size;
+ result.decisions = decisions;
+ return result;
+}
+
+/** Set Rows **/
+
+struct ggml_webgpu_set_rows_pipeline_key {
+ int dst_type;
+ int vec4;
+ int i64_idx;
+
+ bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
+ return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
+ }
+};
+
+struct ggml_webgpu_set_rows_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.dst_type);
+ ggml_webgpu_hash_combine(seed, key.vec4);
+ ggml_webgpu_hash_combine(seed, key.i64_idx);
+ return seed;
+ }
+};
+
+struct ggml_webgpu_set_rows_shader_lib_context {
+ ggml_webgpu_set_rows_pipeline_key key;
+ uint32_t max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_set_rows_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = "set_rows";
+
+ switch (context.key.dst_type) {
+ case GGML_TYPE_F32:
+ defines.push_back("DST_F32");
+ variant += "_dstf32";
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("DST_F16");
+ variant += "_dstf16";
+ break;
+ default:
+ GGML_ABORT("Unsupported dst type for set_rows shader");
+ }
+
+ if (context.key.vec4) {
+ defines.push_back("VEC4");
+ variant += "_vec";
+ }
+ if (context.key.i64_idx) {
+ defines.push_back("I64_IDX");
+ variant += "_i64idx";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = context.max_wg_size;
+ result.decisions = decisions;
+ return result;
+}
+
+struct ggml_webgpu_unary_pipeline_key {
+ int type;
+ int op;
+ bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
+ bool inplace;
+
+ bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
+ return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
+ }
+};
+
+struct ggml_webgpu_unary_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.type);
+ ggml_webgpu_hash_combine(seed, key.op);
+ ggml_webgpu_hash_combine(seed, key.is_unary);
+ ggml_webgpu_hash_combine(seed, key.inplace);
+ return seed;
+ }
+};
+
+struct ggml_webgpu_unary_shader_lib_context {
+ ggml_webgpu_unary_pipeline_key key;
+ uint32_t max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_unary_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
+ ggml_op_name((ggml_op) context.key.op);
+ // Operation-specific behavior
+ defines.push_back(variant);
+
+ switch (context.key.type) {
+ case GGML_TYPE_F32:
+ defines.push_back("TYPE_F32");
+ variant += "_f32";
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("TYPE_F16");
+ variant += "_f16";
+ break;
+ default:
+ GGML_ABORT("Unsupported type for unary shader");
+ }
+
+ if (context.key.inplace) {
+ defines.push_back("INPLACE");
+ variant += "_inplace";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = context.max_wg_size;
+ result.decisions = decisions;
+ return result;
+}
+
+/** Binary **/
+
+struct ggml_webgpu_binary_pipeline_key {
+ int type;
+ int op;
+ bool inplace;
+ bool overlap;
+
+ bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
+ return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
+ }
+};
+
+struct ggml_webgpu_binary_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.type);
+ ggml_webgpu_hash_combine(seed, key.op);
+ ggml_webgpu_hash_combine(seed, key.inplace);
+ ggml_webgpu_hash_combine(seed, key.overlap);
+ return seed;
+ }
+};
+
+struct ggml_webgpu_binary_shader_lib_context {
+ ggml_webgpu_binary_pipeline_key key;
+ uint32_t max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_binary_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string op_name = ggml_op_name((ggml_op) context.key.op);
+ std::string variant = op_name;
+
+ defines.push_back(std::string("OP_") + op_name);
+
+ switch (context.key.type) {
+ case GGML_TYPE_F32:
+ defines.push_back("TYPE_F32");
+ variant += "_f32";
+ break;
+ case GGML_TYPE_F16:
+ defines.push_back("TYPE_F16");
+ variant += "_f16";
+ break;
+ default:
+ GGML_ABORT("Unsupported type for binary shader");
+ }
+
+ if (context.key.inplace) {
+ defines.push_back("INPLACE");
+ variant += "_inplace";
+ } else if (context.key.overlap) {
+ defines.push_back("OVERLAP");
+ variant += "_overlap";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = context.max_wg_size;
+ result.decisions = decisions;
+ return result;
+}
+#endif // GGML_WEBGPU_SHADER_LIB_HPP
diff --git a/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp
new file mode 100644
index 0000000..32e1202
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp
@@ -0,0 +1,3469 @@
+/*
+ WebGPU backend implementation.
+ Note: Use ClangFormat to format this file.
+*/
+
+#include "ggml-webgpu.h"
+
+#include "ggml-backend-impl.h"
+#include "ggml-impl.h"
+#include "ggml-webgpu-shader-lib.hpp"
+#include "ggml-wgsl-shaders.hpp"
+#include "pre_wgsl.hpp"
+
+#ifdef __EMSCRIPTEN__
+# include <emscripten/emscripten.h>
+#endif
+
+#include <webgpu/webgpu_cpp.h>
+
+#include <atomic>
+#include <condition_variable>
+#include <cstdint>
+#include <cstring>
+#include <iostream>
+#include <map>
+#include <mutex>
+#include <optional>
+#include <string>
+#include <vector>
+
+#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
+#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
+
+#ifdef GGML_WEBGPU_DEBUG
+# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
+# define WEBGPU_DEBUG_BUF_ELEMS 512
+#else
+# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
+#endif // GGML_WEBGPU_DEBUG
+
+#ifdef GGML_WEBGPU_CPU_PROFILE
+// total timing (aggregated)
+# define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now();
+
+# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \
+ auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \
+ double cpu_total_time_##id = \
+ std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
+ (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
+// fine-grained timing (not included in totals)
+# define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
+
+# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \
+ auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \
+ double cpu_detail_time_##id = \
+ std::chrono::duration<double, std::milli>(cpu_detail_end_##id - cpu_detail_start_##id).count(); \
+ (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id;
+#else
+# define WEBGPU_CPU_PROFILE_TOTAL_START(id)
+# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)
+# define WEBGPU_CPU_PROFILE_DETAIL_START(id)
+# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)
+#endif // GGML_WEBGPU_CPU_PROFILE
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
+# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
+#endif
+
+/* Constants */
+
+// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
+#define WEBGPU_MAX_WG_SIZE 288
+
+#define WEBGPU_MUL_MAT_WG_SIZE 256
+#define WEBGPU_NUM_PARAM_BUFS 16u
+#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
+#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
+// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
+#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
+#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
+#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
+#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
+#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
+
+// For operations which process a row in parallel, this seems like a reasonable default
+#define WEBGPU_ROW_SPLIT_WG_SIZE 64
+
+// Matrix multiplication parameters
+
+// Register tiling parameters
+#define WEBGPU_MUL_MAT_TILE_M 8
+#define WEBGPU_MUL_MAT_TILE_N 8
+#define WEBGPU_MUL_MAT_WG_SIZE_M 8
+#define WEBGPU_MUL_MAT_WG_SIZE_N 8
+#define WEBGPU_MUL_MAT_TILE_K 32
+
+// Subgroup matrix parameters
+// The number of subgroups in the M dimension
+#define WEBGPU_MUL_MAT_SUBGROUP_M 2
+// The number of subgroups in the N dimension
+#define WEBGPU_MUL_MAT_SUBGROUP_N 2
+// The number of subgroup matrices each subgroup accumulates over
+#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
+#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
+
+// Matrix-vector multiplication parameters
+#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
+// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
+#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
+#define WEBGPU_MUL_MAT_VEC_TILE_K 256
+
+/* End Constants */
+
+// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
+static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
+
+// Always returns the base offset of a tensor, regardless of views.
+static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
+ if (tensor->view_src) {
+ return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
+ }
+ return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
+}
+
+/* Struct definitions */
+
+// Forward reference
+static void ggml_webgpu_create_buffer(wgpu::Device & device,
+ wgpu::Buffer & buffer,
+ size_t size,
+ wgpu::BufferUsage usage,
+ const char * label);
+
+struct webgpu_pool_bufs {
+ wgpu::Buffer host_buf;
+ wgpu::Buffer dev_buf;
+};
+
+// The futures to wait on for a single queue submission
+struct webgpu_submission_futures {
+ std::vector<wgpu::FutureWaitInfo> futures;
+};
+
+// Holds a pool of parameter buffers for WebGPU operations
+struct webgpu_buf_pool {
+ std::vector<webgpu_pool_bufs> free;
+
+ // The pool must be synchronized because
+ // 1. The memset pool is shared globally by every ggml buffer,
+ // since allocating a pool per ggml buffer would consume too much memory.
+ // 2. For the per-thread buffer pools in webgpu_context,
+ // buffers are allocated and freed in Dawn callbacks,
+ // which can run on a different thread than the calling thread.
+ std::mutex mutex;
+ std::condition_variable cv;
+
+ void init(wgpu::Device device,
+ int num_bufs,
+ size_t buf_size,
+ wgpu::BufferUsage dev_buf_usage,
+ wgpu::BufferUsage host_buf_usage) {
+ for (int i = 0; i < num_bufs; i++) {
+ wgpu::Buffer host_buf;
+ wgpu::Buffer dev_buf;
+ ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
+ free.push_back({ host_buf, dev_buf });
+ }
+ }
+
+ webgpu_pool_bufs alloc_bufs() {
+ std::unique_lock<std::mutex> lock(mutex);
+ cv.wait(lock, [this] { return !free.empty(); });
+ webgpu_pool_bufs bufs = free.back();
+ free.pop_back();
+ return bufs;
+ }
+
+ void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
+ std::lock_guard<std::mutex> lock(mutex);
+ free.insert(free.end(), bufs.begin(), bufs.end());
+ cv.notify_all();
+ }
+
+ void cleanup() {
+ std::lock_guard<std::mutex> lock(mutex);
+ for (auto & bufs : free) {
+ if (bufs.host_buf) {
+ bufs.host_buf.Destroy();
+ }
+ if (bufs.dev_buf) {
+ bufs.dev_buf.Destroy();
+ }
+ }
+ free.clear();
+ }
+
+ ~webgpu_buf_pool() { this->cleanup(); }
+};
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+struct webgpu_gpu_profile_bufs {
+ wgpu::Buffer host_buf;
+ wgpu::Buffer dev_buf;
+ wgpu::QuerySet query_set;
+};
+
+// Holds a pool of parameter buffers for WebGPU operations
+struct webgpu_gpu_profile_buf_pool {
+ std::vector<webgpu_gpu_profile_bufs> free;
+
+ std::mutex mutex;
+
+ std::condition_variable cv;
+
+ void init(wgpu::Device device,
+ int num_bufs,
+ size_t buf_size,
+ wgpu::BufferUsage dev_buf_usage,
+ wgpu::BufferUsage host_buf_usage) {
+ for (int i = 0; i < num_bufs; i++) {
+ wgpu::Buffer host_buf;
+ wgpu::Buffer dev_buf;
+ ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf");
+ ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf");
+ // Create a query set for 2 timestamps
+ wgpu::QuerySetDescriptor ts_query_set_desc = {};
+
+ ts_query_set_desc.type = wgpu::QueryType::Timestamp;
+ ts_query_set_desc.count = 2;
+ wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc);
+
+ free.push_back({ host_buf, dev_buf, ts_query_set });
+ }
+ }
+
+ webgpu_gpu_profile_bufs alloc_bufs() {
+ std::unique_lock<std::mutex> lock(mutex);
+ cv.wait(lock, [this] { return !free.empty(); });
+ webgpu_gpu_profile_bufs bufs = free.back();
+ free.pop_back();
+ return bufs;
+ }
+
+ void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) {
+ std::lock_guard<std::mutex> lock(mutex);
+ free.insert(free.end(), bufs.begin(), bufs.end());
+ cv.notify_all();
+ }
+
+ void cleanup() {
+ std::lock_guard<std::mutex> lock(mutex);
+ for (auto & bufs : free) {
+ bufs.host_buf.Destroy();
+ bufs.dev_buf.Destroy();
+ bufs.query_set.Destroy();
+ }
+ free.clear();
+ }
+
+ ~webgpu_gpu_profile_buf_pool() { this->cleanup(); }
+};
+#endif
+
+struct webgpu_pipeline {
+ wgpu::ComputePipeline pipeline;
+ std::string name;
+ std::shared_ptr<void> context = nullptr;
+};
+
+struct webgpu_command {
+ wgpu::CommandBuffer commands;
+ std::vector<webgpu_pool_bufs> params_bufs;
+ std::optional<webgpu_pool_bufs> set_rows_error_bufs;
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ webgpu_gpu_profile_bufs timestamp_query_bufs;
+ std::string pipeline_name;
+#endif
+};
+
+struct webgpu_capabilities {
+ wgpu::Limits limits;
+ bool supports_subgroup_matrix = false;
+
+ uint32_t sg_mat_m = 0;
+ uint32_t sg_mat_n = 0;
+ uint32_t sg_mat_k = 0;
+
+ uint32_t subgroup_size = 0;
+ uint32_t max_subgroup_size = 0;
+ size_t memset_bytes_per_thread;
+};
+
+// Stores global webgpu members
+struct webgpu_global_context_struct {
+ wgpu::Instance instance;
+ wgpu::Adapter adapter;
+ wgpu::Device device;
+ wgpu::Queue queue;
+
+ webgpu_capabilities capabilities;
+ // Shared buffer to move data from device to host
+ wgpu::Buffer get_tensor_staging_buf;
+ // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
+ std::recursive_mutex mutex;
+
+ webgpu_buf_pool memset_buf_pool;
+ std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
+ std::atomic_uint inflight_threads = 0;
+
+#ifdef GGML_WEBGPU_CPU_PROFILE
+ // Profiling: labeled CPU time in ms (total)
+ std::unordered_map<std::string, double> cpu_time_ms;
+ // Profiling: detailed CPU time in ms
+ std::unordered_map<std::string, double> cpu_detail_ms;
+#endif
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ // Profiling: per-shader GPU time in ms
+ std::unordered_map<std::string, double> shader_gpu_time_ms;
+ // Profiling: pool of timestamp query buffers (one per operation)
+ webgpu_gpu_profile_buf_pool timestamp_query_buf_pool;
+#endif
+
+#ifdef GGML_WEBGPU_DEBUG
+ wgpu::Buffer debug_host_buf;
+ wgpu::Buffer debug_dev_buf;
+#endif
+
+ ~webgpu_global_context_struct() {
+ if (this->get_tensor_staging_buf) {
+ this->get_tensor_staging_buf.Destroy();
+ this->get_tensor_staging_buf = nullptr;
+ }
+#ifdef GGML_WEBGPU_DEBUG
+ if (this->debug_host_buf) {
+ this->debug_host_buf.Destroy();
+ this->debug_host_buf = nullptr;
+ }
+ if (this->debug_dev_buf) {
+ this->debug_dev_buf.Destroy();
+ this->debug_dev_buf = nullptr;
+ }
+#endif
+ }
+};
+
+typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
+
+// All the base objects needed to run operations on a WebGPU device
+struct webgpu_context_struct {
+ // Points to global instances owned by ggml_backend_webgpu_reg_context
+ webgpu_global_context global_ctx;
+
+ pre_wgsl::Preprocessor p;
+
+ webgpu_buf_pool param_buf_pool;
+ webgpu_buf_pool set_rows_error_buf_pool;
+
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
+ mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
+
+ std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
+ flash_attn_pipelines;
+
+ std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
+ std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order (asc/desc)
+ std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order (asc/desc)
+ std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
+ std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
+
+ std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
+ set_rows_pipelines;
+ std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
+
+ std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
+
+ std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
+ binary_pipelines;
+
+ std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
+ std::map<int, webgpu_pipeline> scale_pipelines; // inplace
+ std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
+ std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
+ unary_pipelines;
+ std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;
+
+ size_t memset_bytes_per_thread;
+};
+
+typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
+
+// Metadata required for the ggml backend registration/discovery interface
+struct ggml_backend_webgpu_reg_context {
+ // Since the Instance is a global entrypoint into the WebGPU API, it lives here
+ webgpu_global_context webgpu_global_ctx;
+ size_t device_count;
+ const char * name;
+};
+
+// Per-device struct for the global logical device interface
+struct ggml_backend_webgpu_device_context {
+ webgpu_global_context webgpu_global_ctx;
+ std::string device_name;
+ std::string device_desc;
+};
+
+// Per-thread data required to actually run WebGPU operations in a backend instance
+struct ggml_backend_webgpu_context {
+ webgpu_context webgpu_ctx;
+ std::string name;
+};
+
+// Per-thread data related to buffers
+struct ggml_backend_webgpu_buffer_context {
+ wgpu::Buffer buffer;
+ std::string label;
+ webgpu_global_context global_ctx;
+
+ ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) :
+ buffer(std::move(buf)),
+ label(std::move(lbl)),
+ global_ctx(std::move(global_ctx_)) {}
+};
+
+/* WebGPU object initializations */
+
+// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
+// the corresponding values provided in `repls`.
+static std::string ggml_webgpu_process_shader_repls(const char * src,
+ const std::map<std::string, std::string> & repls) {
+ if (!src) {
+ return std::string();
+ }
+ std::string s = src;
+ for (const auto & kv : repls) {
+ std::string token = "{{" + kv.first + "}}";
+ size_t pos = 0;
+ while ((pos = s.find(token, pos)) != std::string::npos) {
+ s.replace(pos, token.length(), kv.second);
+ pos += kv.second.length();
+ }
+ }
+ return s;
+}
+
+static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
+ const char * shader_code,
+ const char * label,
+ const std::vector<wgpu::ConstantEntry> & constants = {}) {
+ wgpu::ShaderSourceWGSL shader_source;
+ shader_source.code = shader_code;
+
+ wgpu::ShaderModuleDescriptor shader_desc;
+ shader_desc.nextInChain = &shader_source;
+
+ wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
+
+ wgpu::ComputePipelineDescriptor pipeline_desc;
+ pipeline_desc.label = label;
+ pipeline_desc.compute.module = shader_module;
+ pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
+ pipeline_desc.layout = nullptr; // nullptr means auto layout
+ if (constants.size() > 0) {
+ pipeline_desc.compute.constants = constants.data();
+ pipeline_desc.compute.constantCount = constants.size();
+ }
+ return { device.CreateComputePipeline(&pipeline_desc), label };
+}
+
+static void ggml_webgpu_create_buffer(wgpu::Device & device,
+ wgpu::Buffer & buffer,
+ size_t size,
+ wgpu::BufferUsage usage,
+ const char * label) {
+ wgpu::BufferDescriptor buffer_desc;
+ buffer_desc.size = size;
+ buffer_desc.usage = usage;
+ buffer_desc.label = label;
+ buffer_desc.mappedAtCreation = false;
+
+ // TODO: error handling
+ buffer = device.CreateBuffer(&buffer_desc);
+}
+
+/** End WebGPU object initializations */
+
+/** WebGPU Actions */
+
+// Wait for the queue to finish processing all submitted work
+static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
+ std::vector<webgpu_submission_futures> & futures,
+ bool block = true) {
+ // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
+ // inflight_max may be 0, meaning that we must wait on all futures.
+ uint64_t timeout_ms = block ? UINT64_MAX : 0;
+ uint32_t inflight_threads = ctx->inflight_threads;
+ uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
+ while (futures.size() >= inflight_max && futures.size() > 0) {
+ ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
+ futures.erase(futures.begin());
+ }
+ size_t i = 0;
+ while (i < futures.size()) {
+ auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
+ switch (waitStatus) {
+ case wgpu::WaitStatus::Success:
+ futures.erase(futures.begin() + i);
+ break;
+ case wgpu::WaitStatus::TimedOut:
+ i++;
+ break;
+ case wgpu::WaitStatus::Error:
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
+ break;
+ default:
+ GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
+ break;
+ }
+ }
+}
+
+static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
+ wgpu::Buffer & buffer,
+ wgpu::MapMode mode,
+ size_t offset,
+ size_t size) {
+ ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
+ [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
+ if (status != wgpu::MapAsyncStatus::Success) {
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
+ message.data);
+ }
+ }),
+ UINT64_MAX);
+}
+
+#ifdef GGML_WEBGPU_DEBUG
+// This function adds debugging information to shaders, as WebGPU does not support printing directly.
+// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
+// debug statements in the shader, and then call this function after encoding the commands and submitting them.
+static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
+ encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
+ wgpu::CommandBuffer commands = encoder.Finish();
+ ctx->queue.Submit(1, &commands);
+ ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
+ const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
+ std::cout << "debug[0]: " << debug_data[0] << "\n";
+ ctx->debug_host_buf.Unmap();
+}
+#endif
+
+static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx,
+ std::vector<webgpu_command> commands,
+ webgpu_buf_pool & param_buf_pool,
+ webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
+ std::vector<wgpu::CommandBuffer> command_buffers;
+ std::vector<webgpu_pool_bufs> params_bufs;
+ std::vector<webgpu_pool_bufs> set_rows_error_bufs;
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
+#endif
+
+ for (const auto & command : commands) {
+ command_buffers.push_back(command.commands);
+ params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
+ if (command.set_rows_error_bufs) {
+ set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
+ }
+ }
+ ctx->queue.Submit(command_buffers.size(), command_buffers.data());
+
+ std::vector<wgpu::FutureWaitInfo> futures;
+
+ wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
+ wgpu::CallbackMode::AllowSpontaneous,
+ [&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
+ }
+ // Free the staged buffers
+ param_buf_pool.free_bufs(params_bufs);
+ });
+ futures.push_back({ p_f });
+
+ for (const auto & bufs : set_rows_error_bufs) {
+ wgpu::Future f = bufs.host_buf.MapAsync(
+ wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
+ [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
+ if (status != wgpu::MapAsyncStatus::Success) {
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
+ } else {
+ const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
+ if (*error_data) {
+ GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
+ }
+ // We can't unmap in here due to WebGPU reentrancy limitations.
+ if (set_rows_error_buf_pool) {
+ set_rows_error_buf_pool->free_bufs({ bufs });
+ }
+ }
+ });
+ futures.push_back({ f });
+ }
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ for (const auto & command : commands) {
+ auto label = command.pipeline_name;
+ auto ts_bufs = command.timestamp_query_bufs;
+
+ wgpu::Future f = ts_bufs.host_buf.MapAsync(
+ wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
+ [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
+ if (status != wgpu::MapAsyncStatus::Success) {
+ GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
+ } else {
+ const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
+ // WebGPU timestamps are in ns; convert to ms
+ double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
+ ctx->shader_gpu_time_ms[label] += elapsed_ms;
+ // We can't unmap in here due to WebGPU reentrancy limitations.
+ ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
+ }
+ });
+ futures.push_back({ f });
+ }
+#endif
+ return { futures };
+}
+
+static webgpu_command ggml_backend_webgpu_build_multi(
+ webgpu_global_context & ctx,
+ webgpu_buf_pool & param_buf_pool,
+ const std::vector<webgpu_pipeline> & pipelines,
+ const std::vector<std::vector<uint32_t>> & params_list,
+ const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
+ const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list,
+ const std::optional<webgpu_pool_bufs> & set_rows_error_bufs = std::nullopt) {
+ GGML_ASSERT(pipelines.size() == params_list.size());
+ GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
+ GGML_ASSERT(pipelines.size() == workgroups_list.size());
+
+ std::vector<webgpu_pool_bufs> params_bufs_list;
+ std::vector<wgpu::BindGroup> bind_groups;
+
+ for (size_t i = 0; i < pipelines.size(); i++) {
+ webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
+
+ ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
+ params_bufs.host_buf.GetSize());
+ uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
+ for (size_t j = 0; j < params_list[i].size(); j++) {
+ _params[j] = params_list[i][j];
+ }
+ params_bufs.host_buf.Unmap();
+
+ std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
+ uint32_t params_binding_num = entries.size();
+ entries.push_back({ .binding = params_binding_num,
+ .buffer = params_bufs.dev_buf,
+ .offset = 0,
+ .size = params_bufs.dev_buf.GetSize() });
+
+ wgpu::BindGroupDescriptor bind_group_desc;
+ bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
+ bind_group_desc.entryCount = entries.size();
+ bind_group_desc.entries = entries.data();
+ bind_group_desc.label = pipelines[i].name.c_str();
+ bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc));
+
+ params_bufs_list.push_back(params_bufs);
+ }
+
+ wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
+ for (const auto & params_bufs : params_bufs_list) {
+ encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
+ }
+
+ // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
+ if (set_rows_error_bufs) {
+ encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
+ set_rows_error_bufs->host_buf.GetSize());
+ }
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
+ if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
+ ts_bufs.host_buf.Unmap();
+ }
+
+ wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set,
+ .beginningOfPassWriteIndex = 0,
+ .endOfPassWriteIndex = 1 };
+ wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes };
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc);
+#else
+ wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
+#endif
+ for (size_t i = 0; i < pipelines.size(); i++) {
+ pass.SetPipeline(pipelines[i].pipeline);
+ pass.SetBindGroup(0, bind_groups[i]);
+ pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1);
+ }
+ pass.End();
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
+ encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
+#endif
+
+ wgpu::CommandBuffer commands = encoder.Finish();
+ webgpu_command result = {};
+ result.commands = commands;
+ result.params_bufs = params_bufs_list;
+ result.set_rows_error_bufs = set_rows_error_bufs;
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ result.timestamp_query_bufs = ts_bufs;
+ // TODO: handle multiple pipeline names
+ result.pipeline_name = pipelines.front().name;
+#endif
+ return result;
+}
+
+static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx,
+ webgpu_buf_pool & param_buf_pool,
+ webgpu_pipeline & pipeline,
+ std::vector<uint32_t> params,
+ std::vector<wgpu::BindGroupEntry> bind_group_entries,
+ uint32_t wg_x,
+ uint32_t wg_y = 1,
+ std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
+ return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
+ {
+ pipeline
+ },
+ { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
+}
+
+static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
+ wgpu::Buffer & buf,
+ uint32_t value,
+ size_t offset,
+ size_t size) {
+ std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
+ };
+ size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;
+ uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
+
+ webgpu_command command =
+ ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
+ std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command },
+ ctx->memset_buf_pool) };
+ ggml_backend_webgpu_wait(ctx, futures);
+}
+
+/** End WebGPU Actions */
+
+/** GGML Backend Interface */
+
+static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
+ ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
+ return ctx->name.c_str();
+}
+
+static void ggml_backend_webgpu_free(ggml_backend_t backend) {
+ ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
+
+#ifdef GGML_WEBGPU_CPU_PROFILE
+ std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
+ double total_cpu = 0.0;
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
+ total_cpu += kv.second;
+ }
+ std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
+ std::cout << "ggml_webgpu: cpu breakdown:\n";
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
+ double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
+ }
+ if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) {
+ std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
+ }
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) {
+ double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
+ }
+#endif
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
+ double total_gpu = 0.0;
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
+ total_gpu += kv.second;
+ }
+ std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
+ std::cout << "\nggml_webgpu: gpu breakdown:\n";
+ for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
+ double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
+ std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
+ }
+#endif
+
+#if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE)
+ std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
+#endif
+
+ delete ctx;
+ delete backend;
+}
+
+static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
+ return webgpu_tensor_offset(tensor) + tensor->view_offs;
+}
+
+static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
+ ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
+ return ctx->buffer;
+}
+
+static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
+ size_t offset = ggml_webgpu_tensor_offset(t);
+ return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
+}
+
+static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
+ size_t offset = ggml_webgpu_tensor_offset(t);
+ return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
+}
+
+static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
+ return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
+}
+
+// Used to determine if two tensors are the same for in-place operations
+static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
+ return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
+ (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
+}
+
+// Used to determine if two tensors share the same buffer and their byte ranges overlap,
+static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
+ return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
+ ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
+ ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
+}
+
+struct binary_overlap_flags {
+ bool inplace; // src0 == dst
+ bool overlap; // src1 == dst
+};
+
+static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * dst) {
+ binary_overlap_flags flags = {};
+ flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
+ flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
+
+ return flags;
+}
+
+static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+ std::vector<uint32_t> params = {
+ ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ // Convert byte-strides to element-strides
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+ (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+ // Logical shapes
+ (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
+ (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
+ params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ const bool circular = ggml_get_op_params_i32(dst, 8) != 0;
+
+ ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular };
+ ggml_webgpu_pad_shader_lib_context shader_lib_ctx = {
+ .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+ };
+
+ webgpu_pipeline pipeline;
+ auto it = ctx->pad_pipelines.find(pipeline_key);
+ if (it != ctx->pad_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->pad_pipelines.emplace(pipeline_key, pipeline);
+ }
+
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
+ const uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+ std::vector<uint32_t> params = {
+ ne,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ // Strides (in elements)
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+ // Shapes
+ (uint32_t) src->ne[0],
+ (uint32_t) src->ne[1],
+ (uint32_t) src->ne[2],
+ (uint32_t) src->ne[3],
+ (uint32_t) dst->ne[0],
+ (uint32_t) dst->ne[1],
+ (uint32_t) dst->ne[2],
+ (uint32_t) dst->ne[3],
+ // Pad sizes
+ (uint32_t) ggml_get_op_params_i32(dst, 0),
+ (uint32_t) ggml_get_op_params_i32(dst, 1),
+ (uint32_t) ggml_get_op_params_i32(dst, 2),
+ (uint32_t) ggml_get_op_params_i32(dst, 3),
+ (uint32_t) ggml_get_op_params_i32(dst, 4),
+ (uint32_t) ggml_get_op_params_i32(dst, 5),
+ (uint32_t) ggml_get_op_params_i32(dst, 6),
+ (uint32_t) ggml_get_op_params_i32(dst, 7),
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
+ ggml_tensor * src,
+ ggml_tensor * idx,
+ ggml_tensor * dst) {
+ // For set rows specifically, we need to check if src and idx are empty tensors.
+ if (ggml_is_empty(src) || ggml_is_empty(idx)) {
+ return std::nullopt;
+ }
+
+ ggml_webgpu_set_rows_pipeline_key key = { .dst_type = dst->type,
+ .vec4 = src->ne[0] % 4 == 0,
+ .i64_idx = idx->type == GGML_TYPE_I64 };
+
+ ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = {
+ .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+ };
+
+ webgpu_pipeline pipeline;
+ auto it = ctx->set_rows_pipelines.find(key);
+ if (it != ctx->set_rows_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->set_rows_pipelines.emplace(key, pipeline);
+ }
+
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
+ std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
+ if (key.i64_idx) {
+ error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
+ if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
+ error_bufs->host_buf.Unmap();
+ }
+ }
+
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ // Convert byte-strides to element-strides
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
+ (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+ // Shape of src
+ (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
+ // Shape of idx
+ (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(idx),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
+ .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
+ { .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ if (key.i64_idx) {
+ entries.push_back(
+ { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
+ }
+
+ uint32_t threads;
+ if (key.vec4) {
+ threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
+ } else {
+ threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
+ }
+ uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
+ error_bufs);
+}
+
+static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
+ ggml_tensor * src,
+ ggml_tensor * idx,
+ ggml_tensor * dst) {
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ // Convert byte-strides to element-strides
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
+ (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+ // Shape of dst
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
+ // Shape of idx
+ (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(idx),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
+ .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
+ { .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
+
+ uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
+ webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized];
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
+ ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * dst) {
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) dst->ne[0], // number of rows in result (M, transposed)
+ (uint32_t) dst->ne[1], // number of columns in result (N)
+ (uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
+ (uint32_t) src0->ne[2], // batch size in dimension 2
+ (uint32_t) src0->ne[3], // batch size in dimension 3
+ (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
+ (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
+ { .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
+ };
+
+ webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
+
+ uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
+ uint32_t wg_y = 1;
+
+ bool use_fast = false;
+ switch (src1->type) {
+ case GGML_TYPE_F16:
+ use_fast = (src0->type == GGML_TYPE_F16);
+ break;
+ case GGML_TYPE_F32:
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ use_fast = true;
+ break;
+ default:
+ break;
+ }
+ break;
+ default:
+ break;
+ }
+
+ if (use_fast) {
+ int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
+ if (dst->ne[1] == 1) {
+ // We don't support vectorized mul_mat_vec for quantized types
+ vectorized = vectorized && (src0->type < 2);
+ pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
+ uint32_t batches = dst->ne[2] * dst->ne[3];
+ uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
+ uint32_t total_wg = output_groups * batches;
+ wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
+ wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
+ } else {
+ pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
+ uint32_t wg_m;
+ uint32_t wg_n;
+#ifndef __EMSCRIPTEN__
+ if (ctx->global_ctx->capabilities.supports_subgroup_matrix) {
+ // The total number of subgroups/workgroups needed per matrix.
+ uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M *
+ ctx->global_ctx->capabilities.sg_mat_m;
+ wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
+ uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N *
+ ctx->global_ctx->capabilities.sg_mat_n;
+ wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
+ } else {
+#endif
+ uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
+ uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
+ wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
+ wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
+#ifndef __EMSCRIPTEN__
+ }
+#endif
+
+ wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
+ }
+ }
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
+}
+
+#ifndef __EMSCRIPTEN__
+static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
+ ggml_tensor * Q,
+ ggml_tensor * K,
+ ggml_tensor * V,
+ ggml_tensor * mask,
+ ggml_tensor * sinks,
+ ggml_tensor * dst) {
+ float scale = *(float *) dst->op_params;
+ float max_bias;
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+ float logit_softcap;
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
+ if (logit_softcap != 0.0f) {
+ scale /= logit_softcap;
+ }
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ const int has_mask = (mask != nullptr);
+ const int has_sinks = (sinks != nullptr);
+
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
+ has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
+ has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) Q->ne[2], // number of heads
+ (uint32_t) Q->ne[1], // sequence length (Q)
+ (uint32_t) K->ne[1], // sequence length (K/V)
+ (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
+ (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
+ (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
+ (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
+ (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
+ (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
+ (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
+ (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
+ (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
+ has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
+ (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
+ *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
+ *(uint32_t *) &max_bias,
+ *(uint32_t *) &logit_softcap,
+ *(uint32_t *) &n_head_log2,
+ *(uint32_t *) &m0,
+ *(uint32_t *) &m1
+
+ };
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(Q),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, Q),
+ .size = ggml_webgpu_tensor_binding_size(ctx, Q) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(K),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, K),
+ .size = ggml_webgpu_tensor_binding_size(ctx, K) },
+ { .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(V),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, V),
+ .size = ggml_webgpu_tensor_binding_size(ctx, V) }
+ };
+ uint32_t binding_index = 3;
+ if (has_mask) {
+ entries.push_back({ .binding = binding_index++,
+ .buffer = ggml_webgpu_tensor_buf(mask),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, mask),
+ .size = ggml_webgpu_tensor_binding_size(ctx, mask) });
+ }
+ if (has_sinks) {
+ entries.push_back({ .binding = binding_index++,
+ .buffer = ggml_webgpu_tensor_buf(sinks),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
+ .size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
+ }
+ entries.push_back({ .binding = binding_index++,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+
+ bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
+ (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
+
+ ggml_webgpu_flash_attn_pipeline_key key = {
+ .kv_type = K->type,
+ .head_dim_qk = (uint32_t) Q->ne[0],
+ .head_dim_v = (uint32_t) V->ne[0],
+ .kv_direct = kv_direct,
+ .has_mask = static_cast<bool>(has_mask),
+ .has_sinks = static_cast<bool>(has_sinks),
+ .uses_logit_softcap = logit_softcap != 0.0f,
+ };
+
+ webgpu_pipeline pipeline;
+ auto it = ctx->flash_attn_pipelines.find(key);
+ if (it != ctx->flash_attn_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
+ .key = key,
+ .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
+ .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
+ .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
+ .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
+ .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
+ };
+
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->flash_attn_pipelines.emplace(key, pipeline);
+ }
+
+ auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
+
+ uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
+ uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+#endif
+
+static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ bool is_unary = dst->op == GGML_OP_UNARY;
+ bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
+ int op = is_unary ? (int) ggml_get_unary_op(dst) : dst->op;
+
+ ggml_webgpu_unary_pipeline_key pipeline_key = {
+ .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace
+ };
+ ggml_webgpu_unary_shader_lib_context shader_lib_ctx = {
+ .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+ };
+
+ webgpu_pipeline pipeline;
+ auto it = ctx->unary_pipelines.find(pipeline_key);
+ if (it != ctx->unary_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->unary_pipelines.emplace(pipeline_key, pipeline);
+ }
+
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+ std::vector<uint32_t> params = { ne,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+ (uint32_t) src->ne[0],
+ (uint32_t) src->ne[1],
+ (uint32_t) src->ne[2] };
+
+ ggml_tensor * effective_src = src;
+ if (is_unary) {
+ ggml_unary_op unary_op = ggml_get_unary_op(dst);
+ switch (unary_op) {
+ case GGML_UNARY_OP_XIELU:
+ {
+ // Get float parameters and reinterpret their bit patterns as uint32_t
+ // for passing through the params buffer
+ float alpha_n = ggml_get_op_params_f32(dst, 1);
+ float alpha_p = ggml_get_op_params_f32(dst, 2);
+ float beta = ggml_get_op_params_f32(dst, 3);
+ float eps = ggml_get_op_params_f32(dst, 4);
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
+ params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
+ params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
+ params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
+ break;
+ }
+ default:
+ break;
+ }
+ } else if (dst->op == GGML_OP_CLAMP) {
+ float clamp_min = ggml_get_op_params_f32(dst, 0);
+ float clamp_max = ggml_get_op_params_f32(dst, 1);
+ params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_min));
+ params.push_back(*reinterpret_cast<const uint32_t *>(&clamp_max));
+ } else if (dst->op == GGML_OP_FILL) {
+ float fill_val = ggml_get_op_params_f32(dst, 0);
+ params.push_back(*reinterpret_cast<const uint32_t *>(&fill_val));
+ effective_src = dst; // fill simply fills dst
+ }
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(effective_src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) },
+ };
+ if (!inplace) {
+ entries.push_back({ .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+ }
+
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
+ ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * dst) {
+ binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
+
+ ggml_webgpu_binary_pipeline_key pipeline_key = {
+ .type = dst->type,
+ .op = dst->op,
+ .inplace = flags.inplace,
+ .overlap = flags.overlap,
+ };
+ ggml_webgpu_binary_shader_lib_context shader_lib_ctx = {
+ .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+ };
+
+ webgpu_pipeline pipeline;
+ auto it = ctx->binary_pipelines.find(pipeline_key);
+ if (it != ctx->binary_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx);
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->binary_pipelines.emplace(pipeline_key, pipeline);
+ }
+
+ auto * decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(pipeline.context.get());
+
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+ std::vector<uint32_t> params = {
+ ne,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
+ (uint32_t) src0->ne[0],
+ (uint32_t) src0->ne[1],
+ (uint32_t) src0->ne[2],
+ (uint32_t) src1->ne[0],
+ (uint32_t) src1->ne[1],
+ (uint32_t) src1->ne[2],
+ (uint32_t) src1->ne[3],
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries;
+
+ entries.push_back({
+ .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0),
+ });
+
+ entries.push_back({
+ .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1),
+ });
+
+ if (!flags.inplace && !flags.overlap) {
+ entries.push_back({ .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+ }
+
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ int inplace = ggml_webgpu_tensor_equal(src, dst);
+
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+ (uint32_t) src->ne[0],
+ (uint32_t) src->ne[1],
+ (uint32_t) src->ne[2],
+ (uint32_t) src->ne[3],
+ *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) }
+ };
+ if (!inplace) {
+ entries.push_back({ .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+ }
+
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
+ entries, ggml_nrows(src));
+}
+
+static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
+ ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * src2,
+ ggml_tensor * dst) {
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
+ const int has_freq_factor = (src2 != nullptr);
+
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+
+ int sections[4];
+ memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
+
+ float theta_scale = powf(freq_base, -2.0f / n_dims);
+
+ float corr_dims[2];
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
+ src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+ (uint32_t) ggml_nelements(src0) / 2,
+ (uint32_t) src0->ne[0],
+ (uint32_t) src0->ne[1],
+ (uint32_t) src0->ne[2],
+ (uint32_t) n_dims,
+ (uint32_t) mode,
+ *(uint32_t *) &theta_scale,
+ *(uint32_t *) &attn_factor,
+ *(uint32_t *) &freq_scale,
+ *(uint32_t *) &ext_factor,
+ *(uint32_t *) &corr_dims[0],
+ *(uint32_t *) &corr_dims[1],
+ (uint32_t) sections[0],
+ (uint32_t) sections[1],
+ (uint32_t) sections[2],
+ (uint32_t) sections[3]
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
+ };
+ uint32_t dst_binding = 2;
+ if (has_freq_factor) {
+ dst_binding = 3;
+ entries.push_back({ .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(src2),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
+ }
+ if (!inplace) {
+ entries.push_back({ .binding = dst_binding,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+ }
+
+ webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
+ const int split = (src1 != nullptr);
+
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+ src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+ src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+ src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+ src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+ (uint32_t) ggml_nelements(dst),
+ (uint32_t) dst->ne[0],
+ (uint32_t) dst->ne[1],
+ (uint32_t) dst->ne[2],
+ (uint32_t) ((int32_t *) dst->op_params)[1], // swapped
+ *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
+ *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
+ };
+ uint32_t dst_binding = 1;
+ if (split) {
+ dst_binding = 2;
+ entries.push_back({ .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
+ }
+ entries.push_back({ .binding = dst_binding,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+
+ webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ int inplace = ggml_webgpu_tensor_equal(src, dst);
+
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+ (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+ (uint32_t) ggml_nelements(dst),
+ (uint32_t) src->ne[0],
+ (uint32_t) src->ne[1],
+ (uint32_t) src->ne[2],
+ *(uint32_t *) dst->op_params, // scale
+ *(uint32_t *) &dst->op_params[1] // bias
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) }
+ };
+ if (!inplace) {
+ entries.push_back({ .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+ }
+
+ uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params,
+ entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
+ ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * src2,
+ ggml_tensor * dst) {
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
+ const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
+ const int has_sink = (src2 != nullptr);
+ float max_bias;
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+ mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
+ has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+ mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
+ mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
+ mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+ (uint32_t) ggml_nelements(dst),
+ (uint32_t) src0->ne[0],
+ (uint32_t) src0->ne[1],
+ (uint32_t) src0->ne[2],
+ mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
+ mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
+ *(uint32_t *) dst->op_params, // scale
+ *(uint32_t *) &max_bias,
+ *(uint32_t *) &n_head_log2,
+ *(uint32_t *) &m0,
+ *(uint32_t *) &m1
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) }
+ };
+ uint32_t binding_num = 1;
+ if (mask_type < 2) {
+ entries.push_back({ .binding = binding_num,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
+ binding_num++;
+ }
+ if (has_sink) {
+ entries.push_back({ .binding = binding_num,
+ .buffer = ggml_webgpu_tensor_buf(src2),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
+ binding_num++;
+ }
+ if (!inplace) {
+ entries.push_back({ .binding = binding_num,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+ }
+
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
+ ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
+ ggml_nrows(dst));
+}
+
+static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) src->ne[0] };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
+ .vec4 = src->ne[0] % 4 == 0,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ };
+
+ webgpu_pipeline pipeline;
+ auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
+ if (it != ctx->argmax_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
+ }
+ uint32_t wg_x = ggml_nelements(dst);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ bool is_top_k = dst->op == GGML_OP_TOP_K;
+ // ascending order is 0, descending order is 1
+ const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0);
+
+ ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = {
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
+ .order = order
+ };
+
+ webgpu_pipeline argsort_pipeline;
+ auto it = ctx->argsort_pipelines.find(order);
+ if (it != ctx->argsort_pipelines.end()) {
+ argsort_pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx);
+ argsort_pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ argsort_pipeline.context = processed.decisions;
+ ctx->argsort_pipelines.emplace(order, argsort_pipeline);
+ }
+ auto * argsort_decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(argsort_pipeline.context.get());
+
+ webgpu_pipeline argsort_merge_pipeline;
+ it = ctx->argsort_merge_pipelines.find(order);
+ if (it != ctx->argsort_merge_pipelines.end()) {
+ argsort_merge_pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx);
+ argsort_merge_pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ argsort_merge_pipeline.context = processed.decisions;
+ ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline);
+ }
+
+ const uint32_t src_ne0 = (uint32_t) src->ne[0];
+ const uint32_t nrows = (uint32_t) ggml_nrows(src);
+ const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size);
+ const uint32_t block_size =
+ is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size;
+ uint32_t out_ne0 = src_ne0;
+ if (is_top_k) {
+ if (npr > 1) {
+ const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size;
+ out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size);
+ } else {
+ out_ne0 = block_size;
+ }
+ }
+
+ uint32_t merge_len = block_size;
+ uint32_t merge_passes = 0;
+ while (merge_len < out_ne0) {
+ merge_len <<= 1;
+ merge_passes++;
+ }
+
+ const bool start_in_tmp = (merge_passes % 2) == 1;
+
+ const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
+ const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);
+ const size_t tmp_offset =
+ ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
+ const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
+ const size_t dst_binding_size =
+ ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT);
+
+ const uint32_t offset_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type));
+ const uint32_t offset_dst = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type));
+ const uint32_t offset_tmp = 0;
+ const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type));
+ const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type));
+ const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type));
+ const uint32_t stride_idx1 = out_ne0;
+ const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1];
+ const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2];
+
+ std::vector<webgpu_pipeline> pipelines;
+ std::vector<std::vector<uint32_t>> params_list;
+ std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
+ std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
+
+ const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst;
+ const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
+ const size_t init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size;
+
+ std::vector<uint32_t> init_params = {
+ offset_src, init_offset, stride_src1, stride_src2, stride_src3, stride_idx1,
+ stride_idx2, stride_idx3, src_ne0, (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0,
+ block_size, npr, nrows
+ };
+
+ const uint32_t total_wg_init = npr * nrows;
+ const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
+ const uint32_t wg_x_init = std::min(total_wg_init, max_wg);
+ const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
+ std::vector<wgpu::BindGroupEntry> init_entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size }
+ };
+
+ pipelines.push_back(argsort_pipeline);
+ params_list.push_back(std::move(init_params));
+ entries_list.push_back(std::move(init_entries));
+ workgroups_list.push_back({ wg_x_init, wg_y_init });
+
+ if (merge_passes == 0) {
+ return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
+ entries_list, workgroups_list);
+ }
+
+ bool in_is_tmp = start_in_tmp;
+ uint32_t len = block_size;
+ while (len < out_ne0) {
+ const uint32_t nm = CEIL_DIV(out_ne0, 2 * len);
+
+ const bool out_is_tmp = !in_is_tmp;
+ const uint32_t offset_in = in_is_tmp ? offset_tmp : offset_dst;
+ const uint32_t offset_out = out_is_tmp ? offset_tmp : offset_dst;
+ const size_t align_in = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
+ const size_t align_out = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
+ const size_t size_in = in_is_tmp ? tmp_binding_size : dst_binding_size;
+ const size_t size_out = out_is_tmp ? tmp_binding_size : dst_binding_size;
+ const uint32_t top_k_out = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0;
+ const uint32_t stride_out1 = top_k_out;
+ const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1];
+ const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2];
+
+ std::vector<uint32_t> merge_params = { offset_src,
+ offset_in,
+ offset_out,
+ stride_src1,
+ stride_src2,
+ stride_src3,
+ stride_idx1,
+ stride_idx2,
+ stride_idx3,
+ stride_out1,
+ stride_out2,
+ stride_out3,
+ out_ne0,
+ (uint32_t) src->ne[1],
+ (uint32_t) src->ne[2],
+ top_k_out,
+ len,
+ nm,
+ nrows };
+
+ std::vector<wgpu::BindGroupEntry> merge_entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in },
+ { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out }
+ };
+
+ const uint32_t total_wg_merge = nm * nrows;
+ const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg);
+ const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge);
+ workgroups_list.push_back({ wg_x_merge, wg_y_merge });
+ pipelines.push_back(argsort_merge_pipeline);
+ params_list.push_back(std::move(merge_params));
+ entries_list.push_back(std::move(merge_entries));
+
+ len <<= 1;
+ in_is_tmp = !in_is_tmp;
+ }
+
+ return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list,
+ workgroups_list);
+}
+
+static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) src->ne[0] };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
+ .vec4 = false,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ };
+ webgpu_pipeline pipeline;
+ auto it = ctx->cumsum_pipelines.find(1);
+ if (it != ctx->cumsum_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ ctx->cumsum_pipelines.emplace(1, pipeline);
+ }
+ uint32_t wg_x = ggml_nrows(dst);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ bool total_sum = dst->op == GGML_OP_SUM;
+ std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+ total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+ total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+ total_sum ? static_cast<uint32_t>(ggml_nelements(src)) : (uint32_t) src->ne[0],
+ total_sum ? 1 : (uint32_t) src->ne[1],
+ total_sum ? 1 : (uint32_t) src->ne[2] };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
+ .vec4 = false,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ };
+
+ webgpu_pipeline pipeline;
+ auto it = ctx->sum_rows_pipelines.find(1);
+ if (it != ctx->sum_rows_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ ctx->sum_rows_pipelines.emplace(1, pipeline);
+ }
+ uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+// Returns the encoded command, or std::nullopt if the operation is a no-op
+static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
+ if (ggml_is_empty(node)) {
+ return std::nullopt;
+ }
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+ return std::nullopt;
+ }
+ WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
+
+ ggml_tensor * src0 = node->src[0];
+ ggml_tensor * src1 = node->src[1];
+ ggml_tensor * src2 = node->src[2];
+
+ switch (node->op) {
+ // no-ops
+ case GGML_OP_NONE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_RESHAPE:
+ return std::nullopt;
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ return ggml_webgpu_cpy(ctx, src0, node);
+ case GGML_OP_SET_ROWS:
+ return ggml_webgpu_set_rows(ctx, src0, src1, node);
+ case GGML_OP_GET_ROWS:
+ return ggml_webgpu_get_rows(ctx, src0, src1, node);
+ case GGML_OP_MUL_MAT:
+ return ggml_webgpu_mul_mat(ctx, src0, src1, node);
+ case GGML_OP_FLASH_ATTN_EXT:
+#ifndef __EMSCRIPTEN__
+ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
+#else
+ return std::nullopt;
+#endif
+ case GGML_OP_ADD:
+ case GGML_OP_SUB:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ return ggml_webgpu_binary_op(ctx, src0, src1, node);
+ case GGML_OP_RMS_NORM:
+ return ggml_webgpu_rms_norm(ctx, src0, node);
+ case GGML_OP_ROPE:
+ return ggml_webgpu_rope(ctx, src0, src1, src2, node);
+ case GGML_OP_GLU:
+ return ggml_webgpu_glu(ctx, src0, src1, node);
+ case GGML_OP_SCALE:
+ return ggml_webgpu_scale(ctx, src0, node);
+ case GGML_OP_SOFT_MAX:
+ return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
+ case GGML_OP_UNARY:
+ return ggml_webgpu_unary_op(ctx, src0, node);
+ case GGML_OP_CLAMP:
+ return ggml_webgpu_unary_op(ctx, src0, node);
+ case GGML_OP_FILL:
+ return ggml_webgpu_unary_op(ctx, src0, node);
+ case GGML_OP_LOG:
+ return ggml_webgpu_unary_op(ctx, src0, node);
+ case GGML_OP_PAD:
+ return ggml_webgpu_pad(ctx, src0, node);
+ case GGML_OP_ARGMAX:
+ return ggml_webgpu_argmax(ctx, src0, node);
+ case GGML_OP_ARGSORT:
+ return ggml_webgpu_argsort(ctx, src0, node);
+ case GGML_OP_TOP_K:
+ // we reuse the same argsort implementation for top_k
+ return ggml_webgpu_argsort(ctx, src0, node);
+ case GGML_OP_CUMSUM:
+ return ggml_webgpu_cumsum(ctx, src0, node);
+ case GGML_OP_SUM:
+ case GGML_OP_SUM_ROWS:
+ return ggml_webgpu_sum_rows(ctx, src0, node);
+ default:
+ return std::nullopt;
+ }
+}
+
+static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
+
+ ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
+ webgpu_context ctx = backend_ctx->webgpu_ctx;
+
+ WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
+
+ ctx->global_ctx->inflight_threads++;
+
+ std::vector<webgpu_command> commands;
+ std::vector<webgpu_submission_futures> futures;
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
+ commands.push_back(*cmd);
+ }
+ // compute the batch size based on the number of inflight threads
+ uint32_t inflight_threads = ctx->global_ctx->inflight_threads;
+ uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
+ WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
+ if (commands.size() >= batch_size) {
+ futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool,
+ &ctx->set_rows_error_buf_pool));
+ // Process events and check for completed submissions
+ ctx->global_ctx->instance.ProcessEvents();
+ ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
+ commands.clear();
+ }
+ }
+ if (!commands.empty()) {
+ webgpu_submission_futures new_futures =
+ ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
+ futures.push_back(new_futures);
+ }
+
+ ggml_backend_webgpu_wait(ctx->global_ctx, futures);
+ ctx->global_ctx->inflight_threads--;
+ WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
+ return GGML_STATUS_SUCCESS;
+}
+
+static ggml_backend_i ggml_backend_webgpu_i = {
+ /* .get_name = */ ggml_backend_webgpu_name,
+ /* .free = */ ggml_backend_webgpu_free,
+ /* .set_tensor_async = */ NULL,
+ /* .get_tensor_async = */ NULL,
+ /* .cpy_tensor_async = */ NULL,
+ /* .synchronize = */ NULL,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_webgpu_graph_compute,
+ /* .event_record = */ NULL,
+ /* .event_wait = */ NULL,
+ /* .graph_optimize = */ NULL,
+};
+
+/* End GGML Backend Interface */
+
+/* GGML Backend Buffer Interface */
+
+static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
+ if (ctx != nullptr && ctx->buffer != nullptr) {
+ ctx->buffer.Destroy();
+ delete ctx;
+ }
+}
+
+// Returns the "fake" base pointer.
+static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
+ GGML_UNUSED(buffer);
+ return webgpu_ptr_base;
+}
+
+static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
+ ggml_tensor * tensor,
+ uint8_t value,
+ size_t offset,
+ size_t size) {
+ if (size == 0) {
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
+ return;
+ }
+
+ WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
+
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
+
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
+ << ", " << offset << ", " << size << ")");
+
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
+
+ // This is a trick to set all bytes of a u32 to the same 1 byte value.
+ uint32_t val32 = (uint32_t) value * 0x01010101;
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size);
+ WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx);
+}
+
+static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
+ ggml_tensor * tensor,
+ const void * data,
+ size_t offset,
+ size_t size) {
+ WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
+
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
+ << ", " << offset << ", " << size << ")");
+
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
+
+ buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
+
+ if (size % 4 != 0) {
+ // If size is not a multiple of 4, we need to memset the remaining bytes
+ size_t remaining_size = size % 4;
+
+ // pack the remaining bytes into a uint32_t
+ uint32_t val32 = 0;
+
+ for (size_t i = 0; i < remaining_size; i++) {
+ ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
+ }
+ // memset the remaining bytes
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
+ total_offset + (size - remaining_size), remaining_size);
+ } else {
+ // wait for WriteBuffer to complete
+ buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
+ wgpu::CallbackMode::AllowSpontaneous,
+ [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
+ if (status != wgpu::QueueWorkDoneStatus::Success) {
+ GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
+ std::string(message).c_str());
+ }
+ }),
+ UINT64_MAX);
+ }
+ WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
+}
+
+static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
+ const ggml_tensor * tensor,
+ void * data,
+ size_t offset,
+ size_t size) {
+ WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
+ << ", " << offset << ", " << size << ")");
+ wgpu::Device device = buf_ctx->global_ctx->device;
+
+ size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
+
+ size_t final_size = size;
+ if (size % 4 != 0) {
+ // If size is not a multiple of 4, we need to round it up to the next multiple of 4
+ final_size = size + (4 - (size % 4));
+ }
+
+ std::lock_guard<std::recursive_mutex> lock(buf_ctx->global_ctx->mutex);
+
+ if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr ||
+ buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) {
+ // Create a new staging buffer if it doesn't exist or is too small
+ if (buf_ctx->global_ctx->get_tensor_staging_buf) {
+ buf_ctx->global_ctx->get_tensor_staging_buf.Destroy();
+ }
+ ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size,
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
+ }
+
+ // Copy the data from the buffer to the staging buffer
+ wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
+ encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0,
+ final_size);
+ wgpu::CommandBuffer commands = encoder.Finish();
+
+ // Submit the command buffer to the queue
+ buf_ctx->global_ctx->queue.Submit(1, &commands);
+
+ // Map the staging buffer to read the data
+ ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf,
+ wgpu::MapMode::Read, 0, final_size);
+ // Must specify size here since the staging buffer might be larger than the tensor size
+ const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
+
+ // Copy the data from the mapped range to the output buffer
+ std::memcpy(data, mapped_range, size);
+ buf_ctx->global_ctx->get_tensor_staging_buf.Unmap();
+ WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx);
+}
+
+static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
+ WEBGPU_CPU_PROFILE_TOTAL_START(clear);
+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
+ ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size);
+ WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx);
+}
+
+static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
+ /* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_webgpu_buffer_get_base,
+ /* .init_tensor = */ NULL, // TODO: optional, needed?
+ /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
+ /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
+ /* .cpy_tensor = */ NULL, // TODO: optional, implement this
+ /* .clear = */ ggml_backend_webgpu_buffer_clear,
+ /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
+};
+
+/* End GGML Backend Buffer Interface */
+
+/* GGML Backend Buffer Type Interface */
+
+static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
+ return ctx->device_name.c_str();
+}
+
+static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
+ size_t size) {
+ static std::atomic<int> buffer_count;
+ int buffer_id = buffer_count++;
+ std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
+
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
+ wgpu::Buffer buf;
+ ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
+ buf_name.c_str());
+
+ ggml_backend_webgpu_buffer_context * buf_ctx =
+ new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx);
+
+ return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
+}
+
+static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ ggml_backend_webgpu_device_context * dev_ctx =
+ static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
+ return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
+}
+
+// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
+static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+ ggml_backend_webgpu_device_context * dev_ctx =
+ static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
+ return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize;
+}
+
+static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
+ const ggml_tensor * tensor) {
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
+ size_t res = ggml_nbytes(tensor);
+ switch (tensor->op) {
+ case GGML_OP_ARGSORT:
+ res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
+ WEBGPU_STORAGE_BUF_BINDING_MULT);
+ break;
+ case GGML_OP_TOP_K:
+ {
+ const ggml_tensor * src0 = tensor->src[0];
+ if (src0) {
+ const size_t full = sizeof(int32_t) * ggml_nelements(src0);
+ res = ROUNDUP_POW2(
+ full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
+ WEBGPU_STORAGE_BUF_BINDING_MULT);
+ }
+ }
+ break;
+ default:
+ break;
+ }
+ return res;
+}
+
+/* End GGML Backend Buffer Type Interface */
+
+/* GGML Backend Device Interface */
+
+static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
+ return ctx->device_name.c_str();
+}
+
+static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
+ return ctx->device_desc.c_str();
+}
+
+static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
+ // TODO: for now, return maxBufferSize as both free and total memory
+ // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
+ uint64_t max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize;
+ // If we're on a 32-bit system, clamp to UINTPTR_MAX
+#if UINTPTR_MAX < UINT64_MAX
+ uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
+ if (max_buffer_size > max_ptr_size) {
+ max_buffer_size = max_ptr_size;
+ }
+#endif
+ *free = static_cast<size_t>(max_buffer_size);
+ *total = static_cast<size_t>(max_buffer_size);
+}
+
+static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
+ GGML_UNUSED(dev);
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
+}
+
+static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+ props->name = ggml_backend_webgpu_device_get_name(dev);
+ props->description = ggml_backend_webgpu_device_get_description(dev);
+ props->type = ggml_backend_webgpu_device_get_type(dev);
+ ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
+ props->caps = {
+ /* .async = */ false,
+ /* .host_buffer = */ false,
+ /* .buffer_from_host_ptr = */ false,
+ /* .events = */ false,
+ };
+}
+
+static ggml_guid_t ggml_backend_webgpu_guid(void) {
+ static const char * guid_str = "__ggml_webgpu :)";
+ return reinterpret_cast<ggml_guid_t>((void *) guid_str);
+}
+
+// Workgroup size is a common constant
+static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
+ std::vector<wgpu::ConstantEntry> constants(1);
+ constants[0].key = "wg_size";
+ constants[0].value = wg_size;
+ return constants;
+}
+
+static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
+ // we use the maximum workgroup size for the memset pipeline
+ size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
+ // Size the bytes_per_thread so that the largest buffer size can be handled
+ ctx->capabilities.memset_bytes_per_thread =
+ CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
+ std::vector<wgpu::ConstantEntry> constants(2);
+ constants[0].key = "wg_size";
+ constants[0].value = WEBGPU_MAX_WG_SIZE;
+ constants[1].key = "bytes_per_thread";
+ constants[1].value = ctx->capabilities.memset_bytes_per_thread;
+ ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
+}
+
+static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
+ // Q4/Q5/Q8 classic quantizations
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
+
+ // K-quantizations
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
+
+ // IQ quantizations (2-, 3-, 4-bit variants)
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
+
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
+
+ // 1-bit and 4-bit IQ variants
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
+
+ std::string proc_mul_mat_f32_f32;
+ std::string proc_mul_mat_f32_f32_vec;
+ std::string proc_mul_mat_f16_f32;
+ std::string proc_mul_mat_f16_f32_vec;
+ std::string proc_mul_mat_f16_f16;
+ std::string proc_mul_mat_f16_f16_vec;
+ std::string proc_mul_mat_q4_0_f32;
+ std::string proc_mul_mat_q4_0_f32_vec;
+
+ std::vector<wgpu::ConstantEntry> mul_mat_constants;
+#ifndef __EMSCRIPTEN__
+ if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) {
+ std::map<std::string, std::string> sg_matrix_repls;
+ sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] =
+ std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size);
+ sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
+ sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
+ sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
+ sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
+ sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
+ sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m);
+ sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n);
+ sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k);
+ proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
+ proc_mul_mat_f32_f32_vec =
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
+ proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
+ proc_mul_mat_f16_f32_vec =
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
+ proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
+ proc_mul_mat_f16_f16_vec =
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
+ proc_mul_mat_q4_0_f32 =
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
+ proc_mul_mat_q4_0_f32_vec =
+ ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
+ } else {
+#endif
+ mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
+ mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
+ mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
+
+ std::map<std::string, std::string> reg_repls;
+ reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
+ reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
+
+ proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
+ proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
+ proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
+ proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
+ proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
+ proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
+ proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
+ proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
+#ifndef __EMSCRIPTEN__
+ }
+#endif
+
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
+ webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
+
+ std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
+ mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
+ mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
+ mul_mat_vec_constants[1].key = "TILE_K";
+ mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
+ mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
+ mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
+
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
+ webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
+}
+
+static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
+
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
+
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
+
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
+
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
+ webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
+}
+
+static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
+
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
+ webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
+}
+
+static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
+
+ webgpu_ctx->rms_norm_pipelines[0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
+ webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
+}
+
+static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
+
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
+
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
+ webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
+}
+
+static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
+
+ // REGLU
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
+
+ // GEGLU
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
+
+ // SWIGLU
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
+
+ // SWIGLU_OAI
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
+
+ // GEGLU_ERF
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
+
+ // GEGLU_QUICK
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
+ webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
+}
+
+static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
+
+ webgpu_ctx->scale_pipelines[0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants);
+ webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace,
+ "scale_f32_inplace", constants);
+}
+
+static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
+ std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
+
+ // f32 (no mask)
+ webgpu_ctx->soft_max_pipelines[2][0][0] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
+ webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
+ webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
+ webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
+
+ // f32 mask (mask_type = 0)
+ webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
+ webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
+ webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
+ webgpu_ctx->soft_max_pipelines[0][1][1] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
+ "soft_max_f32_mask_f32_sink_inplace", constants);
+
+ // f16 mask (mask_type = 1)
+ webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
+ webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
+ webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
+ webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
+ webgpu_ctx->soft_max_pipelines[1][1][1] =
+ ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
+ "soft_max_f32_mask_f16_sink_inplace", constants);
+}
+
+static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
+ wgpu::RequestAdapterOptions options = {};
+
+#ifndef __EMSCRIPTEN__
+ // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
+ const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
+ wgpu::DawnTogglesDescriptor adapterTogglesDesc;
+ adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
+ adapterTogglesDesc.enabledToggleCount = 2;
+ options.nextInChain = &adapterTogglesDesc;
+#endif
+
+ ctx->webgpu_global_ctx->instance.WaitAny(
+ ctx->webgpu_global_ctx->instance.RequestAdapter(
+ &options, wgpu::CallbackMode::AllowSpontaneous,
+ [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
+ if (status != wgpu::RequestAdapterStatus::Success) {
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
+ return;
+ }
+ ctx->webgpu_global_ctx->adapter = std::move(adapter);
+ }),
+ UINT64_MAX);
+ GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
+
+ ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
+
+ wgpu::AdapterInfo info{};
+#ifndef __EMSCRIPTEN__
+ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
+ if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
+ info.nextInChain = &subgroup_matrix_configs;
+ }
+#endif
+ ctx->webgpu_global_ctx->adapter.GetInfo(&info);
+ wgpu::SupportedFeatures features;
+ ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
+ // we require f16 support
+ GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
+
+#ifndef __EMSCRIPTEN__
+ // Only support square f16 matrices of size 8 or 16 for now
+ bool valid_subgroup_matrix_config = false;
+ if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
+ for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
+ const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
+ if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
+ config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
+ config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
+ ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
+ ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
+ ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K;
+ valid_subgroup_matrix_config = true;
+ break;
+ }
+ }
+ }
+ ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
+#endif
+
+ // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
+ // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
+ ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
+ // Initialize device
+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
+
+#ifndef __EMSCRIPTEN__
+ required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
+ if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
+ required_features.push_back(wgpu::FeatureName::Subgroups);
+ required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
+ }
+#endif
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ required_features.push_back(wgpu::FeatureName::TimestampQuery);
+#endif
+
+ wgpu::DeviceDescriptor dev_desc;
+ dev_desc.requiredLimits = &ctx->webgpu_global_ctx->capabilities.limits;
+ dev_desc.requiredFeatures = required_features.data();
+ dev_desc.requiredFeatureCount = required_features.size();
+ dev_desc.SetDeviceLostCallback(
+ wgpu::CallbackMode::AllowSpontaneous,
+ [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
+ if (reason == wgpu::DeviceLostReason::Destroyed) {
+ return;
+ }
+ GGML_UNUSED(device);
+ GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
+ std::string(message).c_str());
+ });
+ dev_desc.SetUncapturedErrorCallback(
+ [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
+ GGML_UNUSED(device);
+ GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
+ std::string(message).c_str());
+ });
+
+#ifndef __EMSCRIPTEN__
+ // Enable Dawn-specific toggles to increase native performance
+ // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
+ // only for native performance?
+ const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
+ "disable_polyfills_on_integer_div_and_mod" };
+ const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
+ wgpu::DawnTogglesDescriptor deviceTogglesDesc;
+ deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
+ deviceTogglesDesc.enabledToggleCount = 4;
+ deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
+ deviceTogglesDesc.disabledToggleCount = 1;
+
+ dev_desc.nextInChain = &deviceTogglesDesc;
+#endif
+
+ ctx->webgpu_global_ctx->instance.WaitAny(
+ ctx->webgpu_global_ctx->adapter.RequestDevice(
+ &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
+ [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
+ if (status != wgpu::RequestDeviceStatus::Success) {
+ GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
+ return;
+ }
+ ctx->webgpu_global_ctx->device = std::move(device);
+ }),
+ UINT64_MAX);
+ GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr);
+
+ ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx);
+ ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES,
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
+ ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue();
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+ // Initialize buffer pool for timestamp queries, used for profiling
+ ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
+ ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
+ wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
+ wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
+#endif
+
+ GGML_LOG_INFO(
+ "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
+ "device_desc: %s\n",
+ info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
+ std::string(info.device).c_str(), std::string(info.description).c_str());
+ return true;
+}
+
+static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
+ ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
+ webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
+ webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
+ webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
+ webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
+ WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
+ wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
+
+ ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
+ ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
+ ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
+ ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
+ ggml_webgpu_init_rope_pipeline(webgpu_ctx);
+ ggml_webgpu_init_glu_pipeline(webgpu_ctx);
+ ggml_webgpu_init_scale_pipeline(webgpu_ctx);
+ ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
+#ifdef GGML_WEBGPU_DEBUG
+ // Initialize debug buffers
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
+ wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
+ ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf,
+ WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
+ wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
+#endif
+ return webgpu_ctx;
+}
+
+static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) {
+ GGML_UNUSED(params);
+
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()");
+
+ ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
+
+ auto * backend_ctx = new ggml_backend_webgpu_context();
+ backend_ctx->name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
+ backend_ctx->webgpu_ctx = initialize_webgpu_context(dev);
+
+ // See GGML Backend Interface section
+ auto * backend = new ggml_backend();
+ *backend = {
+ /* .guid = */ ggml_backend_webgpu_guid(),
+ /* .interface = */ ggml_backend_webgpu_i,
+ /* .device = */ dev,
+ /* .context = */ backend_ctx,
+ };
+ return backend;
+}
+
+static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
+ // See GGML Backend Buffer Type Interface section
+
+ static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
+ /* .iface = */ {
+ /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
+ /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
+ /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
+ /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size,
+ /* .is_host = */ NULL, // defaults to false
+ },
+ /* .device = */
+ dev,
+ /* .context = */
+ NULL
+ };
+
+ return &ggml_backend_webgpu_buffer_type;
+}
+
+static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+ GGML_UNUSED(dev);
+ return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
+}
+
+static bool ggml_webgpu_supported_qtype(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ return true;
+ default:
+ return false;
+ }
+}
+
+static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+ ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
+
+ ggml_tensor * src0 = op->src[0];
+ ggml_tensor * src1 = op->src[1];
+ ggml_tensor * src2 = op->src[2];
+
+ // on smaller devices (or CI), tensors may be larger than the max storage buffer size
+ if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
+ (src0 != nullptr &&
+ ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
+ (src1 != nullptr &&
+ ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
+ return false;
+ }
+
+ bool supports_op = false;
+ switch (op->op) {
+ case GGML_OP_NONE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_RESHAPE:
+ supports_op = true;
+ break;
+ case GGML_OP_ADD:
+ case GGML_OP_SUB:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
+ // see https://github.com/ggml-org/llama.cpp/pull/16857
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
+ (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
+ break;
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
+ (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
+ break;
+ case GGML_OP_SET_ROWS:
+ supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
+ (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
+ break;
+ case GGML_OP_GET_ROWS:
+ if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {
+ supports_op = (op->type == GGML_TYPE_F32);
+ } else if (src0->type == GGML_TYPE_I32) {
+ supports_op = op->type == GGML_TYPE_I32;
+ }
+ break;
+ case GGML_OP_MUL_MAT:
+ {
+ switch (src1->type) {
+ case GGML_TYPE_F16:
+ supports_op |= (src0->type == GGML_TYPE_F16);
+ break;
+ case GGML_TYPE_F32:
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ supports_op = true;
+ break;
+ default:
+ break;
+ }
+ default:
+ break;
+ }
+ break;
+ }
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+#ifndef __EMSCRIPTEN__
+ if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
+ break;
+ }
+ // Head dimensions must fit in workgroup memory with minimum tile sizes
+ size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
+ const bool has_mask = op->src[3] != nullptr;
+ const bool kv_direct = src1->type == GGML_TYPE_F16 &&
+ (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
+ (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
+ const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
+ ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
+ (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
+ if (min_bytes > limit_bytes) {
+ break;
+ }
+
+ supports_op = src0->type == GGML_TYPE_F32 &&
+ (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
+ src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
+ src2->type == src1->type && op->type == GGML_TYPE_F32;
+#endif
+ break;
+ }
+ case GGML_OP_RMS_NORM:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_ROPE:
+ supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
+ break;
+ case GGML_OP_GLU:
+ switch (ggml_get_glu_op(op)) {
+ case GGML_GLU_OP_REGLU:
+ case GGML_GLU_OP_GEGLU:
+ case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_GEGLU_ERF:
+ case GGML_GLU_OP_GEGLU_QUICK:
+ supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
+ break;
+ case GGML_GLU_OP_SWIGLU_OAI:
+ supports_op = op->type == GGML_TYPE_F32;
+ break;
+ default:
+ break;
+ }
+ break;
+ case GGML_OP_SCALE:
+ supports_op = op->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_SOFT_MAX:
+ supports_op = op->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_UNARY:
+ {
+ const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);
+
+ switch (UNARY_OP) {
+ case GGML_UNARY_OP_ABS:
+ case GGML_UNARY_OP_SGN:
+ case GGML_UNARY_OP_NEG:
+ case GGML_UNARY_OP_STEP:
+ case GGML_UNARY_OP_TANH:
+ case GGML_UNARY_OP_ELU:
+ case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_SIGMOID:
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_GELU_QUICK:
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_HARDSWISH:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_EXP:
+ case GGML_UNARY_OP_GELU_ERF:
+ case GGML_UNARY_OP_SOFTPLUS:
+ case GGML_UNARY_OP_EXPM1:
+ case GGML_UNARY_OP_FLOOR:
+ case GGML_UNARY_OP_CEIL:
+ case GGML_UNARY_OP_ROUND:
+ case GGML_UNARY_OP_TRUNC:
+ case GGML_UNARY_OP_XIELU:
+ supports_op =
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+ break;
+ default:
+ break;
+ }
+ }
+ break;
+ case GGML_OP_CLAMP:
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+ break;
+ case GGML_OP_FILL:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_LOG:
+ supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+ break;
+ case GGML_OP_PAD:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_ARGMAX:
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_ARGSORT:
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
+ break;
+ case GGML_OP_TOP_K:
+ supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
+ break;
+ case GGML_OP_CUMSUM:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type;
+ break;
+ case GGML_OP_SUM:
+ case GGML_OP_SUM_ROWS:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);
+ break;
+ default:
+ break;
+ }
+ if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
+ (src0 != nullptr &&
+ ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
+ (src1 != nullptr &&
+ ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
+ (src2 != nullptr &&
+ ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
+ supports_op = false;
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
+ }
+
+ if (!supports_op) {
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
+ } else {
+ WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
+ }
+ return supports_op;
+}
+
+static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
+ /* .get_name = */ ggml_backend_webgpu_device_get_name,
+ /* .get_description = */ ggml_backend_webgpu_device_get_description,
+ /* .get_memory = */ ggml_backend_webgpu_device_get_memory,
+ /* .get_type = */ ggml_backend_webgpu_device_get_type,
+ /* .get_props = */ ggml_backend_webgpu_device_get_props,
+ /* .init_backend = */ ggml_backend_webgpu_backend_init,
+ /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
+ /* .get_host_buffer_type = */ NULL,
+ /* .buffer_from_host_ptr = */ NULL,
+ /* .supports_op = */ ggml_backend_webgpu_device_supports_op,
+ /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
+ /* .offload_op = */ NULL,
+ /* .event_new = */ NULL,
+ /* .event_free = */ NULL,
+ /* .event_synchronize = */ NULL,
+};
+
+/* End GGML Backend Device Interface */
+
+/* GGML Backend Registration Interface */
+
+static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
+ ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
+ return ctx->name;
+}
+
+static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
+ ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
+ return ctx->device_count;
+}
+
+// Only one device is supported for now
+static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+ GGML_ASSERT(index == 0);
+ WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
+
+ WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device);
+
+ ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
+
+ create_webgpu_device(reg_ctx);
+
+ static ggml_backend_webgpu_device_context device_ctx;
+ device_ctx.device_name = GGML_WEBGPU_NAME;
+ device_ctx.device_desc = GGML_WEBGPU_NAME;
+ device_ctx.webgpu_global_ctx = reg_ctx->webgpu_global_ctx;
+ // See GGML Backend Device Interface section
+ static ggml_backend_device device = {
+ /* .iface = */ ggml_backend_webgpu_device_i,
+ /* .reg = */ reg,
+ /* .context = */ &device_ctx,
+ };
+
+ WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx);
+ return &device;
+}
+
+static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
+ /* .get_name = */ ggml_backend_webgpu_reg_get_name,
+ /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
+ /* .get_device = */ ggml_backend_webgpu_reg_get_device,
+ /* .get_proc_address = */ NULL,
+};
+
+/* End GGML Backend Registration Interface */
+
+ggml_backend_reg_t ggml_backend_webgpu_reg() {
+ WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
+
+ static ggml_backend_webgpu_reg_context ctx;
+ ctx.name = GGML_WEBGPU_NAME;
+ ctx.device_count = 1;
+
+ wgpu::InstanceDescriptor instance_descriptor{};
+ std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
+ instance_descriptor.requiredFeatures = instance_features.data();
+ instance_descriptor.requiredFeatureCount = instance_features.size();
+
+#ifndef __EMSCRIPTEN__
+ const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
+ wgpu::DawnTogglesDescriptor instanceTogglesDesc;
+ instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
+ instanceTogglesDesc.enabledToggleCount = 1;
+ instance_descriptor.nextInChain = &instanceTogglesDesc;
+#endif
+
+ wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor);
+ ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
+ ctx.webgpu_global_ctx->instance = std::move(inst);
+
+#ifdef __EMSCRIPTEN__
+ if (ctx.webgpu_global_ctx->instance == nullptr) {
+ GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
+ return nullptr;
+ }
+#endif
+ GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
+
+ static ggml_backend_reg reg = {
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
+ /* .iface = */ ggml_backend_webgpu_reg_i,
+ /* .context = */ &ctx,
+ };
+ return &reg;
+}
+
+ggml_backend_t ggml_backend_webgpu_init(void) {
+ ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
+
+ return ggml_backend_webgpu_backend_init(dev, nullptr);
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp b/llama.cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp
new file mode 100644
index 0000000..4d43594
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp
@@ -0,0 +1,778 @@
+#ifndef PRE_WGSL_HPP
+#define PRE_WGSL_HPP
+
+#include <cctype>
+#include <fstream>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <string_view>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace pre_wgsl {
+
+//==============================================================
+// Options
+//==============================================================
+struct Options {
+ std::string include_path = ".";
+ std::vector<std::string> macros;
+};
+
+//==============================================================
+// Utility: trim
+//==============================================================
+static std::string trim(const std::string & s) {
+ size_t a = 0;
+ while (a < s.size() && std::isspace((unsigned char) s[a])) {
+ a++;
+ }
+ size_t b = s.size();
+ while (b > a && std::isspace((unsigned char) s[b - 1])) {
+ b--;
+ }
+ return s.substr(a, b - a);
+}
+
+static std::string trim_value(std::istream & is) {
+ std::string str;
+ std::getline(is, str);
+ return trim(str);
+}
+
+static bool isIdentChar(char c) {
+ return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
+}
+
+static std::string expandMacrosRecursiveInternal(const std::string & line,
+ const std::unordered_map<std::string, std::string> & macros,
+ std::unordered_set<std::string> & visiting);
+
+static std::string expandMacroValue(const std::string & name,
+ const std::unordered_map<std::string, std::string> & macros,
+ std::unordered_set<std::string> & visiting) {
+ if (visiting.count(name)) {
+ throw std::runtime_error("Recursive macro: " + name);
+ }
+ visiting.insert(name);
+
+ auto it = macros.find(name);
+ if (it == macros.end()) {
+ visiting.erase(name);
+ return name;
+ }
+
+ const std::string & value = it->second;
+ if (value.empty()) {
+ visiting.erase(name);
+ return "";
+ }
+
+ std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting);
+ visiting.erase(name);
+ return expanded;
+}
+
+static std::string expandMacrosRecursiveInternal(const std::string & line,
+ const std::unordered_map<std::string, std::string> & macros,
+ std::unordered_set<std::string> & visiting) {
+ std::string result;
+ result.reserve(line.size());
+
+ size_t i = 0;
+ while (i < line.size()) {
+ if (isIdentChar(line[i])) {
+ size_t start = i;
+ while (i < line.size() && isIdentChar(line[i])) {
+ i++;
+ }
+ std::string token = line.substr(start, i - start);
+
+ auto it = macros.find(token);
+ if (it != macros.end()) {
+ result += expandMacroValue(token, macros, visiting);
+ } else {
+ result += token;
+ }
+ } else {
+ result += line[i];
+ i++;
+ }
+ }
+
+ return result;
+}
+
+static std::string expandMacrosRecursive(const std::string & line,
+ const std::unordered_map<std::string, std::string> & macros) {
+ std::unordered_set<std::string> visiting;
+ return expandMacrosRecursiveInternal(line, macros, visiting);
+}
+
+//==============================================================
+// Tokenizer for expressions in #if/#elif
+//==============================================================
+class ExprLexer {
+ public:
+ enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN };
+
+ struct Tok {
+ Kind kind;
+ std::string text;
+ };
+
+ explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {}
+
+ Tok next() {
+ skipWS();
+ if (pos >= src.size()) {
+ return { END, "" };
+ }
+
+ char c = src[pos];
+
+ // number
+ if (std::isdigit((unsigned char) c)) {
+ size_t start = pos;
+ while (pos < src.size() && std::isdigit((unsigned char) src[pos])) {
+ pos++;
+ }
+ return { NUMBER, std::string(src.substr(start, pos - start)) };
+ }
+
+ // identifier
+ if (std::isalpha((unsigned char) c) || c == '_') {
+ size_t start = pos;
+ while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) {
+ pos++;
+ }
+ return { IDENT, std::string(src.substr(start, pos - start)) };
+ }
+
+ if (c == '(') {
+ pos++;
+ return { LPAREN, "(" };
+ }
+ if (c == ')') {
+ pos++;
+ return { RPAREN, ")" };
+ }
+
+ // multi-char operators
+ static const char * two_ops[] = { "==", "!=", "<=", ">=", "&&", "||", "<<", ">>" };
+ for (auto op : two_ops) {
+ if (src.substr(pos, 2) == op) {
+ pos += 2;
+ return { OP, std::string(op) };
+ }
+ }
+
+ // single-char operators
+ if (std::string("+-*/%<>!").find(c) != std::string::npos) {
+ pos++;
+ return { OP, std::string(1, c) };
+ }
+
+ // unexpected
+ pos++;
+ return { END, "" };
+ }
+
+ private:
+ std::string_view src;
+ size_t pos;
+
+ void skipWS() {
+ while (pos < src.size() && std::isspace((unsigned char) src[pos])) {
+ pos++;
+ }
+ }
+};
+
+//==============================================================
+// Expression Parser (recursive descent)
+//==============================================================
+class ExprParser {
+ public:
+ ExprParser(std::string_view expr,
+ const std::unordered_map<std::string, std::string> & macros,
+ std::unordered_set<std::string> & visiting) :
+ lex(expr),
+ macros(macros),
+ visiting(visiting) {
+ advance();
+ }
+
+ int parse() { return parseLogicalOr(); }
+
+ private:
+ ExprLexer lex;
+ ExprLexer::Tok tok;
+ const std::unordered_map<std::string, std::string> & macros;
+ std::unordered_set<std::string> & visiting;
+
+ void advance() { tok = lex.next(); }
+
+ bool acceptOp(const std::string & s) {
+ if (tok.kind == ExprLexer::OP && tok.text == s) {
+ advance();
+ return true;
+ }
+ return false;
+ }
+
+ bool acceptKind(ExprLexer::Kind k) {
+ if (tok.kind == k) {
+ advance();
+ return true;
+ }
+ return false;
+ }
+
+ int parseLogicalOr() {
+ int v = parseLogicalAnd();
+ while (acceptOp("||")) {
+ int rhs = parseLogicalAnd();
+ v = (v || rhs);
+ }
+ return v;
+ }
+
+ int parseLogicalAnd() {
+ int v = parseEquality();
+ while (acceptOp("&&")) {
+ int rhs = parseEquality();
+ v = (v && rhs);
+ }
+ return v;
+ }
+
+ int parseEquality() {
+ int v = parseRelational();
+ for (;;) {
+ if (acceptOp("==")) {
+ int rhs = parseRelational();
+ v = (v == rhs);
+ } else if (acceptOp("!=")) {
+ int rhs = parseRelational();
+ v = (v != rhs);
+ } else {
+ break;
+ }
+ }
+ return v;
+ }
+
+ int parseRelational() {
+ int v = parseShift();
+ for (;;) {
+ if (acceptOp("<")) {
+ int rhs = parseShift();
+ v = (v < rhs);
+ } else if (acceptOp(">")) {
+ int rhs = parseShift();
+ v = (v > rhs);
+ } else if (acceptOp("<=")) {
+ int rhs = parseShift();
+ v = (v <= rhs);
+ } else if (acceptOp(">=")) {
+ int rhs = parseShift();
+ v = (v >= rhs);
+ } else {
+ break;
+ }
+ }
+ return v;
+ }
+
+ int parseShift() {
+ int v = parseAdd();
+ for (;;) {
+ if (acceptOp("<<")) {
+ int rhs = parseAdd();
+ v = (v << rhs);
+ } else if (acceptOp(">>")) {
+ int rhs = parseAdd();
+ v = (v >> rhs);
+ } else {
+ break;
+ }
+ }
+ return v;
+ }
+
+ int parseAdd() {
+ int v = parseMult();
+ for (;;) {
+ if (acceptOp("+")) {
+ int rhs = parseMult();
+ v = (v + rhs);
+ } else if (acceptOp("-")) {
+ int rhs = parseMult();
+ v = (v - rhs);
+ } else {
+ break;
+ }
+ }
+ return v;
+ }
+
+ int parseMult() {
+ int v = parseUnary();
+ for (;;) {
+ if (acceptOp("*")) {
+ int rhs = parseUnary();
+ v = (v * rhs);
+ } else if (acceptOp("/")) {
+ int rhs = parseUnary();
+ v = (rhs == 0 ? 0 : v / rhs);
+ } else if (acceptOp("%")) {
+ int rhs = parseUnary();
+ v = (rhs == 0 ? 0 : v % rhs);
+ } else {
+ break;
+ }
+ }
+ return v;
+ }
+
+ int parseUnary() {
+ if (acceptOp("!")) {
+ return !parseUnary();
+ }
+ if (acceptOp("-")) {
+ return -parseUnary();
+ }
+ if (acceptOp("+")) {
+ return +parseUnary();
+ }
+ return parsePrimary();
+ }
+
+ int parsePrimary() {
+ // '(' expr ')'
+ if (acceptKind(ExprLexer::LPAREN)) {
+ int v = parse();
+ if (!acceptKind(ExprLexer::RPAREN)) {
+ throw std::runtime_error("missing ')'");
+ }
+ return v;
+ }
+
+ // number
+ if (tok.kind == ExprLexer::NUMBER) {
+ int v = std::stoi(tok.text);
+ advance();
+ return v;
+ }
+
+ // defined(identifier)
+ if (tok.kind == ExprLexer::IDENT && tok.text == "defined") {
+ advance();
+ if (acceptKind(ExprLexer::LPAREN)) {
+ if (tok.kind != ExprLexer::IDENT) {
+ throw std::runtime_error("expected identifier in defined()");
+ }
+ std::string name = tok.text;
+ advance();
+ if (!acceptKind(ExprLexer::RPAREN)) {
+ throw std::runtime_error("missing ) in defined()");
+ }
+ return macros.count(name) ? 1 : 0;
+ } else {
+ // defined NAME
+ if (tok.kind != ExprLexer::IDENT) {
+ throw std::runtime_error("expected identifier in defined NAME");
+ }
+ std::string name = tok.text;
+ advance();
+ return macros.count(name) ? 1 : 0;
+ }
+ }
+
+ // identifier -> treat as integer, if defined use its value else 0
+ if (tok.kind == ExprLexer::IDENT) {
+ std::string name = tok.text;
+ advance();
+ auto it = macros.find(name);
+ if (it == macros.end()) {
+ return 0;
+ }
+ if (it->second.empty()) {
+ return 1;
+ }
+ return evalMacroExpression(name, it->second);
+ }
+
+ // unexpected
+ return 0;
+ }
+
+ int evalMacroExpression(const std::string & name, const std::string & value) {
+ if (visiting.count(name)) {
+ throw std::runtime_error("Recursive macro: " + name);
+ }
+
+ visiting.insert(name);
+ ExprParser ep(value, macros, visiting);
+ int v = ep.parse();
+ visiting.erase(name);
+ return v;
+ }
+};
+
+//==============================================================
+// Preprocessor
+//==============================================================
+class Preprocessor {
+ public:
+ explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) {
+ // Treat empty include path as current directory
+ if (opts_.include_path.empty()) {
+ opts_.include_path = ".";
+ }
+ parseMacroDefinitions(opts_.macros);
+ }
+
+ std::string preprocess_file(const std::string & filename, const std::vector<std::string> & additional_macros = {}) {
+ std::unordered_map<std::string, std::string> macros;
+ std::unordered_set<std::string> predefined;
+ std::unordered_set<std::string> include_stack;
+ buildMacros(additional_macros, macros, predefined);
+
+ std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All);
+ return result;
+ }
+
+ std::string preprocess(const std::string & contents, const std::vector<std::string> & additional_macros = {}) {
+ std::unordered_map<std::string, std::string> macros;
+ std::unordered_set<std::string> predefined;
+ std::unordered_set<std::string> include_stack;
+ buildMacros(additional_macros, macros, predefined);
+
+ std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All);
+ return result;
+ }
+
+ std::string preprocess_includes_file(const std::string & filename) {
+ std::unordered_map<std::string, std::string> macros;
+ std::unordered_set<std::string> predefined;
+ std::unordered_set<std::string> include_stack;
+ std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
+ return result;
+ }
+
+ std::string preprocess_includes(const std::string & contents) {
+ std::unordered_map<std::string, std::string> macros;
+ std::unordered_set<std::string> predefined;
+ std::unordered_set<std::string> include_stack;
+ std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
+ return result;
+ }
+
+ private:
+ Options opts_;
+ std::unordered_map<std::string, std::string> global_macros;
+
+ enum class DirectiveMode { All, IncludesOnly };
+
+ struct Cond {
+ bool parent_active;
+ bool active;
+ bool taken;
+ };
+
+ //----------------------------------------------------------
+ // Parse macro definitions into global_macros
+ //----------------------------------------------------------
+ void parseMacroDefinitions(const std::vector<std::string> & macro_defs) {
+ for (const auto & def : macro_defs) {
+ size_t eq_pos = def.find('=');
+ if (eq_pos != std::string::npos) {
+ // Format: NAME=VALUE
+ std::string name = trim(def.substr(0, eq_pos));
+ std::string value = trim(def.substr(eq_pos + 1));
+ global_macros[name] = value;
+ } else {
+ // Format: NAME
+ std::string name = trim(def);
+ global_macros[name] = "";
+ }
+ }
+ }
+
+ //----------------------------------------------------------
+ // Build combined macro map and predefined set for a preprocessing operation
+ //----------------------------------------------------------
+ void buildMacros(const std::vector<std::string> & additional_macros,
+ std::unordered_map<std::string, std::string> & macros,
+ std::unordered_set<std::string> & predefined) {
+ macros = global_macros;
+ predefined.clear();
+
+ for (const auto & [name, value] : global_macros) {
+ predefined.insert(name);
+ }
+
+ for (const auto & def : additional_macros) {
+ size_t eq_pos = def.find('=');
+ std::string name, value;
+ if (eq_pos != std::string::npos) {
+ name = trim(def.substr(0, eq_pos));
+ value = trim(def.substr(eq_pos + 1));
+ } else {
+ name = trim(def);
+ value = "";
+ }
+
+ // Add to macros map (will override global if same name)
+ macros[name] = value;
+ predefined.insert(name);
+ }
+ }
+
+ //----------------------------------------------------------
+ // Helpers
+ //----------------------------------------------------------
+ std::string loadFile(const std::string & fname) {
+ std::ifstream f(fname);
+ if (!f.is_open()) {
+ throw std::runtime_error("Could not open file: " + fname);
+ }
+ std::stringstream ss;
+ ss << f.rdbuf();
+ return ss.str();
+ }
+
+ bool condActive(const std::vector<Cond> & cond) const {
+ if (cond.empty()) {
+ return true;
+ }
+ return cond.back().active;
+ }
+
+ //----------------------------------------------------------
+ // Process a file
+ //----------------------------------------------------------
+ std::string processFile(const std::string & name,
+ std::unordered_map<std::string, std::string> & macros,
+ const std::unordered_set<std::string> & predefined_macros,
+ std::unordered_set<std::string> & include_stack,
+ DirectiveMode mode) {
+ if (include_stack.count(name)) {
+ throw std::runtime_error("Recursive include: " + name);
+ }
+
+ include_stack.insert(name);
+ std::string shader_code = loadFile(name);
+ std::string out = processString(shader_code, macros, predefined_macros, include_stack, mode);
+ include_stack.erase(name);
+ return out;
+ }
+
+ std::string processIncludeFile(const std::string & fname,
+ std::unordered_map<std::string, std::string> & macros,
+ const std::unordered_set<std::string> & predefined_macros,
+ std::unordered_set<std::string> & include_stack,
+ DirectiveMode mode) {
+ std::string full_path = opts_.include_path + "/" + fname;
+ return processFile(full_path, macros, predefined_macros, include_stack, mode);
+ }
+
+ //----------------------------------------------------------
+ // Process text
+ //----------------------------------------------------------
+ std::string processString(const std::string & shader_code,
+ std::unordered_map<std::string, std::string> & macros,
+ const std::unordered_set<std::string> & predefined_macros,
+ std::unordered_set<std::string> & include_stack,
+ DirectiveMode mode) {
+ std::vector<Cond> cond; // Conditional stack for this shader
+ std::stringstream out;
+ std::istringstream in(shader_code);
+ std::string line;
+
+ while (std::getline(in, line)) {
+ std::string t = trim(line);
+
+ if (!t.empty() && t[0] == '#') {
+ bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);
+ if (mode == DirectiveMode::IncludesOnly && !handled) {
+ out << line << "\n";
+ }
+ } else {
+ if (mode == DirectiveMode::IncludesOnly) {
+ out << line << "\n";
+ } else if (condActive(cond)) {
+ // Expand macros in the line before outputting
+ std::string expanded = expandMacrosRecursive(line, macros);
+ out << expanded << "\n";
+ }
+ }
+ }
+
+ if (mode == DirectiveMode::All && !cond.empty()) {
+ throw std::runtime_error("Unclosed #if directive");
+ }
+
+ return out.str();
+ }
+
+ //----------------------------------------------------------
+ // Directive handler
+ //----------------------------------------------------------
+ bool handleDirective(const std::string & t,
+ std::stringstream & out,
+ std::unordered_map<std::string, std::string> & macros,
+ const std::unordered_set<std::string> & predefined_macros,
+ std::vector<Cond> & cond,
+ std::unordered_set<std::string> & include_stack,
+ DirectiveMode mode) {
+ // split into tokens
+ std::string body = t.substr(1);
+ std::istringstream iss(body);
+ std::string cmd;
+ iss >> cmd;
+
+ if (cmd == "include") {
+ if (mode == DirectiveMode::All && !condActive(cond)) {
+ return true;
+ }
+ std::string file;
+ iss >> file;
+ if (file.size() >= 2 && file.front() == '"' && file.back() == '"') {
+ file = file.substr(1, file.size() - 2);
+ }
+ out << processIncludeFile(file, macros, predefined_macros, include_stack, mode);
+ return true;
+ }
+
+ if (mode == DirectiveMode::IncludesOnly) {
+ return false;
+ }
+
+ if (cmd == "define") {
+ if (!condActive(cond)) {
+ return true;
+ }
+ std::string name;
+ iss >> name;
+ // Don't override predefined macros from options
+ if (predefined_macros.count(name)) {
+ return true;
+ }
+ std::string value = trim_value(iss);
+ macros[name] = value;
+ return true;
+ }
+
+ if (cmd == "undef") {
+ if (!condActive(cond)) {
+ return true;
+ }
+ std::string name;
+ iss >> name;
+ // Don't undef predefined macros from options
+ if (predefined_macros.count(name)) {
+ return true;
+ }
+ macros.erase(name);
+ return true;
+ }
+
+ if (cmd == "ifdef") {
+ std::string name;
+ iss >> name;
+ bool p = condActive(cond);
+ bool v = macros.count(name);
+ cond.push_back({ p, p && v, p && v });
+ return true;
+ }
+
+ if (cmd == "ifndef") {
+ std::string name;
+ iss >> name;
+ bool p = condActive(cond);
+ bool v = !macros.count(name);
+ cond.push_back({ p, p && v, p && v });
+ return true;
+ }
+
+ if (cmd == "if") {
+ std::string expr = trim_value(iss);
+ bool p = condActive(cond);
+ bool v = false;
+ if (p) {
+ std::unordered_set<std::string> visiting;
+ ExprParser ep(expr, macros, visiting);
+ v = ep.parse() != 0;
+ }
+ cond.push_back({ p, p && v, p && v });
+ return true;
+ }
+
+ if (cmd == "elif") {
+ std::string expr = trim_value(iss);
+
+ if (cond.empty()) {
+ throw std::runtime_error("#elif without #if");
+ }
+
+ Cond & c = cond.back();
+ if (!c.parent_active) {
+ c.active = false;
+ return true;
+ }
+
+ if (c.taken) {
+ c.active = false;
+ return true;
+ }
+
+ std::unordered_set<std::string> visiting;
+ ExprParser ep(expr, macros, visiting);
+ bool v = ep.parse() != 0;
+ c.active = v;
+ if (v) {
+ c.taken = true;
+ }
+ return true;
+ }
+
+ if (cmd == "else") {
+ if (cond.empty()) {
+ throw std::runtime_error("#else without #if");
+ }
+
+ Cond & c = cond.back();
+ if (!c.parent_active) {
+ c.active = false;
+ return true;
+ }
+ if (c.taken) {
+ c.active = false;
+ } else {
+ c.active = true;
+ c.taken = true;
+ }
+ return true;
+ }
+
+ if (cmd == "endif") {
+ if (cond.empty()) {
+ throw std::runtime_error("#endif without #if");
+ }
+ cond.pop_back();
+ return true;
+ }
+
+ // Unknown directive
+ throw std::runtime_error("Unknown directive: #" + cmd);
+ }
+};
+
+} // namespace pre_wgsl
+
+#endif // PRE_WGSL_HPP
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl
new file mode 100644
index 0000000..ca5bfcc
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl
@@ -0,0 +1,72 @@
+@group(0) @binding(0)
+#ifdef VEC4
+var<storage, read_write> src: array<vec4<f32>>;
+#define VEC_SIZE 4
+#else
+var<storage, read_write> src: array<f32>;
+#define VEC_SIZE 1
+#endif
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<i32>;
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+ ne0: u32,
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+const FLOAT_MIN: f32 = -1.0e9;
+
+struct Pair {
+ value: f32,
+ index: i32
+};
+
+var<workgroup> shared_max: array<Pair, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+ let row_idx = params.offset_src + wid.x * params.ne0;
+ var local_pair = Pair(FLOAT_MIN, -1);
+#ifdef VEC4
+ for (var col = lid.x; col < params.ne0/VEC_SIZE; col += WG_SIZE) {
+ let vec_val = src[row_idx / VEC_SIZE + col];
+ for (var v = 0u; v < VEC_SIZE; v++) {
+ let val = vec_val[v];
+ if (val >= local_pair.value) {
+ local_pair = Pair(val, i32(col * VEC_SIZE + v));
+ }
+ }
+ }
+#else
+ for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
+ if (src[row_idx + col] >= local_pair.value) {
+ local_pair = Pair(src[row_idx + col], i32(col));
+ }
+ }
+#endif
+ shared_max[lid.x] = local_pair;
+ workgroupBarrier();
+ var offset: u32 = WG_SIZE >> 1;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ let a = shared_max[lid.x];
+ let b = shared_max[lid.x + offset];
+ if (b.value > a.value) {
+ shared_max[lid.x] = b;
+ } else if (b.value == a.value && b.index > a.index) {
+ shared_max[lid.x] = b;
+ }
+ }
+ workgroupBarrier();
+ offset >>= 1;
+ }
+ if (lid.x == 0u) {
+ dst[params.offset_dst + wid.x] = shared_max[0].index;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl
new file mode 100644
index 0000000..46ed19f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl
@@ -0,0 +1,106 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<i32>;
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // src/dst dimensions
+ src_ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ ne0: u32,
+ top_k: u32,
+
+ npr: u32, // tiles per row
+ nrows: u32
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+var<workgroup> shmem_idx: array<u32, WG_SIZE>;
+
+#if ORDER == 0
+#define EXTREME_VALUE 1e30
+#define SWAP_COMPARE_UP >
+#define SWAP_COMPARE_DOWN <
+#else
+#define EXTREME_VALUE -1e30
+#define SWAP_COMPARE_UP <
+#define SWAP_COMPARE_DOWN >
+#endif
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(num_workgroups) num_wg: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+ let linear = wid.x + wid.y * num_wg.x;
+ // guard against overprovisioned workgroups
+ if (linear >= params.npr * params.nrows) {
+ return;
+ }
+ let tile = linear % params.npr;
+ var row = linear / params.npr;
+ let i3 = row / (params.ne2 * params.ne1);
+ row = row % (params.ne2 * params.ne1);
+ let i2 = row / params.ne1;
+ let i1 = row % params.ne1;
+
+ let row_base = params.offset_src +
+ i1 * params.stride_src1 +
+ i2 * params.stride_src2 +
+ i3 * params.stride_src3;
+
+ let tile_base = tile * WG_SIZE;
+ let idx = tile_base + lid.x;
+ shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0);
+ workgroupBarrier();
+
+ var k = 2u;
+ while (k <= WG_SIZE) {
+ var j = k >> 1;
+ while (j > 0) {
+ let ixj = lid.x ^ j;
+ if (ixj > lid.x) {
+ let dir_up = (lid.x & k) == 0;
+ let a_idx = shmem_idx[lid.x];
+ let b_idx = shmem_idx[ixj];
+ let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0);
+ let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0);
+ let should_swap = select(
+ (a_val SWAP_COMPARE_DOWN b_val),
+ (a_val SWAP_COMPARE_UP b_val),
+ dir_up);
+ if (should_swap) {
+ shmem_idx[lid.x] = b_idx;
+ shmem_idx[ixj] = a_idx;
+ }
+ }
+ workgroupBarrier();
+ j >>= 1;
+ }
+ k <<= 1;
+ }
+
+ let out_idx = tile * params.top_k + lid.x;
+ if (out_idx < params.ne0 && lid.x < params.top_k) {
+ let row_dst = params.offset_dst +
+ i1 * params.stride_dst1 +
+ i2 * params.stride_dst2 +
+ i3 * params.stride_dst3;
+ dst[row_dst + out_idx] = i32(shmem_idx[lid.x]);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
new file mode 100644
index 0000000..9a77f6e
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
@@ -0,0 +1,134 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> idx_in: array<i32>;
+
+@group(0) @binding(2)
+var<storage, read_write> idx_out: array<i32>;
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_in: u32, // in elements
+ offset_out: u32, // in elements
+
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_idx1: u32,
+ stride_idx2: u32,
+ stride_idx3: u32,
+
+ stride_out1: u32,
+ stride_out2: u32,
+ stride_out3: u32,
+
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ top_k: u32,
+
+ len: u32,
+ nm: u32,
+ nrows: u32
+};
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {
+ let a_val = src[row_base + u32(a_idx)];
+ let b_val = src[row_base + u32(b_idx)];
+#if ORDER == 0
+ return a_val <= b_val;
+#else
+ return a_val >= b_val;
+#endif
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(num_workgroups) num_wg: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+ let linear = wid.x + wid.y * num_wg.x;
+ // guard against overprovisioned workgroups
+ if (linear >= params.nm * params.nrows) {
+ return;
+ }
+
+ let start = (linear % params.nm) * params.len * 2;
+ let len0 = min(params.len, params.ne0 - start);
+ let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));
+ let len1 = min(params.len, rem1);
+ let total = len0 + len1;
+ let chunk = (total + WG_SIZE - 1u) / WG_SIZE;
+ let k0 = lid.x * chunk;
+ let k1 = min(min(k0 + chunk, total), params.top_k);
+ // guard against overprovisioned threads
+ if (k0 >= params.top_k || k0 >= total) {
+ return;
+ }
+
+ var row = linear / params.nm;
+ let i3 = row / (params.ne2 * params.ne1);
+ row = row % (params.ne2 * params.ne1);
+ let i2 = row / params.ne1;
+ let i1 = row % params.ne1;
+
+ let row_src = params.offset_src +
+ i1 * params.stride_src1 +
+ i2 * params.stride_src2 +
+ i3 * params.stride_src3;
+
+ let row_in = params.offset_in +
+ i1 * params.stride_idx1 +
+ i2 * params.stride_idx2 +
+ i3 * params.stride_idx3;
+
+ let row_out = params.offset_out +
+ i1 * params.stride_out1 +
+ i2 * params.stride_out2 +
+ i3 * params.stride_out3;
+
+
+ var low: u32 = select(0, k0 - len1, k0 > len1);
+ var high: u32 = min(k0, len0);
+
+ while (low < high) {
+ let mid = (low + high) >> 1;
+ let idx0 = idx_in[row_in + start + mid];
+ let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];
+ if (take_left(idx0, idx1, row_src)) {
+ low = mid + 1;
+ } else {
+ high = mid;
+ }
+ }
+
+ var i = low;
+ var j = k0 - i;
+ var k = k0;
+ while (k < k1) {
+ var take_l = false;
+ if (i >= len0) {
+ take_l = false;
+ } else if (j >= len1) {
+ take_l = true;
+ } else {
+ let idx0 = idx_in[row_in + start + i];
+ let idx1 = idx_in[row_in + start + params.len + j];
+ take_l = take_left(idx0, idx1, row_src);
+ }
+
+ let out_idx = select(
+ idx_in[row_in + start + params.len + j],
+ idx_in[row_in + start + i],
+ take_l);
+ idx_out[row_out + start + k] = out_idx;
+ i = select(i, i + 1, take_l);
+ j = select(j + 1, j, take_l);
+ k += 1;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl
new file mode 100644
index 0000000..55dd664
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl
@@ -0,0 +1,107 @@
+enable f16;
+
+struct Params {
+ ne: u32,
+
+ // offsets in elements
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_dst: u32,
+
+ stride_src1_0: u32,
+ stride_src1_1: u32,
+ stride_src1_2: u32,
+ stride_src1_3: u32,
+
+ a_ne0: u32,
+ a_ne1: u32,
+ a_ne2: u32,
+
+ b_ne0: u32,
+ b_ne1: u32,
+ b_ne2: u32,
+ b_ne3: u32,
+};
+
+fn src1_index(_i: u32) -> u32 {
+ var i = _i;
+ let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
+ i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
+ let a_i2 = i / (params.a_ne1 * params.a_ne0);
+ i = i % (params.a_ne1 * params.a_ne0);
+ let a_i1 = i / params.a_ne0;
+ let a_i0 = i % params.a_ne0;
+
+ // handle repetition of b
+ // index loops back to the beginning and repeats after elements are exhausted = modulo
+ let b_i0 = a_i0 % params.b_ne0;
+ let b_i1 = a_i1 % params.b_ne1;
+ let b_i2 = a_i2 % params.b_ne2;
+ let b_i3 = a_i3 % params.b_ne3;
+
+ // compute index for position in b's flat array
+ return b_i0 * params.stride_src1_0 +
+ b_i1 * params.stride_src1_1 +
+ b_i2 * params.stride_src1_2 +
+ b_i3 * params.stride_src1_3;
+}
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_F16
+#define DataType f16
+#endif
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<DataType>;
+
+@group(0) @binding(1)
+var<storage, read_write> src1 : array<DataType>;
+
+#ifdef INPLACE
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#elif defined(OVERLAP)
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#else
+@group(0) @binding(2)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#endif
+
+fn op(a: DataType, b: DataType) -> DataType {
+#ifdef OP_ADD
+ return a + b;
+#elif defined(OP_SUB)
+ return a - b;
+#elif defined(OP_MUL)
+ return a * b;
+#elif defined(OP_DIV)
+ return a / b;
+#endif
+}
+
+fn update(dst_i: u32, src0_i: u32, src1_i: u32){
+ let result = op(src0[src0_i], src1[src1_i]);
+
+#ifdef INPLACE
+ src0[dst_i] = result;
+#elif defined(OVERLAP)
+ src1[dst_i] = result;
+#else
+ dst[dst_i] = result;
+#endif
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x < params.ne) {
+ update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl
new file mode 100644
index 0000000..389c97b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl
@@ -0,0 +1,930 @@
+#decl(BYTE_HELPERS)
+
+fn get_byte(value: u32, index: u32) -> u32 {
+ return (value >> (index * 8)) & 0xFF;
+}
+
+fn get_byte_i32(value: u32, index: u32) -> i32 {
+ return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
+}
+
+#enddecl(BYTE_HELPERS)
+
+#decl(Q4_0_T)
+struct q4_0 {
+ d: f16,
+ qs: array<f16, 8>
+};
+#enddecl(Q4_0_T)
+
+#decl(Q4_1_T)
+struct q4_1 {
+ d: f16,
+ m: f16,
+ qs: array<u32, 4>
+};
+#enddecl(Q4_1_T)
+
+#decl(Q5_0_T)
+struct q5_0 {
+ d: f16,
+ qh: array<f16, 2>,
+ qs: array<f16, 8>
+};
+#enddecl(Q5_0_T)
+
+#decl(Q5_1_T)
+struct q5_1 {
+ d: f16,
+ m: f16,
+ qh: u32,
+ qs: array<u32, 4>
+};
+#enddecl(Q5_1_T)
+
+#decl(Q8_0_T)
+struct q8_0 {
+ d: f16,
+ qs: array<f16, 16>
+};
+#enddecl(Q8_0_T)
+
+#decl(Q8_1_T)
+struct q8_1 {
+ d: f16,
+ m: f16,
+ qs: array<u32, 8>
+};
+#enddecl(Q8_1_T)
+
+#decl(Q2_K_T)
+struct q2_k {
+ scales: array<u32, 4>,
+ qs: array<u32, 16>,
+ d: f16,
+ dmin: f16
+};
+#enddecl(Q2_K_T)
+
+#decl(Q3_K_T)
+struct q3_k {
+ hmask: array<f16, 16>,
+ qs: array<f16, 32>,
+ scales: array<f16, 6>,
+ d: f16
+};
+#enddecl(Q3_K_T)
+
+#decl(Q45_K_SCALE_MIN)
+
+fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
+ if (is < 4) {
+ let sc_byte = get_byte(scales[is / 4], is % 4);
+ let min_byte = get_byte(scales[(is + 4) / 4], is % 4);
+ return vec2(f32(sc_byte & 63), f32(min_byte & 63));
+ } else {
+ let sc_min_lo = get_byte(scales[(is + 4) / 4], (is + 4) % 4);
+ let sc_hi = get_byte(scales[(is - 4) / 4], (is - 4) % 4);
+ let min_hi = get_byte(scales[is / 4], is % 4);
+ let sc = (sc_min_lo & 0xF) | ((sc_hi >> 6) << 4);
+ let m = (sc_min_lo >> 4) | ((min_hi >> 6) << 4);
+ return vec2(f32(sc), f32(m));
+ }
+}
+
+#enddecl(Q45_K_SCALE_MIN)
+
+#decl(Q4_K_T)
+struct q4_k {
+ d: f16,
+ dmin: f16,
+ scales: array<u32, 3>,
+ qs: array<u32, 32>
+};
+#enddecl(Q4_K_T)
+
+#decl(Q5_K_T)
+struct q5_k {
+ d: f16,
+ dmin: f16,
+ scales: array<u32, 3>,
+ qh: array<u32, 8>,
+ qs: array<u32, 32>
+};
+#enddecl(Q5_K_T)
+
+#decl(Q6_K_T)
+struct q6_k {
+ ql: array<f16, 64>,
+ qh: array<f16, 32>,
+ scales: array<f16, 8>,
+ d: f16
+};
+#enddecl(Q6_K_T)
+
+#decl(IQ2_XXS_T)
+struct iq2_xxs {
+ d: f16,
+ qs: array<f16, 32>
+};
+#enddecl(IQ2_XXS_T)
+
+#decl(IQ2_XS_T)
+struct iq2_xs {
+ d: f16,
+ qs: array<f16, 32>,
+ scales: array<f16, 4>
+};
+#enddecl(IQ2_XS_T)
+
+#decl(IQ2_S_T)
+struct iq2_s {
+ d: f16,
+ qs: array<f16, 32>,
+ qh: array<f16, 4>,
+ scales: array<f16, 4>
+};
+#enddecl(IQ2_S_T)
+
+#decl(IQ3_XSS_T)
+struct iq3_xxs {
+ d: f16,
+ qs: array<f16, 48>
+};
+#enddecl(IQ3_XSS_T)
+
+#decl(IQ3_S_T)
+struct iq3_s {
+ d: f16,
+ qs: array<f16, 32>,
+ qh: array<f16, 4>,
+ signs: array<f16, 16>,
+ scales: array<f16, 2>
+};
+#enddecl(IQ3_S_T)
+
+#decl(IQ1_S_T)
+struct iq1_s {
+ d: f16,
+ qs: array<f16, 16>,
+ qh: array<f16, 8>
+};
+#enddecl(IQ1_S_T)
+
+#decl(IQ1_M_T)
+struct iq1_m {
+ qs: array<u32, 8>,
+ qh: array<u32, 4>,
+ scales: array<u32, 2>
+};
+#enddecl(IQ1_M_T)
+
+#decl(IQ4_NL_T)
+struct iq4_nl {
+ d: f16,
+ qs: array<f16, 8>,
+};
+#enddecl(IQ4_NL_T)
+
+#decl(IQ4_XS_T)
+struct iq4_xs {
+ d: f16,
+ scales_h: f16,
+ scales_l: u32,
+ qs: array<u32, 32>
+};
+#enddecl(IQ4_XS_T)
+
+#decl(IQ23_TABLES)
+const kmask_iq2xs : array<u32, 2> = array<u32, 2>(
+ 0x08040201u, // 1, 2, 4, 8
+ 0x80402010u // 16, 32, 64, 128
+);
+
+const ksigns_iq2xs: array<u32, 32> = array<u32, 32>(
+ 0x03828100,0x87060584,0x8b0a0988,0x0f8e8d0c,
+ 0x93121190,0x17969514,0x1b9a9918,0x9f1e1d9c,
+ 0xa32221a0,0x27a6a524,0x2baaa928,0xaf2e2dac,
+ 0x33b2b130,0xb73635b4,0xbb3a39b8,0x3fbebd3c,
+ 0xc34241c0,0x47c6c544,0x4bcac948,0xcf4e4dcc,
+ 0x53d2d150,0xd75655d4,0xdb5a59d8,0x5fdedd5c,
+ 0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
+ 0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
+);
+#enddecl(IQ23_TABLES)
+
+#decl(IQ2_XXS_GRID)
+const iq2xxs_grid = array<u32, 512>(
+ 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
+ 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
+ 0x082b082b, 0x08080808, 0x082b2b08, 0x08080808, 0x082b2b2b, 0x08080808, 0x19080819, 0x08080808,
+ 0x19081908, 0x08080808, 0x19190808, 0x08080808, 0x19192b08, 0x08080808, 0x192b0819, 0x08080808,
+ 0x192b1908, 0x08080808, 0x2b080808, 0x08080808, 0x2b08082b, 0x08080808, 0x2b082b2b, 0x08080808,
+ 0x2b2b082b, 0x08080808, 0x08080819, 0x08080819, 0x08081908, 0x08080819, 0x08190808, 0x08080819,
+ 0x08191919, 0x08080819, 0x19080808, 0x08080819, 0x2b081908, 0x08080819, 0x2b192b08, 0x08080819,
+ 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, 0x082b082b, 0x0808082b, 0x2b08082b, 0x0808082b,
+ 0x08080819, 0x08081908, 0x08081908, 0x08081908, 0x08190808, 0x08081908, 0x082b0819, 0x08081908,
+ 0x082b1908, 0x08081908, 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19082b08, 0x08081908,
+ 0x192b0808, 0x08081908, 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b190808, 0x08081908,
+ 0x2b2b1908, 0x08081908, 0x08080808, 0x08081919, 0x0808082b, 0x08081919, 0x08082b08, 0x08081919,
+ 0x082b0808, 0x08081919, 0x1908192b, 0x08081919, 0x192b2b19, 0x08081919, 0x2b080808, 0x08081919,
+ 0x2b190819, 0x08081919, 0x08082b19, 0x0808192b, 0x08190808, 0x0808192b, 0x19080808, 0x0808192b,
+ 0x2b081908, 0x0808192b, 0x2b2b1908, 0x0808192b, 0x08080808, 0x08082b08, 0x08081919, 0x08082b08,
+ 0x08082b08, 0x08082b08, 0x08191908, 0x08082b08, 0x082b2b08, 0x08082b08, 0x19080819, 0x08082b08,
+ 0x19081908, 0x08082b08, 0x19190808, 0x08082b08, 0x1919082b, 0x08082b08, 0x2b082b08, 0x08082b08,
+ 0x08081908, 0x08082b19, 0x19080808, 0x08082b19, 0x0808082b, 0x08082b2b, 0x08191908, 0x08082b2b,
+ 0x08080819, 0x08190808, 0x08081908, 0x08190808, 0x08190808, 0x08190808, 0x082b0819, 0x08190808,
+ 0x19080808, 0x08190808, 0x192b0808, 0x08190808, 0x2b081908, 0x08190808, 0x2b190808, 0x08190808,
+ 0x2b191919, 0x08190808, 0x08080808, 0x08190819, 0x08082b08, 0x08190819, 0x082b0808, 0x08190819,
+ 0x19190808, 0x08190819, 0x19192b2b, 0x08190819, 0x2b080808, 0x08190819, 0x082b1908, 0x0819082b,
+ 0x19081919, 0x0819082b, 0x08080808, 0x08191908, 0x08082b08, 0x08191908, 0x082b0808, 0x08191908,
+ 0x082b1919, 0x08191908, 0x19082b19, 0x08191908, 0x2b080808, 0x08191908, 0x08192b08, 0x08191919,
+ 0x192b082b, 0x08191919, 0x08080808, 0x0819192b, 0x0819192b, 0x0819192b, 0x08080819, 0x08192b08,
+ 0x08081908, 0x08192b08, 0x08190808, 0x08192b08, 0x19080808, 0x08192b08, 0x2b080819, 0x08192b08,
+ 0x08080808, 0x08192b19, 0x08081919, 0x08192b19, 0x2b2b0808, 0x08192b19, 0x19190819, 0x08192b2b,
+ 0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08082b2b, 0x082b0808, 0x19081908, 0x082b0808,
+ 0x192b0819, 0x082b0808, 0x2b080808, 0x082b0808, 0x2b08082b, 0x082b0808, 0x082b2b19, 0x082b0819,
+ 0x19082b08, 0x082b0819, 0x08080808, 0x082b082b, 0x0808082b, 0x082b082b, 0x08080819, 0x082b1908,
+ 0x08081908, 0x082b1908, 0x08190808, 0x082b1908, 0x19080808, 0x082b1908, 0x1919192b, 0x082b1908,
+ 0x08080808, 0x082b1919, 0x19080819, 0x082b1919, 0x192b1908, 0x082b1919, 0x2b190808, 0x082b192b,
+ 0x08082b08, 0x082b2b08, 0x082b0808, 0x082b2b08, 0x2b191908, 0x082b2b08, 0x19081908, 0x082b2b2b,
+ 0x08080819, 0x19080808, 0x08081908, 0x19080808, 0x08190808, 0x19080808, 0x08192b08, 0x19080808,
+ 0x082b0819, 0x19080808, 0x082b1908, 0x19080808, 0x19080808, 0x19080808, 0x19082b08, 0x19080808,
+ 0x1919192b, 0x19080808, 0x192b0808, 0x19080808, 0x2b080819, 0x19080808, 0x2b081908, 0x19080808,
+ 0x2b190808, 0x19080808, 0x08080808, 0x19080819, 0x082b0808, 0x19080819, 0x192b0819, 0x19080819,
+ 0x2b080808, 0x19080819, 0x2b081919, 0x19080819, 0x08080819, 0x1908082b, 0x08190808, 0x1908082b,
+ 0x19082b08, 0x1908082b, 0x1919192b, 0x1908082b, 0x192b2b08, 0x1908082b, 0x08080808, 0x19081908,
+ 0x08082b08, 0x19081908, 0x082b0808, 0x19081908, 0x2b080808, 0x19081908, 0x2b192b19, 0x19081908,
+ 0x0819082b, 0x19081919, 0x082b1908, 0x19081919, 0x08080808, 0x1908192b, 0x08080819, 0x19082b08,
+ 0x08081908, 0x19082b08, 0x08190808, 0x19082b08, 0x19080808, 0x19082b08, 0x19081919, 0x19082b08,
+ 0x08080808, 0x19082b19, 0x19192b08, 0x19082b19, 0x192b0819, 0x19082b19, 0x2b08082b, 0x19082b19,
+ 0x19081919, 0x19082b2b, 0x2b190808, 0x19082b2b, 0x08080808, 0x19190808, 0x08082b08, 0x19190808,
+ 0x08190819, 0x19190808, 0x08192b19, 0x19190808, 0x082b0808, 0x19190808, 0x2b080808, 0x19190808,
+ 0x2b082b08, 0x19190808, 0x08081908, 0x19190819, 0x1908082b, 0x19190819, 0x2b2b1908, 0x19190819,
+ 0x2b190819, 0x1919082b, 0x2b190808, 0x19191908, 0x2b19082b, 0x19191908, 0x08082b2b, 0x19191919,
+ 0x08080819, 0x1919192b, 0x19191908, 0x1919192b, 0x08080808, 0x19192b08, 0x08190819, 0x19192b08,
+ 0x08192b19, 0x19192b08, 0x192b1908, 0x19192b08, 0x19080808, 0x19192b19, 0x08082b08, 0x19192b2b,
+ 0x08081908, 0x192b0808, 0x08190808, 0x192b0808, 0x19080808, 0x192b0808, 0x192b2b08, 0x192b0808,
+ 0x08080808, 0x192b0819, 0x19191919, 0x192b0819, 0x08192b08, 0x192b082b, 0x192b0808, 0x192b082b,
+ 0x08080808, 0x192b1908, 0x08081919, 0x192b1908, 0x08190808, 0x192b1919, 0x0819082b, 0x192b1919,
+ 0x2b081908, 0x192b1919, 0x1908082b, 0x192b2b08, 0x08080808, 0x2b080808, 0x0808082b, 0x2b080808,
+ 0x08082b2b, 0x2b080808, 0x19080819, 0x2b080808, 0x2b08082b, 0x2b080808, 0x08081908, 0x2b080819,
+ 0x08192b08, 0x2b080819, 0x19080808, 0x2b080819, 0x08190819, 0x2b08082b, 0x08080819, 0x2b081908,
+ 0x08081908, 0x2b081908, 0x08190808, 0x2b081908, 0x08191919, 0x2b081908, 0x19080808, 0x2b081908,
+ 0x192b0808, 0x2b081908, 0x08080808, 0x2b081919, 0x1908192b, 0x2b081919, 0x2b191908, 0x2b081919,
+ 0x08082b19, 0x2b08192b, 0x19080808, 0x2b08192b, 0x192b0808, 0x2b08192b, 0x0808082b, 0x2b082b08,
+ 0x08081908, 0x2b082b19, 0x08190819, 0x2b082b2b, 0x08081908, 0x2b190808, 0x08190808, 0x2b190808,
+ 0x082b1908, 0x2b190808, 0x19080808, 0x2b190808, 0x2b2b0819, 0x2b190808, 0x0819192b, 0x2b190819,
+ 0x2b080808, 0x2b190819, 0x19081919, 0x2b19082b, 0x08080808, 0x2b191908, 0x082b082b, 0x2b191908,
+ 0x19081908, 0x2b191908, 0x19190819, 0x2b191919, 0x2b080819, 0x2b192b08, 0x082b0808, 0x2b192b19,
+ 0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
+ 0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
+);
+#enddecl(IQ2_XXS_GRID)
+
+#decl(IQ2_XS_GRID)
+const iq2xs_grid = array<u32, 1024>(
+ 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
+ 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
+ 0x08192b19, 0x08080808, 0x082b0808, 0x08080808, 0x082b082b, 0x08080808, 0x082b1919, 0x08080808,
+ 0x082b2b08, 0x08080808, 0x19080819, 0x08080808, 0x19081908, 0x08080808, 0x1908192b, 0x08080808,
+ 0x19082b19, 0x08080808, 0x19190808, 0x08080808, 0x1919082b, 0x08080808, 0x19191919, 0x08080808,
+ 0x19192b08, 0x08080808, 0x192b0819, 0x08080808, 0x192b1908, 0x08080808, 0x2b080808, 0x08080808,
+ 0x2b08082b, 0x08080808, 0x2b081919, 0x08080808, 0x2b082b08, 0x08080808, 0x2b190819, 0x08080808,
+ 0x2b191908, 0x08080808, 0x2b192b19, 0x08080808, 0x2b2b0808, 0x08080808, 0x08080819, 0x08080819,
+ 0x08081908, 0x08080819, 0x0808192b, 0x08080819, 0x08082b19, 0x08080819, 0x08190808, 0x08080819,
+ 0x0819082b, 0x08080819, 0x08191919, 0x08080819, 0x08192b08, 0x08080819, 0x08192b2b, 0x08080819,
+ 0x082b0819, 0x08080819, 0x082b1908, 0x08080819, 0x19080808, 0x08080819, 0x1908082b, 0x08080819,
+ 0x19081919, 0x08080819, 0x19082b08, 0x08080819, 0x19190819, 0x08080819, 0x19191908, 0x08080819,
+ 0x192b0808, 0x08080819, 0x192b2b08, 0x08080819, 0x2b080819, 0x08080819, 0x2b081908, 0x08080819,
+ 0x2b190808, 0x08080819, 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b, 0x08081919, 0x0808082b,
+ 0x08082b08, 0x0808082b, 0x08190819, 0x0808082b, 0x08191908, 0x0808082b, 0x082b0808, 0x0808082b,
+ 0x19080819, 0x0808082b, 0x19081908, 0x0808082b, 0x19190808, 0x0808082b, 0x19191919, 0x0808082b,
+ 0x2b080808, 0x0808082b, 0x2b082b2b, 0x0808082b, 0x08080819, 0x08081908, 0x08081908, 0x08081908,
+ 0x0808192b, 0x08081908, 0x08082b19, 0x08081908, 0x08190808, 0x08081908, 0x0819082b, 0x08081908,
+ 0x08191919, 0x08081908, 0x08192b08, 0x08081908, 0x082b0819, 0x08081908, 0x082b1908, 0x08081908,
+ 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19081919, 0x08081908, 0x19082b08, 0x08081908,
+ 0x19190819, 0x08081908, 0x19191908, 0x08081908, 0x1919192b, 0x08081908, 0x192b0808, 0x08081908,
+ 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b190808, 0x08081908, 0x08080808, 0x08081919,
+ 0x0808082b, 0x08081919, 0x08081919, 0x08081919, 0x08082b08, 0x08081919, 0x08190819, 0x08081919,
+ 0x08191908, 0x08081919, 0x082b0808, 0x08081919, 0x19080819, 0x08081919, 0x19081908, 0x08081919,
+ 0x19190808, 0x08081919, 0x192b0819, 0x08081919, 0x2b080808, 0x08081919, 0x08080819, 0x0808192b,
+ 0x08081908, 0x0808192b, 0x08190808, 0x0808192b, 0x082b192b, 0x0808192b, 0x19080808, 0x0808192b,
+ 0x1908082b, 0x0808192b, 0x2b081908, 0x0808192b, 0x08080808, 0x08082b08, 0x0808082b, 0x08082b08,
+ 0x08081919, 0x08082b08, 0x08082b08, 0x08082b08, 0x08082b2b, 0x08082b08, 0x08190819, 0x08082b08,
+ 0x08191908, 0x08082b08, 0x082b0808, 0x08082b08, 0x082b1919, 0x08082b08, 0x19080819, 0x08082b08,
+ 0x19081908, 0x08082b08, 0x19190808, 0x08082b08, 0x19192b08, 0x08082b08, 0x2b080808, 0x08082b08,
+ 0x2b2b0808, 0x08082b08, 0x2b2b2b2b, 0x08082b08, 0x08080819, 0x08082b19, 0x08081908, 0x08082b19,
+ 0x08190808, 0x08082b19, 0x19080808, 0x08082b19, 0x2b080819, 0x08082b19, 0x2b082b19, 0x08082b19,
+ 0x08080808, 0x08082b2b, 0x082b0808, 0x08082b2b, 0x082b2b08, 0x08082b2b, 0x2b19192b, 0x08082b2b,
+ 0x2b2b0808, 0x08082b2b, 0x08080819, 0x08190808, 0x08081908, 0x08190808, 0x0808192b, 0x08190808,
+ 0x08082b19, 0x08190808, 0x08190808, 0x08190808, 0x0819082b, 0x08190808, 0x08191919, 0x08190808,
+ 0x08192b08, 0x08190808, 0x082b0819, 0x08190808, 0x082b1908, 0x08190808, 0x19080808, 0x08190808,
+ 0x1908082b, 0x08190808, 0x19081919, 0x08190808, 0x19082b08, 0x08190808, 0x19190819, 0x08190808,
+ 0x19191908, 0x08190808, 0x192b0808, 0x08190808, 0x192b2b2b, 0x08190808, 0x2b080819, 0x08190808,
+ 0x2b081908, 0x08190808, 0x2b190808, 0x08190808, 0x08080808, 0x08190819, 0x0808082b, 0x08190819,
+ 0x08081919, 0x08190819, 0x08082b08, 0x08190819, 0x08190819, 0x08190819, 0x08191908, 0x08190819,
+ 0x082b0808, 0x08190819, 0x19080819, 0x08190819, 0x19081908, 0x08190819, 0x19190808, 0x08190819,
+ 0x2b080808, 0x08190819, 0x2b191908, 0x08190819, 0x2b19192b, 0x08190819, 0x08080819, 0x0819082b,
+ 0x08081908, 0x0819082b, 0x0808192b, 0x0819082b, 0x08190808, 0x0819082b, 0x19080808, 0x0819082b,
+ 0x192b0808, 0x0819082b, 0x08080808, 0x08191908, 0x0808082b, 0x08191908, 0x08081919, 0x08191908,
+ 0x08082b08, 0x08191908, 0x08190819, 0x08191908, 0x08191908, 0x08191908, 0x082b0808, 0x08191908,
+ 0x19080819, 0x08191908, 0x19081908, 0x08191908, 0x19082b19, 0x08191908, 0x19190808, 0x08191908,
+ 0x192b1908, 0x08191908, 0x2b080808, 0x08191908, 0x08080819, 0x08191919, 0x08081908, 0x08191919,
+ 0x08190808, 0x08191919, 0x19080808, 0x08191919, 0x08080808, 0x0819192b, 0x08191908, 0x0819192b,
+ 0x19082b19, 0x0819192b, 0x08080819, 0x08192b08, 0x08081908, 0x08192b08, 0x08190808, 0x08192b08,
+ 0x0819082b, 0x08192b08, 0x19080808, 0x08192b08, 0x19191908, 0x08192b08, 0x2b08192b, 0x08192b08,
+ 0x08080808, 0x08192b19, 0x08081919, 0x08192b19, 0x192b192b, 0x08192b19, 0x19190819, 0x08192b2b,
+ 0x2b2b2b19, 0x08192b2b, 0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08081919, 0x082b0808,
+ 0x08082b08, 0x082b0808, 0x08082b2b, 0x082b0808, 0x08190819, 0x082b0808, 0x08191908, 0x082b0808,
+ 0x082b0808, 0x082b0808, 0x19080819, 0x082b0808, 0x19081908, 0x082b0808, 0x19190808, 0x082b0808,
+ 0x2b080808, 0x082b0808, 0x2b2b0808, 0x082b0808, 0x08080819, 0x082b0819, 0x08081908, 0x082b0819,
+ 0x08190808, 0x082b0819, 0x19080808, 0x082b0819, 0x19082b08, 0x082b0819, 0x192b1919, 0x082b0819,
+ 0x08080808, 0x082b082b, 0x082b082b, 0x082b082b, 0x2b080808, 0x082b082b, 0x2b2b2b08, 0x082b082b,
+ 0x08080819, 0x082b1908, 0x08081908, 0x082b1908, 0x08190808, 0x082b1908, 0x082b2b19, 0x082b1908,
+ 0x19080808, 0x082b1908, 0x08080808, 0x082b1919, 0x19080819, 0x082b1919, 0x1919082b, 0x082b1919,
+ 0x2b192b19, 0x082b1919, 0x08080819, 0x082b192b, 0x08192b2b, 0x082b192b, 0x2b2b192b, 0x082b192b,
+ 0x08080808, 0x082b2b08, 0x08082b08, 0x082b2b08, 0x08082b2b, 0x082b2b08, 0x082b0808, 0x082b2b08,
+ 0x19191919, 0x082b2b08, 0x2b082b08, 0x082b2b08, 0x2b2b082b, 0x082b2b08, 0x192b2b08, 0x082b2b19,
+ 0x2b190808, 0x082b2b19, 0x08082b08, 0x082b2b2b, 0x082b0808, 0x082b2b2b, 0x2b08082b, 0x082b2b2b,
+ 0x2b082b08, 0x082b2b2b, 0x2b082b2b, 0x082b2b2b, 0x08080819, 0x19080808, 0x08081908, 0x19080808,
+ 0x0808192b, 0x19080808, 0x08082b19, 0x19080808, 0x08190808, 0x19080808, 0x0819082b, 0x19080808,
+ 0x08191919, 0x19080808, 0x08192b08, 0x19080808, 0x082b0819, 0x19080808, 0x082b1908, 0x19080808,
+ 0x19080808, 0x19080808, 0x1908082b, 0x19080808, 0x19081919, 0x19080808, 0x19082b08, 0x19080808,
+ 0x19082b2b, 0x19080808, 0x19190819, 0x19080808, 0x19191908, 0x19080808, 0x192b0808, 0x19080808,
+ 0x192b1919, 0x19080808, 0x2b080819, 0x19080808, 0x2b081908, 0x19080808, 0x2b190808, 0x19080808,
+ 0x08080808, 0x19080819, 0x0808082b, 0x19080819, 0x08081919, 0x19080819, 0x08082b08, 0x19080819,
+ 0x08190819, 0x19080819, 0x08191908, 0x19080819, 0x082b0808, 0x19080819, 0x19080819, 0x19080819,
+ 0x19081908, 0x19080819, 0x19190808, 0x19080819, 0x2b080808, 0x19080819, 0x2b081919, 0x19080819,
+ 0x2b2b082b, 0x19080819, 0x08080819, 0x1908082b, 0x08081908, 0x1908082b, 0x08190808, 0x1908082b,
+ 0x0819082b, 0x1908082b, 0x082b2b19, 0x1908082b, 0x19080808, 0x1908082b, 0x08080808, 0x19081908,
+ 0x0808082b, 0x19081908, 0x08081919, 0x19081908, 0x08082b08, 0x19081908, 0x08190819, 0x19081908,
+ 0x08191908, 0x19081908, 0x08192b19, 0x19081908, 0x082b0808, 0x19081908, 0x19080819, 0x19081908,
+ 0x19081908, 0x19081908, 0x19190808, 0x19081908, 0x2b080808, 0x19081908, 0x2b191908, 0x19081908,
+ 0x08080819, 0x19081919, 0x08081908, 0x19081919, 0x08190808, 0x19081919, 0x082b1908, 0x19081919,
+ 0x19080808, 0x19081919, 0x2b192b2b, 0x19081919, 0x08080808, 0x1908192b, 0x08082b2b, 0x1908192b,
+ 0x19081908, 0x1908192b, 0x19190808, 0x1908192b, 0x08080819, 0x19082b08, 0x08081908, 0x19082b08,
+ 0x08190808, 0x19082b08, 0x19080808, 0x19082b08, 0x19081919, 0x19082b08, 0x19191908, 0x19082b08,
+ 0x192b082b, 0x19082b08, 0x08080808, 0x19082b19, 0x08190819, 0x19082b19, 0x19081908, 0x19082b19,
+ 0x19190808, 0x19082b19, 0x192b2b19, 0x19082b19, 0x08081908, 0x19082b2b, 0x08080808, 0x19190808,
+ 0x0808082b, 0x19190808, 0x08081919, 0x19190808, 0x08082b08, 0x19190808, 0x08190819, 0x19190808,
+ 0x08191908, 0x19190808, 0x082b0808, 0x19190808, 0x082b2b08, 0x19190808, 0x19080819, 0x19190808,
+ 0x19081908, 0x19190808, 0x19190808, 0x19190808, 0x2b080808, 0x19190808, 0x08080819, 0x19190819,
+ 0x08081908, 0x19190819, 0x08190808, 0x19190819, 0x08191919, 0x19190819, 0x19080808, 0x19190819,
+ 0x1908082b, 0x19190819, 0x08080808, 0x1919082b, 0x19081908, 0x1919082b, 0x2b2b2b2b, 0x1919082b,
+ 0x08080819, 0x19191908, 0x08081908, 0x19191908, 0x08190808, 0x19191908, 0x082b0819, 0x19191908,
+ 0x19080808, 0x19191908, 0x192b0808, 0x19191908, 0x2b080819, 0x19191908, 0x2b2b0819, 0x19191908,
+ 0x08080808, 0x19191919, 0x08082b08, 0x19191919, 0x2b080808, 0x19191919, 0x2b082b08, 0x19191919,
+ 0x082b0819, 0x1919192b, 0x192b2b08, 0x1919192b, 0x2b2b0819, 0x1919192b, 0x08080808, 0x19192b08,
+ 0x08191908, 0x19192b08, 0x19080819, 0x19192b08, 0x19190808, 0x19192b08, 0x2b192b19, 0x19192b08,
+ 0x08192b2b, 0x19192b19, 0x19080808, 0x19192b19, 0x1908082b, 0x19192b19, 0x2b081919, 0x19192b2b,
+ 0x08080819, 0x192b0808, 0x08081908, 0x192b0808, 0x08190808, 0x192b0808, 0x19080808, 0x192b0808,
+ 0x19191908, 0x192b0808, 0x192b082b, 0x192b0808, 0x2b08192b, 0x192b0808, 0x2b2b2b19, 0x192b0808,
+ 0x08080808, 0x192b0819, 0x082b1908, 0x192b082b, 0x19082b2b, 0x192b082b, 0x2b19082b, 0x192b082b,
+ 0x08080808, 0x192b1908, 0x0819192b, 0x192b1908, 0x08190808, 0x192b1919, 0x19080808, 0x192b1919,
+ 0x19081919, 0x192b1919, 0x2b2b1908, 0x192b1919, 0x08080819, 0x192b2b08, 0x192b2b2b, 0x192b2b08,
+ 0x082b1919, 0x192b2b19, 0x0808192b, 0x192b2b2b, 0x19191908, 0x192b2b2b, 0x192b082b, 0x192b2b2b,
+ 0x08080808, 0x2b080808, 0x0808082b, 0x2b080808, 0x08081919, 0x2b080808, 0x08082b08, 0x2b080808,
+ 0x08190819, 0x2b080808, 0x08191908, 0x2b080808, 0x082b0808, 0x2b080808, 0x082b2b2b, 0x2b080808,
+ 0x19080819, 0x2b080808, 0x19081908, 0x2b080808, 0x19190808, 0x2b080808, 0x2b080808, 0x2b080808,
+ 0x2b08082b, 0x2b080808, 0x2b2b2b08, 0x2b080808, 0x2b2b2b2b, 0x2b080808, 0x08080819, 0x2b080819,
+ 0x08081908, 0x2b080819, 0x0808192b, 0x2b080819, 0x08190808, 0x2b080819, 0x19080808, 0x2b080819,
+ 0x19190819, 0x2b080819, 0x19192b19, 0x2b080819, 0x08080808, 0x2b08082b, 0x082b0808, 0x2b08082b,
+ 0x2b080808, 0x2b08082b, 0x2b08082b, 0x2b08082b, 0x2b2b0808, 0x2b08082b, 0x2b2b2b08, 0x2b08082b,
+ 0x08080819, 0x2b081908, 0x08081908, 0x2b081908, 0x08190808, 0x2b081908, 0x0819082b, 0x2b081908,
+ 0x08191919, 0x2b081908, 0x19080808, 0x2b081908, 0x192b0808, 0x2b081908, 0x2b082b19, 0x2b081908,
+ 0x08080808, 0x2b081919, 0x19081908, 0x2b081919, 0x2b2b1919, 0x2b081919, 0x08192b08, 0x2b08192b,
+ 0x192b2b2b, 0x2b08192b, 0x08080808, 0x2b082b08, 0x08082b08, 0x2b082b08, 0x082b1919, 0x2b082b08,
+ 0x19192b2b, 0x2b082b08, 0x2b080808, 0x2b082b08, 0x2b08082b, 0x2b082b08, 0x2b2b2b08, 0x2b082b08,
+ 0x0808192b, 0x2b082b19, 0x082b082b, 0x2b082b2b, 0x2b080808, 0x2b082b2b, 0x2b082b08, 0x2b082b2b,
+ 0x2b19192b, 0x2b082b2b, 0x2b2b2b08, 0x2b082b2b, 0x08080819, 0x2b190808, 0x08081908, 0x2b190808,
+ 0x08190808, 0x2b190808, 0x19080808, 0x2b190808, 0x1919192b, 0x2b190808, 0x2b081908, 0x2b190808,
+ 0x08080808, 0x2b190819, 0x082b082b, 0x2b190819, 0x192b1908, 0x2b190819, 0x1919192b, 0x2b19082b,
+ 0x2b082b19, 0x2b19082b, 0x08080808, 0x2b191908, 0x08081919, 0x2b191908, 0x19081908, 0x2b191908,
+ 0x19190808, 0x2b191908, 0x19192b08, 0x2b191908, 0x082b2b19, 0x2b191919, 0x2b190808, 0x2b191919,
+ 0x2b19082b, 0x2b191919, 0x19080819, 0x2b19192b, 0x19190819, 0x2b192b08, 0x2b2b192b, 0x2b192b08,
+ 0x19082b19, 0x2b192b19, 0x08191919, 0x2b192b2b, 0x192b0808, 0x2b192b2b, 0x08080808, 0x2b2b0808,
+ 0x0808082b, 0x2b2b0808, 0x08082b08, 0x2b2b0808, 0x08082b2b, 0x2b2b0808, 0x082b0808, 0x2b2b0808,
+ 0x082b2b2b, 0x2b2b0808, 0x2b2b0808, 0x2b2b0808, 0x19190819, 0x2b2b0819, 0x19192b19, 0x2b2b0819,
+ 0x2b2b192b, 0x2b2b0819, 0x08080808, 0x2b2b082b, 0x0808082b, 0x2b2b082b, 0x08082b08, 0x2b2b082b,
+ 0x082b2b2b, 0x2b2b082b, 0x2b080808, 0x2b2b082b, 0x2b2b0808, 0x2b2b082b, 0x19080808, 0x2b2b1908,
+ 0x2b191919, 0x2b2b1908, 0x192b1919, 0x2b2b192b, 0x2b192b08, 0x2b2b192b, 0x08082b2b, 0x2b2b2b08,
+ 0x082b0808, 0x2b2b2b08, 0x082b082b, 0x2b2b2b08, 0x082b2b08, 0x2b2b2b08, 0x2b2b0808, 0x2b2b2b08,
+ 0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
+ 0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
+);
+#enddecl(IQ2_XS_GRID)
+
+#decl(IQ2_S_GRID)
+const iq2s_grid = array<u32, 2048>(
+ 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
+ 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
+ 0x08192b19, 0x08080808, 0x082b0808, 0x08080808, 0x082b082b, 0x08080808, 0x082b1919, 0x08080808,
+ 0x082b2b08, 0x08080808, 0x19080819, 0x08080808, 0x19081908, 0x08080808, 0x1908192b, 0x08080808,
+ 0x19082b19, 0x08080808, 0x19190808, 0x08080808, 0x1919082b, 0x08080808, 0x19191919, 0x08080808,
+ 0x19192b08, 0x08080808, 0x192b0819, 0x08080808, 0x192b1908, 0x08080808, 0x192b192b, 0x08080808,
+ 0x192b2b19, 0x08080808, 0x2b080808, 0x08080808, 0x2b08082b, 0x08080808, 0x2b081919, 0x08080808,
+ 0x2b082b08, 0x08080808, 0x2b190819, 0x08080808, 0x2b191908, 0x08080808, 0x2b2b0808, 0x08080808,
+ 0x2b2b1919, 0x08080808, 0x2b2b2b2b, 0x08080808, 0x08080819, 0x08080819, 0x08081908, 0x08080819,
+ 0x0808192b, 0x08080819, 0x08082b19, 0x08080819, 0x08190808, 0x08080819, 0x0819082b, 0x08080819,
+ 0x08191919, 0x08080819, 0x08192b08, 0x08080819, 0x082b0819, 0x08080819, 0x082b1908, 0x08080819,
+ 0x19080808, 0x08080819, 0x1908082b, 0x08080819, 0x19081919, 0x08080819, 0x19082b08, 0x08080819,
+ 0x19190819, 0x08080819, 0x19191908, 0x08080819, 0x1919192b, 0x08080819, 0x19192b19, 0x08080819,
+ 0x192b0808, 0x08080819, 0x192b1919, 0x08080819, 0x192b2b08, 0x08080819, 0x2b080819, 0x08080819,
+ 0x2b081908, 0x08080819, 0x2b190808, 0x08080819, 0x2b19082b, 0x08080819, 0x2b191919, 0x08080819,
+ 0x2b2b0819, 0x08080819, 0x2b2b1908, 0x08080819, 0x08080808, 0x0808082b, 0x0808082b, 0x0808082b,
+ 0x08081919, 0x0808082b, 0x08082b08, 0x0808082b, 0x08190819, 0x0808082b, 0x08191908, 0x0808082b,
+ 0x082b0808, 0x0808082b, 0x082b2b2b, 0x0808082b, 0x19080819, 0x0808082b, 0x19081908, 0x0808082b,
+ 0x1908192b, 0x0808082b, 0x19082b19, 0x0808082b, 0x19190808, 0x0808082b, 0x19191919, 0x0808082b,
+ 0x2b080808, 0x0808082b, 0x2b081919, 0x0808082b, 0x2b082b2b, 0x0808082b, 0x2b191908, 0x0808082b,
+ 0x2b2b082b, 0x0808082b, 0x08080819, 0x08081908, 0x08081908, 0x08081908, 0x0808192b, 0x08081908,
+ 0x08082b19, 0x08081908, 0x08190808, 0x08081908, 0x0819082b, 0x08081908, 0x08191919, 0x08081908,
+ 0x08192b08, 0x08081908, 0x082b0819, 0x08081908, 0x082b1908, 0x08081908, 0x082b192b, 0x08081908,
+ 0x082b2b19, 0x08081908, 0x19080808, 0x08081908, 0x1908082b, 0x08081908, 0x19081919, 0x08081908,
+ 0x19082b08, 0x08081908, 0x19082b2b, 0x08081908, 0x19190819, 0x08081908, 0x19191908, 0x08081908,
+ 0x1919192b, 0x08081908, 0x19192b19, 0x08081908, 0x192b0808, 0x08081908, 0x192b082b, 0x08081908,
+ 0x192b1919, 0x08081908, 0x2b080819, 0x08081908, 0x2b081908, 0x08081908, 0x2b08192b, 0x08081908,
+ 0x2b082b19, 0x08081908, 0x2b190808, 0x08081908, 0x2b191919, 0x08081908, 0x2b192b08, 0x08081908,
+ 0x2b2b0819, 0x08081908, 0x2b2b1908, 0x08081908, 0x08080808, 0x08081919, 0x0808082b, 0x08081919,
+ 0x08081919, 0x08081919, 0x08082b08, 0x08081919, 0x08082b2b, 0x08081919, 0x08190819, 0x08081919,
+ 0x08191908, 0x08081919, 0x0819192b, 0x08081919, 0x08192b19, 0x08081919, 0x082b0808, 0x08081919,
+ 0x082b1919, 0x08081919, 0x082b2b08, 0x08081919, 0x19080819, 0x08081919, 0x19081908, 0x08081919,
+ 0x1908192b, 0x08081919, 0x19082b19, 0x08081919, 0x19190808, 0x08081919, 0x1919082b, 0x08081919,
+ 0x19191919, 0x08081919, 0x19192b08, 0x08081919, 0x192b0819, 0x08081919, 0x192b1908, 0x08081919,
+ 0x2b080808, 0x08081919, 0x2b08082b, 0x08081919, 0x2b081919, 0x08081919, 0x2b082b08, 0x08081919,
+ 0x2b190819, 0x08081919, 0x2b191908, 0x08081919, 0x2b2b0808, 0x08081919, 0x08080819, 0x0808192b,
+ 0x08081908, 0x0808192b, 0x0808192b, 0x0808192b, 0x08082b19, 0x0808192b, 0x08190808, 0x0808192b,
+ 0x08191919, 0x0808192b, 0x19080808, 0x0808192b, 0x19081919, 0x0808192b, 0x19082b08, 0x0808192b,
+ 0x19190819, 0x0808192b, 0x19191908, 0x0808192b, 0x192b0808, 0x0808192b, 0x2b080819, 0x0808192b,
+ 0x2b081908, 0x0808192b, 0x2b190808, 0x0808192b, 0x08080808, 0x08082b08, 0x0808082b, 0x08082b08,
+ 0x08081919, 0x08082b08, 0x08082b08, 0x08082b08, 0x08190819, 0x08082b08, 0x08191908, 0x08082b08,
+ 0x0819192b, 0x08082b08, 0x08192b19, 0x08082b08, 0x082b0808, 0x08082b08, 0x082b1919, 0x08082b08,
+ 0x082b2b2b, 0x08082b08, 0x19080819, 0x08082b08, 0x19081908, 0x08082b08, 0x1908192b, 0x08082b08,
+ 0x19082b19, 0x08082b08, 0x19190808, 0x08082b08, 0x1919082b, 0x08082b08, 0x19191919, 0x08082b08,
+ 0x19192b08, 0x08082b08, 0x192b0819, 0x08082b08, 0x192b1908, 0x08082b08, 0x2b080808, 0x08082b08,
+ 0x2b081919, 0x08082b08, 0x2b191908, 0x08082b08, 0x2b2b2b2b, 0x08082b08, 0x08080819, 0x08082b19,
+ 0x08081908, 0x08082b19, 0x08190808, 0x08082b19, 0x0819082b, 0x08082b19, 0x08191919, 0x08082b19,
+ 0x08192b08, 0x08082b19, 0x082b0819, 0x08082b19, 0x19080808, 0x08082b19, 0x19081919, 0x08082b19,
+ 0x19082b08, 0x08082b19, 0x19190819, 0x08082b19, 0x19191908, 0x08082b19, 0x192b0808, 0x08082b19,
+ 0x2b080819, 0x08082b19, 0x2b190808, 0x08082b19, 0x08080808, 0x08082b2b, 0x08190819, 0x08082b2b,
+ 0x08191908, 0x08082b2b, 0x082b082b, 0x08082b2b, 0x082b2b08, 0x08082b2b, 0x082b2b2b, 0x08082b2b,
+ 0x19190808, 0x08082b2b, 0x2b192b19, 0x08082b2b, 0x08080819, 0x08190808, 0x08081908, 0x08190808,
+ 0x0808192b, 0x08190808, 0x08082b19, 0x08190808, 0x08190808, 0x08190808, 0x0819082b, 0x08190808,
+ 0x08191919, 0x08190808, 0x08192b08, 0x08190808, 0x082b0819, 0x08190808, 0x082b1908, 0x08190808,
+ 0x082b192b, 0x08190808, 0x19080808, 0x08190808, 0x1908082b, 0x08190808, 0x19081919, 0x08190808,
+ 0x19082b08, 0x08190808, 0x19190819, 0x08190808, 0x19191908, 0x08190808, 0x1919192b, 0x08190808,
+ 0x19192b19, 0x08190808, 0x192b0808, 0x08190808, 0x192b082b, 0x08190808, 0x192b1919, 0x08190808,
+ 0x192b2b08, 0x08190808, 0x2b080819, 0x08190808, 0x2b081908, 0x08190808, 0x2b08192b, 0x08190808,
+ 0x2b190808, 0x08190808, 0x2b191919, 0x08190808, 0x2b192b08, 0x08190808, 0x2b2b0819, 0x08190808,
+ 0x2b2b1908, 0x08190808, 0x08080808, 0x08190819, 0x0808082b, 0x08190819, 0x08081919, 0x08190819,
+ 0x08082b08, 0x08190819, 0x08082b2b, 0x08190819, 0x08190819, 0x08190819, 0x08191908, 0x08190819,
+ 0x0819192b, 0x08190819, 0x08192b19, 0x08190819, 0x082b0808, 0x08190819, 0x082b082b, 0x08190819,
+ 0x082b1919, 0x08190819, 0x082b2b08, 0x08190819, 0x19080819, 0x08190819, 0x19081908, 0x08190819,
+ 0x1908192b, 0x08190819, 0x19082b19, 0x08190819, 0x19190808, 0x08190819, 0x1919082b, 0x08190819,
+ 0x19191919, 0x08190819, 0x19192b08, 0x08190819, 0x192b0819, 0x08190819, 0x192b1908, 0x08190819,
+ 0x2b080808, 0x08190819, 0x2b08082b, 0x08190819, 0x2b081919, 0x08190819, 0x2b082b08, 0x08190819,
+ 0x2b190819, 0x08190819, 0x2b191908, 0x08190819, 0x08080819, 0x0819082b, 0x08081908, 0x0819082b,
+ 0x08082b19, 0x0819082b, 0x08190808, 0x0819082b, 0x08191919, 0x0819082b, 0x082b0819, 0x0819082b,
+ 0x082b1908, 0x0819082b, 0x19080808, 0x0819082b, 0x19081919, 0x0819082b, 0x19190819, 0x0819082b,
+ 0x19191908, 0x0819082b, 0x2b080819, 0x0819082b, 0x2b081908, 0x0819082b, 0x2b190808, 0x0819082b,
+ 0x08080808, 0x08191908, 0x0808082b, 0x08191908, 0x08081919, 0x08191908, 0x08082b08, 0x08191908,
+ 0x08190819, 0x08191908, 0x08191908, 0x08191908, 0x0819192b, 0x08191908, 0x08192b19, 0x08191908,
+ 0x082b0808, 0x08191908, 0x082b1919, 0x08191908, 0x082b2b08, 0x08191908, 0x19080819, 0x08191908,
+ 0x19081908, 0x08191908, 0x1908192b, 0x08191908, 0x19082b19, 0x08191908, 0x19190808, 0x08191908,
+ 0x1919082b, 0x08191908, 0x19191919, 0x08191908, 0x19192b08, 0x08191908, 0x192b0819, 0x08191908,
+ 0x192b1908, 0x08191908, 0x2b080808, 0x08191908, 0x2b08082b, 0x08191908, 0x2b081919, 0x08191908,
+ 0x2b082b08, 0x08191908, 0x2b190819, 0x08191908, 0x2b191908, 0x08191908, 0x2b2b0808, 0x08191908,
+ 0x08080819, 0x08191919, 0x08081908, 0x08191919, 0x0808192b, 0x08191919, 0x08082b19, 0x08191919,
+ 0x08190808, 0x08191919, 0x0819082b, 0x08191919, 0x08191919, 0x08191919, 0x08192b08, 0x08191919,
+ 0x082b0819, 0x08191919, 0x082b1908, 0x08191919, 0x19080808, 0x08191919, 0x1908082b, 0x08191919,
+ 0x19081919, 0x08191919, 0x19082b08, 0x08191919, 0x19190819, 0x08191919, 0x19191908, 0x08191919,
+ 0x192b0808, 0x08191919, 0x2b080819, 0x08191919, 0x2b081908, 0x08191919, 0x2b190808, 0x08191919,
+ 0x08080808, 0x0819192b, 0x08081919, 0x0819192b, 0x08082b08, 0x0819192b, 0x08190819, 0x0819192b,
+ 0x08191908, 0x0819192b, 0x082b0808, 0x0819192b, 0x19080819, 0x0819192b, 0x19081908, 0x0819192b,
+ 0x19190808, 0x0819192b, 0x2b080808, 0x0819192b, 0x2b2b2b2b, 0x0819192b, 0x08080819, 0x08192b08,
+ 0x08081908, 0x08192b08, 0x0808192b, 0x08192b08, 0x08082b19, 0x08192b08, 0x08190808, 0x08192b08,
+ 0x08191919, 0x08192b08, 0x08192b08, 0x08192b08, 0x082b0819, 0x08192b08, 0x19080808, 0x08192b08,
+ 0x1908082b, 0x08192b08, 0x19081919, 0x08192b08, 0x19082b08, 0x08192b08, 0x19190819, 0x08192b08,
+ 0x19191908, 0x08192b08, 0x192b0808, 0x08192b08, 0x2b080819, 0x08192b08, 0x2b081908, 0x08192b08,
+ 0x08080808, 0x08192b19, 0x0808082b, 0x08192b19, 0x08081919, 0x08192b19, 0x08082b08, 0x08192b19,
+ 0x08190819, 0x08192b19, 0x08191908, 0x08192b19, 0x082b0808, 0x08192b19, 0x19080819, 0x08192b19,
+ 0x19081908, 0x08192b19, 0x19190808, 0x08192b19, 0x192b2b19, 0x08192b19, 0x2b2b082b, 0x08192b19,
+ 0x08081908, 0x08192b2b, 0x08190808, 0x08192b2b, 0x19080808, 0x08192b2b, 0x1919192b, 0x08192b2b,
+ 0x08080808, 0x082b0808, 0x0808082b, 0x082b0808, 0x08081919, 0x082b0808, 0x08082b08, 0x082b0808,
+ 0x08190819, 0x082b0808, 0x08191908, 0x082b0808, 0x0819192b, 0x082b0808, 0x08192b19, 0x082b0808,
+ 0x082b0808, 0x082b0808, 0x082b1919, 0x082b0808, 0x082b2b2b, 0x082b0808, 0x19080819, 0x082b0808,
+ 0x19081908, 0x082b0808, 0x19190808, 0x082b0808, 0x1919082b, 0x082b0808, 0x19191919, 0x082b0808,
+ 0x192b1908, 0x082b0808, 0x2b080808, 0x082b0808, 0x2b082b2b, 0x082b0808, 0x2b191908, 0x082b0808,
+ 0x2b2b2b2b, 0x082b0808, 0x08080819, 0x082b0819, 0x08081908, 0x082b0819, 0x08190808, 0x082b0819,
+ 0x0819082b, 0x082b0819, 0x08191919, 0x082b0819, 0x082b0819, 0x082b0819, 0x19080808, 0x082b0819,
+ 0x1908082b, 0x082b0819, 0x19081919, 0x082b0819, 0x19190819, 0x082b0819, 0x19191908, 0x082b0819,
+ 0x192b0808, 0x082b0819, 0x2b080819, 0x082b0819, 0x2b081908, 0x082b0819, 0x2b190808, 0x082b0819,
+ 0x08080808, 0x082b082b, 0x08082b2b, 0x082b082b, 0x082b082b, 0x082b082b, 0x082b2b08, 0x082b082b,
+ 0x082b2b2b, 0x082b082b, 0x19081908, 0x082b082b, 0x19190808, 0x082b082b, 0x2b082b08, 0x082b082b,
+ 0x2b082b2b, 0x082b082b, 0x2b2b2b08, 0x082b082b, 0x08080819, 0x082b1908, 0x08081908, 0x082b1908,
+ 0x0808192b, 0x082b1908, 0x08082b19, 0x082b1908, 0x08190808, 0x082b1908, 0x08191919, 0x082b1908,
+ 0x08192b08, 0x082b1908, 0x082b0819, 0x082b1908, 0x082b1908, 0x082b1908, 0x19080808, 0x082b1908,
+ 0x1908082b, 0x082b1908, 0x19081919, 0x082b1908, 0x19082b08, 0x082b1908, 0x19190819, 0x082b1908,
+ 0x19191908, 0x082b1908, 0x192b0808, 0x082b1908, 0x2b080819, 0x082b1908, 0x2b081908, 0x082b1908,
+ 0x2b190808, 0x082b1908, 0x08080808, 0x082b1919, 0x08081919, 0x082b1919, 0x08082b08, 0x082b1919,
+ 0x08190819, 0x082b1919, 0x08191908, 0x082b1919, 0x082b0808, 0x082b1919, 0x19080819, 0x082b1919,
+ 0x19081908, 0x082b1919, 0x19190808, 0x082b1919, 0x192b192b, 0x082b1919, 0x2b080808, 0x082b1919,
+ 0x08080819, 0x082b192b, 0x08081908, 0x082b192b, 0x08190808, 0x082b192b, 0x19080808, 0x082b192b,
+ 0x19192b19, 0x082b192b, 0x08080808, 0x082b2b08, 0x08081919, 0x082b2b08, 0x08190819, 0x082b2b08,
+ 0x08191908, 0x082b2b08, 0x19080819, 0x082b2b08, 0x19081908, 0x082b2b08, 0x19190808, 0x082b2b08,
+ 0x2b082b2b, 0x082b2b08, 0x2b2b2b2b, 0x082b2b08, 0x08080819, 0x082b2b19, 0x08081908, 0x082b2b19,
+ 0x08190808, 0x082b2b19, 0x2b191919, 0x082b2b19, 0x08082b2b, 0x082b2b2b, 0x082b082b, 0x082b2b2b,
+ 0x192b1908, 0x082b2b2b, 0x2b082b08, 0x082b2b2b, 0x2b082b2b, 0x082b2b2b, 0x08080819, 0x19080808,
+ 0x08081908, 0x19080808, 0x0808192b, 0x19080808, 0x08082b19, 0x19080808, 0x08190808, 0x19080808,
+ 0x0819082b, 0x19080808, 0x08191919, 0x19080808, 0x08192b08, 0x19080808, 0x08192b2b, 0x19080808,
+ 0x082b0819, 0x19080808, 0x082b1908, 0x19080808, 0x082b192b, 0x19080808, 0x19080808, 0x19080808,
+ 0x1908082b, 0x19080808, 0x19081919, 0x19080808, 0x19082b08, 0x19080808, 0x19082b2b, 0x19080808,
+ 0x19190819, 0x19080808, 0x19191908, 0x19080808, 0x1919192b, 0x19080808, 0x19192b19, 0x19080808,
+ 0x192b0808, 0x19080808, 0x192b082b, 0x19080808, 0x192b1919, 0x19080808, 0x2b080819, 0x19080808,
+ 0x2b081908, 0x19080808, 0x2b190808, 0x19080808, 0x2b191919, 0x19080808, 0x2b192b08, 0x19080808,
+ 0x2b2b0819, 0x19080808, 0x2b2b1908, 0x19080808, 0x08080808, 0x19080819, 0x0808082b, 0x19080819,
+ 0x08081919, 0x19080819, 0x08082b08, 0x19080819, 0x08190819, 0x19080819, 0x08191908, 0x19080819,
+ 0x0819192b, 0x19080819, 0x08192b19, 0x19080819, 0x082b0808, 0x19080819, 0x082b082b, 0x19080819,
+ 0x082b1919, 0x19080819, 0x19080819, 0x19080819, 0x19081908, 0x19080819, 0x1908192b, 0x19080819,
+ 0x19082b19, 0x19080819, 0x19190808, 0x19080819, 0x1919082b, 0x19080819, 0x19191919, 0x19080819,
+ 0x19192b08, 0x19080819, 0x192b0819, 0x19080819, 0x192b1908, 0x19080819, 0x2b080808, 0x19080819,
+ 0x2b08082b, 0x19080819, 0x2b081919, 0x19080819, 0x2b082b08, 0x19080819, 0x2b190819, 0x19080819,
+ 0x2b191908, 0x19080819, 0x2b2b0808, 0x19080819, 0x08080819, 0x1908082b, 0x08081908, 0x1908082b,
+ 0x08190808, 0x1908082b, 0x0819082b, 0x1908082b, 0x08191919, 0x1908082b, 0x08192b08, 0x1908082b,
+ 0x082b1908, 0x1908082b, 0x19080808, 0x1908082b, 0x19081919, 0x1908082b, 0x19082b08, 0x1908082b,
+ 0x19190819, 0x1908082b, 0x19191908, 0x1908082b, 0x192b0808, 0x1908082b, 0x2b080819, 0x1908082b,
+ 0x2b081908, 0x1908082b, 0x08080808, 0x19081908, 0x0808082b, 0x19081908, 0x08081919, 0x19081908,
+ 0x08082b08, 0x19081908, 0x08082b2b, 0x19081908, 0x08190819, 0x19081908, 0x08191908, 0x19081908,
+ 0x0819192b, 0x19081908, 0x08192b19, 0x19081908, 0x082b0808, 0x19081908, 0x082b082b, 0x19081908,
+ 0x082b1919, 0x19081908, 0x082b2b08, 0x19081908, 0x19080819, 0x19081908, 0x19081908, 0x19081908,
+ 0x1908192b, 0x19081908, 0x19082b19, 0x19081908, 0x19190808, 0x19081908, 0x1919082b, 0x19081908,
+ 0x19191919, 0x19081908, 0x19192b08, 0x19081908, 0x192b0819, 0x19081908, 0x192b1908, 0x19081908,
+ 0x2b080808, 0x19081908, 0x2b08082b, 0x19081908, 0x2b081919, 0x19081908, 0x2b082b08, 0x19081908,
+ 0x2b190819, 0x19081908, 0x2b191908, 0x19081908, 0x2b2b0808, 0x19081908, 0x08080819, 0x19081919,
+ 0x08081908, 0x19081919, 0x0808192b, 0x19081919, 0x08082b19, 0x19081919, 0x08190808, 0x19081919,
+ 0x0819082b, 0x19081919, 0x08191919, 0x19081919, 0x08192b08, 0x19081919, 0x082b0819, 0x19081919,
+ 0x082b1908, 0x19081919, 0x19080808, 0x19081919, 0x1908082b, 0x19081919, 0x19081919, 0x19081919,
+ 0x19082b08, 0x19081919, 0x19190819, 0x19081919, 0x19191908, 0x19081919, 0x192b0808, 0x19081919,
+ 0x192b2b2b, 0x19081919, 0x2b080819, 0x19081919, 0x2b081908, 0x19081919, 0x2b190808, 0x19081919,
+ 0x08080808, 0x1908192b, 0x0808082b, 0x1908192b, 0x08081919, 0x1908192b, 0x08082b08, 0x1908192b,
+ 0x08190819, 0x1908192b, 0x08191908, 0x1908192b, 0x082b0808, 0x1908192b, 0x19080819, 0x1908192b,
+ 0x19081908, 0x1908192b, 0x19190808, 0x1908192b, 0x2b080808, 0x1908192b, 0x2b2b1919, 0x1908192b,
+ 0x08080819, 0x19082b08, 0x08081908, 0x19082b08, 0x08082b19, 0x19082b08, 0x08190808, 0x19082b08,
+ 0x0819082b, 0x19082b08, 0x08191919, 0x19082b08, 0x08192b08, 0x19082b08, 0x082b0819, 0x19082b08,
+ 0x082b1908, 0x19082b08, 0x19080808, 0x19082b08, 0x1908082b, 0x19082b08, 0x19081919, 0x19082b08,
+ 0x19082b08, 0x19082b08, 0x19190819, 0x19082b08, 0x19191908, 0x19082b08, 0x192b0808, 0x19082b08,
+ 0x2b081908, 0x19082b08, 0x2b190808, 0x19082b08, 0x08080808, 0x19082b19, 0x0808082b, 0x19082b19,
+ 0x08081919, 0x19082b19, 0x08082b08, 0x19082b19, 0x08190819, 0x19082b19, 0x08191908, 0x19082b19,
+ 0x082b0808, 0x19082b19, 0x19080819, 0x19082b19, 0x19081908, 0x19082b19, 0x19190808, 0x19082b19,
+ 0x2b080808, 0x19082b19, 0x2b19192b, 0x19082b19, 0x08080819, 0x19082b2b, 0x08081908, 0x19082b2b,
+ 0x08190808, 0x19082b2b, 0x19080808, 0x19082b2b, 0x08080808, 0x19190808, 0x0808082b, 0x19190808,
+ 0x08081919, 0x19190808, 0x08082b08, 0x19190808, 0x08190819, 0x19190808, 0x08191908, 0x19190808,
+ 0x0819192b, 0x19190808, 0x08192b19, 0x19190808, 0x082b0808, 0x19190808, 0x082b082b, 0x19190808,
+ 0x082b1919, 0x19190808, 0x082b2b08, 0x19190808, 0x19080819, 0x19190808, 0x19081908, 0x19190808,
+ 0x1908192b, 0x19190808, 0x19082b19, 0x19190808, 0x19190808, 0x19190808, 0x1919082b, 0x19190808,
+ 0x19191919, 0x19190808, 0x19192b08, 0x19190808, 0x192b0819, 0x19190808, 0x192b1908, 0x19190808,
+ 0x2b080808, 0x19190808, 0x2b08082b, 0x19190808, 0x2b081919, 0x19190808, 0x2b082b08, 0x19190808,
+ 0x2b190819, 0x19190808, 0x2b191908, 0x19190808, 0x08080819, 0x19190819, 0x08081908, 0x19190819,
+ 0x0808192b, 0x19190819, 0x08082b19, 0x19190819, 0x08190808, 0x19190819, 0x0819082b, 0x19190819,
+ 0x08191919, 0x19190819, 0x08192b08, 0x19190819, 0x082b0819, 0x19190819, 0x082b1908, 0x19190819,
+ 0x19080808, 0x19190819, 0x1908082b, 0x19190819, 0x19081919, 0x19190819, 0x19082b08, 0x19190819,
+ 0x19190819, 0x19190819, 0x19191908, 0x19190819, 0x192b0808, 0x19190819, 0x2b080819, 0x19190819,
+ 0x2b081908, 0x19190819, 0x2b190808, 0x19190819, 0x08080808, 0x1919082b, 0x08081919, 0x1919082b,
+ 0x08082b08, 0x1919082b, 0x08190819, 0x1919082b, 0x08191908, 0x1919082b, 0x082b0808, 0x1919082b,
+ 0x19080819, 0x1919082b, 0x19081908, 0x1919082b, 0x19190808, 0x1919082b, 0x192b2b19, 0x1919082b,
+ 0x2b080808, 0x1919082b, 0x08080819, 0x19191908, 0x08081908, 0x19191908, 0x0808192b, 0x19191908,
+ 0x08082b19, 0x19191908, 0x08190808, 0x19191908, 0x0819082b, 0x19191908, 0x08191919, 0x19191908,
+ 0x08192b08, 0x19191908, 0x082b0819, 0x19191908, 0x082b1908, 0x19191908, 0x19080808, 0x19191908,
+ 0x1908082b, 0x19191908, 0x19081919, 0x19191908, 0x19082b08, 0x19191908, 0x19190819, 0x19191908,
+ 0x19191908, 0x19191908, 0x192b0808, 0x19191908, 0x2b080819, 0x19191908, 0x2b081908, 0x19191908,
+ 0x2b190808, 0x19191908, 0x08080808, 0x19191919, 0x0808082b, 0x19191919, 0x08081919, 0x19191919,
+ 0x08082b08, 0x19191919, 0x08190819, 0x19191919, 0x08191908, 0x19191919, 0x082b0808, 0x19191919,
+ 0x19080819, 0x19191919, 0x19081908, 0x19191919, 0x19190808, 0x19191919, 0x2b080808, 0x19191919,
+ 0x08080819, 0x1919192b, 0x08081908, 0x1919192b, 0x08190808, 0x1919192b, 0x082b192b, 0x1919192b,
+ 0x19080808, 0x1919192b, 0x08080808, 0x19192b08, 0x0808082b, 0x19192b08, 0x08081919, 0x19192b08,
+ 0x08082b08, 0x19192b08, 0x08190819, 0x19192b08, 0x08191908, 0x19192b08, 0x082b0808, 0x19192b08,
+ 0x19080819, 0x19192b08, 0x19081908, 0x19192b08, 0x19190808, 0x19192b08, 0x19192b2b, 0x19192b08,
+ 0x2b080808, 0x19192b08, 0x08080819, 0x19192b19, 0x08081908, 0x19192b19, 0x08190808, 0x19192b19,
+ 0x19080808, 0x19192b19, 0x08080808, 0x19192b2b, 0x08192b19, 0x19192b2b, 0x2b081919, 0x19192b2b,
+ 0x2b2b2b08, 0x19192b2b, 0x08080819, 0x192b0808, 0x08081908, 0x192b0808, 0x0808192b, 0x192b0808,
+ 0x08190808, 0x192b0808, 0x0819082b, 0x192b0808, 0x08191919, 0x192b0808, 0x08192b08, 0x192b0808,
+ 0x082b0819, 0x192b0808, 0x082b1908, 0x192b0808, 0x19080808, 0x192b0808, 0x19081919, 0x192b0808,
+ 0x19082b08, 0x192b0808, 0x19190819, 0x192b0808, 0x19191908, 0x192b0808, 0x192b0808, 0x192b0808,
+ 0x2b081908, 0x192b0808, 0x2b190808, 0x192b0808, 0x08080808, 0x192b0819, 0x0808082b, 0x192b0819,
+ 0x08081919, 0x192b0819, 0x08082b08, 0x192b0819, 0x08190819, 0x192b0819, 0x08191908, 0x192b0819,
+ 0x082b0808, 0x192b0819, 0x19080819, 0x192b0819, 0x19081908, 0x192b0819, 0x19190808, 0x192b0819,
+ 0x2b080808, 0x192b0819, 0x2b192b19, 0x192b0819, 0x08081908, 0x192b082b, 0x08190808, 0x192b082b,
+ 0x19080808, 0x192b082b, 0x1919192b, 0x192b082b, 0x2b2b0819, 0x192b082b, 0x08080808, 0x192b1908,
+ 0x08081919, 0x192b1908, 0x08082b08, 0x192b1908, 0x08190819, 0x192b1908, 0x08191908, 0x192b1908,
+ 0x082b0808, 0x192b1908, 0x19080819, 0x192b1908, 0x19081908, 0x192b1908, 0x19190808, 0x192b1908,
+ 0x2b080808, 0x192b1908, 0x08080819, 0x192b1919, 0x08081908, 0x192b1919, 0x08190808, 0x192b1919,
+ 0x19080808, 0x192b1919, 0x19082b2b, 0x192b1919, 0x192b2b08, 0x192b1919, 0x2b19082b, 0x192b1919,
+ 0x08080808, 0x192b192b, 0x2b191908, 0x192b192b, 0x08080819, 0x192b2b08, 0x08081908, 0x192b2b08,
+ 0x08190808, 0x192b2b08, 0x192b1919, 0x192b2b08, 0x2b192b08, 0x192b2b08, 0x08080808, 0x192b2b19,
+ 0x082b2b2b, 0x192b2b19, 0x1908082b, 0x192b2b2b, 0x2b2b0819, 0x192b2b2b, 0x08080808, 0x2b080808,
+ 0x0808082b, 0x2b080808, 0x08081919, 0x2b080808, 0x08082b08, 0x2b080808, 0x08190819, 0x2b080808,
+ 0x08191908, 0x2b080808, 0x08192b19, 0x2b080808, 0x082b0808, 0x2b080808, 0x082b1919, 0x2b080808,
+ 0x19080819, 0x2b080808, 0x19081908, 0x2b080808, 0x19190808, 0x2b080808, 0x1919082b, 0x2b080808,
+ 0x19191919, 0x2b080808, 0x19192b08, 0x2b080808, 0x192b0819, 0x2b080808, 0x2b080808, 0x2b080808,
+ 0x2b081919, 0x2b080808, 0x2b190819, 0x2b080808, 0x2b191908, 0x2b080808, 0x08080819, 0x2b080819,
+ 0x08081908, 0x2b080819, 0x08082b19, 0x2b080819, 0x08190808, 0x2b080819, 0x0819082b, 0x2b080819,
+ 0x08191919, 0x2b080819, 0x08192b08, 0x2b080819, 0x082b0819, 0x2b080819, 0x082b1908, 0x2b080819,
+ 0x19080808, 0x2b080819, 0x1908082b, 0x2b080819, 0x19081919, 0x2b080819, 0x19082b08, 0x2b080819,
+ 0x19190819, 0x2b080819, 0x19191908, 0x2b080819, 0x2b080819, 0x2b080819, 0x2b081908, 0x2b080819,
+ 0x2b190808, 0x2b080819, 0x2b2b2b19, 0x2b080819, 0x08080808, 0x2b08082b, 0x08081919, 0x2b08082b,
+ 0x08082b2b, 0x2b08082b, 0x08190819, 0x2b08082b, 0x08191908, 0x2b08082b, 0x19080819, 0x2b08082b,
+ 0x19081908, 0x2b08082b, 0x19190808, 0x2b08082b, 0x08080819, 0x2b081908, 0x08081908, 0x2b081908,
+ 0x0808192b, 0x2b081908, 0x08082b19, 0x2b081908, 0x08190808, 0x2b081908, 0x0819082b, 0x2b081908,
+ 0x08191919, 0x2b081908, 0x08192b08, 0x2b081908, 0x082b0819, 0x2b081908, 0x19080808, 0x2b081908,
+ 0x1908082b, 0x2b081908, 0x19081919, 0x2b081908, 0x19082b08, 0x2b081908, 0x19190819, 0x2b081908,
+ 0x19191908, 0x2b081908, 0x192b0808, 0x2b081908, 0x2b080819, 0x2b081908, 0x2b081908, 0x2b081908,
+ 0x2b190808, 0x2b081908, 0x08080808, 0x2b081919, 0x0808082b, 0x2b081919, 0x08081919, 0x2b081919,
+ 0x08082b08, 0x2b081919, 0x08190819, 0x2b081919, 0x08191908, 0x2b081919, 0x082b0808, 0x2b081919,
+ 0x19080819, 0x2b081919, 0x19081908, 0x2b081919, 0x19190808, 0x2b081919, 0x2b080808, 0x2b081919,
+ 0x2b082b2b, 0x2b081919, 0x08080819, 0x2b08192b, 0x08081908, 0x2b08192b, 0x08190808, 0x2b08192b,
+ 0x082b2b19, 0x2b08192b, 0x19080808, 0x2b08192b, 0x08080808, 0x2b082b08, 0x08081919, 0x2b082b08,
+ 0x08190819, 0x2b082b08, 0x08191908, 0x2b082b08, 0x19080819, 0x2b082b08, 0x19081908, 0x2b082b08,
+ 0x19190808, 0x2b082b08, 0x2b2b082b, 0x2b082b08, 0x08080819, 0x2b082b19, 0x08081908, 0x2b082b19,
+ 0x19080808, 0x2b082b19, 0x192b1919, 0x2b082b19, 0x082b082b, 0x2b082b2b, 0x19192b08, 0x2b082b2b,
+ 0x19192b2b, 0x2b082b2b, 0x2b08082b, 0x2b082b2b, 0x2b2b082b, 0x2b082b2b, 0x08080819, 0x2b190808,
+ 0x08081908, 0x2b190808, 0x08082b19, 0x2b190808, 0x08190808, 0x2b190808, 0x0819082b, 0x2b190808,
+ 0x08191919, 0x2b190808, 0x08192b08, 0x2b190808, 0x082b1908, 0x2b190808, 0x19080808, 0x2b190808,
+ 0x1908082b, 0x2b190808, 0x19081919, 0x2b190808, 0x19082b08, 0x2b190808, 0x19190819, 0x2b190808,
+ 0x19191908, 0x2b190808, 0x192b0808, 0x2b190808, 0x2b080819, 0x2b190808, 0x2b081908, 0x2b190808,
+ 0x2b190808, 0x2b190808, 0x08080808, 0x2b190819, 0x08081919, 0x2b190819, 0x08190819, 0x2b190819,
+ 0x08191908, 0x2b190819, 0x19080819, 0x2b190819, 0x19081908, 0x2b190819, 0x19190808, 0x2b190819,
+ 0x19192b2b, 0x2b190819, 0x08080819, 0x2b19082b, 0x08081908, 0x2b19082b, 0x08190808, 0x2b19082b,
+ 0x19080808, 0x2b19082b, 0x2b2b192b, 0x2b19082b, 0x08080808, 0x2b191908, 0x0808082b, 0x2b191908,
+ 0x08081919, 0x2b191908, 0x08082b08, 0x2b191908, 0x08190819, 0x2b191908, 0x08191908, 0x2b191908,
+ 0x082b0808, 0x2b191908, 0x19080819, 0x2b191908, 0x19081908, 0x2b191908, 0x19190808, 0x2b191908,
+ 0x2b080808, 0x2b191908, 0x2b19192b, 0x2b191908, 0x08080819, 0x2b191919, 0x08081908, 0x2b191919,
+ 0x08190808, 0x2b191919, 0x19080808, 0x2b191919, 0x2b192b08, 0x2b191919, 0x2b2b0819, 0x2b191919,
+ 0x08080808, 0x2b19192b, 0x1908192b, 0x2b19192b, 0x192b1908, 0x2b19192b, 0x08080819, 0x2b192b08,
+ 0x08081908, 0x2b192b08, 0x08190808, 0x2b192b08, 0x082b192b, 0x2b192b08, 0x19080808, 0x2b192b08,
+ 0x2b2b2b19, 0x2b192b08, 0x08080808, 0x2b192b19, 0x19082b19, 0x2b192b19, 0x1919082b, 0x2b192b19,
+ 0x2b190808, 0x2b192b2b, 0x08080808, 0x2b2b0808, 0x08081919, 0x2b2b0808, 0x08082b2b, 0x2b2b0808,
+ 0x08191908, 0x2b2b0808, 0x082b082b, 0x2b2b0808, 0x082b2b2b, 0x2b2b0808, 0x19080819, 0x2b2b0808,
+ 0x19081908, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b2b082b, 0x2b2b0808, 0x2b2b2b2b, 0x2b2b0808,
+ 0x19080808, 0x2b2b0819, 0x192b1919, 0x2b2b0819, 0x0808082b, 0x2b2b082b, 0x08082b2b, 0x2b2b082b,
+ 0x082b082b, 0x2b2b082b, 0x082b2b08, 0x2b2b082b, 0x082b2b2b, 0x2b2b082b, 0x2b08082b, 0x2b2b082b,
+ 0x2b082b08, 0x2b2b082b, 0x2b082b2b, 0x2b2b082b, 0x2b2b2b08, 0x2b2b082b, 0x08080819, 0x2b2b1908,
+ 0x08081908, 0x2b2b1908, 0x08190808, 0x2b2b1908, 0x19080808, 0x2b2b1908, 0x2b082b19, 0x2b2b1908,
+ 0x2b2b1908, 0x2b2b1908, 0x08080808, 0x2b2b1919, 0x08192b19, 0x2b2b1919, 0x19190819, 0x2b2b192b,
+ 0x08082b2b, 0x2b2b2b08, 0x082b2b08, 0x2b2b2b08, 0x2b2b082b, 0x2b2b2b08, 0x19191908, 0x2b2b2b19,
+ 0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
+ 0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
+);
+#enddecl(IQ2_S_GRID)
+
+#decl(IQ3_XSS_GRID)
+
+const iq3xxs_grid = array<u32, 256>(
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
+);
+#enddecl(IQ3_XSS_GRID)
+
+#decl(IQ3_S_GRID)
+
+const iq3s_grid = array<u32, 512>(
+ 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
+ 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
+ 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
+ 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
+ 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
+ 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
+ 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
+ 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
+ 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
+ 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
+ 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
+ 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
+ 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
+ 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
+ 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
+ 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
+ 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
+ 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
+ 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
+ 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
+ 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
+ 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
+ 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
+ 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
+ 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
+ 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
+ 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
+ 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
+ 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
+ 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
+ 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
+ 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
+ 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
+ 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
+ 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
+ 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
+ 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
+ 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
+ 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
+ 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
+ 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
+ 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
+ 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
+ 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
+ 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
+ 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
+ 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
+ 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
+ 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
+ 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
+ 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
+ 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
+ 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
+ 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
+ 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
+ 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
+ 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
+ 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
+ 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
+ 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
+ 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
+ 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
+ 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
+ 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
+);
+#enddecl(IQ3_S_GRID)
+
+#decl(IQ1_GRID)
+
+const IQ1_DELTA: f32 = 0.125;
+
+const iq1_grid = array<u32, 1024>(
+ 0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01,
+ 0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4,
+ 0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41,
+ 0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f,
+ 0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334,
+ 0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f,
+ 0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040,
+ 0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f,
+ 0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5,
+ 0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3,
+ 0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff,
+ 0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570,
+ 0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f,
+ 0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf,
+ 0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f,
+ 0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07,
+ 0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc,
+ 0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374,
+ 0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0,
+ 0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001,
+ 0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043,
+ 0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc,
+ 0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117,
+ 0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f,
+ 0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5,
+ 0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474,
+ 0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d,
+ 0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd,
+ 0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50,
+ 0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10,
+ 0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30,
+ 0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1,
+ 0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c,
+ 0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074,
+ 0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134,
+ 0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7,
+ 0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3,
+ 0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450,
+ 0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577,
+ 0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c,
+ 0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5,
+ 0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c,
+ 0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00,
+ 0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300,
+ 0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc,
+ 0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034,
+ 0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077,
+ 0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5,
+ 0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117,
+ 0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f,
+ 0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5,
+ 0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404,
+ 0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1,
+ 0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd,
+ 0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71,
+ 0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7,
+ 0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00,
+ 0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44,
+ 0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00,
+ 0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0,
+ 0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303,
+ 0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343,
+ 0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd,
+ 0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031,
+ 0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011,
+ 0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c,
+ 0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4,
+ 0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c,
+ 0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174,
+ 0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7,
+ 0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d,
+ 0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4,
+ 0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c,
+ 0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7,
+ 0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510,
+ 0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33,
+ 0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4,
+ 0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73,
+ 0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f,
+ 0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337,
+ 0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343,
+ 0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030,
+ 0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075,
+ 0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4,
+ 0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170,
+ 0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705,
+ 0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c,
+ 0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c,
+ 0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514,
+ 0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c,
+ 0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3,
+ 0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70,
+ 0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03,
+ 0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c,
+ 0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c,
+ 0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074,
+ 0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104,
+ 0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7,
+ 0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757,
+ 0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c,
+ 0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c,
+ 0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4,
+ 0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc,
+ 0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03,
+ 0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc,
+ 0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54,
+ 0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f,
+ 0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf,
+ 0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c,
+ 0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c,
+ 0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4,
+ 0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174,
+ 0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700,
+ 0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7,
+ 0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d,
+ 0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531,
+ 0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf,
+ 0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57,
+ 0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13,
+ 0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01,
+ 0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f,
+ 0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7,
+ 0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074,
+ 0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107,
+ 0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd,
+ 0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0,
+ 0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7,
+ 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
+);
+
+#enddecl(IQ1_GRID)
+
+#decl(IQ4_GRID)
+
+const kvalues_iq4nl = array<i32, 16>(
+ -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
+);
+
+#enddecl(IQ4_GRID)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
new file mode 100644
index 0000000..b5e93b8
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
@@ -0,0 +1,107 @@
+#define(VARIANTS)
+
+[
+ {
+ "REPLS": {
+ "SRC_TYPE": "f32",
+ "DST_TYPE": "f32"
+ }
+ },
+ {
+ "REPLS": {
+ "SRC_TYPE": "f32",
+ "DST_TYPE": "i32"
+ }
+ },
+ {
+ "REPLS": {
+ "SRC_TYPE": "f32",
+ "DST_TYPE": "f16"
+ }
+ },
+ {
+ "REPLS": {
+ "SRC_TYPE": "f16",
+ "DST_TYPE": "f16"
+ }
+ },
+ {
+ "REPLS": {
+ "SRC_TYPE": "f16",
+ "DST_TYPE": "f32"
+ }
+ }
+]
+
+#end(VARIANTS)
+
+#define(SHADER)
+enable f16;
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<{{SRC_TYPE}}>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<{{DST_TYPE}}>;
+
+struct Params {
+ ne: u32, // total number of elements
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+
+ // Strides (in elements) — may be permuted
+ stride_src0: u32,
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_dst0: u32,
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // Logical shapes
+ src_ne0: u32,
+ src_ne1: u32,
+ src_ne2: u32,
+
+ dst_ne0: u32,
+ dst_ne1: u32,
+ dst_ne2: u32
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= params.ne) {
+ return;
+ }
+
+ var i = gid.x;
+ let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
+ i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
+ let i2 = i / (params.src_ne1 * params.src_ne0);
+ i = i % (params.src_ne1 * params.src_ne0);
+ let i1 = i / params.src_ne0;
+ let i0 = i % params.src_ne0;
+
+ var j = gid.x;
+ let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
+ j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
+ let j2 = j / (params.dst_ne1 * params.dst_ne0);
+ j = j % (params.dst_ne1 * params.dst_ne0);
+ let j1 = j / params.dst_ne0;
+ let j0 = j % params.dst_ne0;
+
+ let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
+ i2 * params.stride_src2 + i3 * params.stride_src3;
+
+ let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
+ j2 * params.stride_dst2 + j3 * params.stride_dst3;
+
+ dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
+}
+#end(SHADER)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl
new file mode 100644
index 0000000..e622552
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl
@@ -0,0 +1,66 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+ ne0: u32,
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+var<workgroup> shared_sum: array<f32, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+ let row_idx = params.offset_src + wid.x * params.ne0;
+ let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
+ var local_sum: f32 = 0.0;
+ for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
+ local_sum += src[row_idx + col];
+ }
+ shared_sum[lid.x] = local_sum;
+ workgroupBarrier();
+
+ // upsweep
+ var offset = 1u;
+ while (offset < WG_SIZE) {
+ let idx = (lid.x + 1) * offset * 2 - 1;
+ if (idx < WG_SIZE) {
+ shared_sum[idx] = shared_sum[idx] + shared_sum[idx - offset];
+ }
+ workgroupBarrier();
+ offset <<= 1;
+ }
+
+ // set last to 0 for exclusive sum
+ if (lid.x == 0) {
+ shared_sum[WG_SIZE - 1] = 0.0;
+ }
+ workgroupBarrier();
+
+ // downsweep
+ offset = WG_SIZE >> 1;
+ while (offset > 0) {
+ let idx = (lid.x + 1) * offset * 2 - 1;
+ if (idx < WG_SIZE) {
+ let t = shared_sum[idx - offset];
+ shared_sum[idx - offset] = shared_sum[idx];
+ shared_sum[idx] = shared_sum[idx] + t;
+ }
+ workgroupBarrier();
+ offset = offset >> 1;
+ }
+
+ // shared_sum[lid] is exclusive prefix sum up to this thread.
+ var running_sum = shared_sum[lid.x];
+ for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
+ running_sum += src[row_idx + col];
+ dst[params.offset_dst + wid.x * params.ne0 + col] = running_sum;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
new file mode 100755
index 0000000..d61df5b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
@@ -0,0 +1,147 @@
+import os
+import re
+import ast
+import argparse
+
+
+def extract_block(text, name):
+ pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
+ match = re.search(pattern, text, re.DOTALL)
+ if not match:
+ raise ValueError(f"Missing block: {name}")
+ return match.group(1).strip()
+
+
+def parse_decls(decls_text):
+ decls = {}
+ for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
+ decls[name.strip()] = code.strip()
+ return decls
+
+
+def replace_repl_placeholders(variant, template_map):
+ for repl, code in variant["REPLS"].items():
+ for key, val in template_map.items():
+ # Match "key" and avoid matching subsequences using by using \b
+ code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
+ variant["REPLS"][repl] = code
+ return variant
+
+
+def replace_placeholders(shader_text, replacements):
+ for key, val in replacements.items():
+ # Match {{KEY}} literally, where KEY is escaped
+ pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
+ shader_text = re.sub(pattern, str(val), shader_text)
+ return shader_text
+
+
+def expand_includes(shader, input_dir):
+ """
+ Replace #include "file" lines in the text with the contents of that file.
+ Searches for files relative to input_dir.
+ """
+ include_pattern = re.compile(r'^\s*#include\s+"([^"]+)"\s*$', re.MULTILINE)
+
+ def replacer(match):
+ fname = match.group(1)
+ file_path = os.path.join(input_dir, fname)
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"Included file not found: {file_path}")
+ with open(file_path, "r", encoding="utf-8") as f:
+ included_code = f.read()
+ # Recursively expand includes inside the included file
+ return expand_includes(included_code, input_dir)
+
+ return include_pattern.sub(replacer, shader)
+
+
+def write_shader(shader_name, shader_code, output_dir, outfile):
+ if output_dir:
+ wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
+ with open(wgsl_filename, "w", encoding="utf-8") as f_out:
+ f_out.write(shader_code)
+ outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
+
+
+def generate_variants(fname, input_dir, output_dir, outfile):
+ shader_path = os.path.join(input_dir, fname)
+ shader_base_name = fname.split(".")[0]
+
+ with open(shader_path, "r", encoding="utf-8") as f:
+ text = f.read()
+
+ try:
+ variants = ast.literal_eval(extract_block(text, "VARIANTS"))
+ except ValueError:
+ write_shader(shader_base_name, text, output_dir, outfile)
+ else:
+ try:
+ decls_map = parse_decls(extract_block(text, "DECLS"))
+ except ValueError:
+ decls_map = {}
+ try:
+ templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
+ except ValueError:
+ templates_map = {}
+
+ for fname in sorted(os.listdir(input_dir)):
+ if fname.endswith(".tmpl"):
+ tmpl_path = os.path.join(input_dir, fname)
+ with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
+ decls = f_tmpl.read()
+ decls_map.update(parse_decls(decls))
+
+ shader_template = extract_block(text, "SHADER")
+ for variant in variants:
+ if "DECLS" in variant:
+ decls = variant["DECLS"]
+ else:
+ decls = []
+ decls_code = ""
+ for key in decls:
+ if key not in decls_map:
+ raise ValueError(f"DECLS key '{key}' not found.")
+ decls_code += decls_map[key] + "\n\n"
+ final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
+ if "REPLS" in variant:
+ variant = replace_repl_placeholders(variant, templates_map)
+ final_shader = replace_placeholders(final_shader, variant["REPLS"])
+ # second run to expand placeholders in repl_template
+ final_shader = replace_placeholders(final_shader, variant["REPLS"])
+ final_shader = expand_includes(final_shader, input_dir)
+
+ if "SHADER_NAME" in variant:
+ output_name = variant["SHADER_NAME"]
+ elif "SHADER_SUFFIX" in variant:
+ output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
+ elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
+ output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
+ elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
+ output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
+ elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
+ output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
+ else:
+ output_name = shader_base_name
+ write_shader(output_name, final_shader, output_dir, outfile)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input_dir", required=True)
+ parser.add_argument("--output_file", required=True)
+ parser.add_argument("--output_dir")
+ args = parser.parse_args()
+
+ if args.output_dir:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ with open(args.output_file, "w", encoding="utf-8") as out:
+ out.write("// Auto-generated shader embedding\n\n")
+ for fname in sorted(os.listdir(args.input_dir)):
+ if fname.endswith(".wgsl"):
+ generate_variants(fname, args.input_dir, args.output_dir, out)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl
new file mode 100644
index 0000000..b682216
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl
@@ -0,0 +1,636 @@
+diagnostic(off, chromium.subgroup_matrix_uniformity);
+diagnostic(off, subgroup_uniformity);
+enable f16;
+enable subgroups;
+enable chromium_experimental_subgroup_matrix;
+
+#ifdef KV_F32
+#define KV_TYPE f32
+#else
+#define KV_TYPE f16
+#endif
+
+// Default values
+#define HEAD_DIM_QK 64
+#define HEAD_DIM_V 64
+
+// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN
+// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension.
+#define SG_MAT_M 8
+#define SG_MAT_N 8
+#define SG_MAT_K 8
+
+// Each workgroup processes one subgroup matrix of Q rows
+#define Q_TILE SG_MAT_M
+#define KV_TILE 16
+#define WG_SIZE 64
+
+// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE.
+#define KV_BLOCKS (KV_TILE / SG_MAT_N)
+
+// Quantization constants/helpers
+#define BLOCK_SIZE 32
+#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
+#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
+// number of quantized elements processed per thread
+#if defined(KV_Q4_0)
+#define NQ 16
+// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
+#define F16_PER_BLOCK 9
+#define WEIGHTS_PER_F16 4
+#elif defined(KV_Q8_0)
+#define NQ 8
+// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
+#define F16_PER_BLOCK 17
+#define WEIGHTS_PER_F16 2
+#endif
+#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
+
+// Ok not to put these in a define block, compiler will remove if unused
+fn get_byte(value: u32, index: u32) -> u32 {
+ return (value >> (index * 8)) & 0xFF;
+}
+
+fn get_byte_i32(value: u32, index: u32) -> i32 {
+ return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
+}
+
+struct Params {
+ offset_q: u32,
+ offset_k: u32,
+ offset_v: u32,
+ offset_mask: u32,
+ offset_sinks: u32,
+ offset_dst: u32,
+
+ // shapes of Q/K/V
+ n_heads: u32,
+ seq_len_q: u32,
+ seq_len_kv: u32,
+
+ // strides (in elements)
+ stride_q1: u32,
+ stride_q2: u32,
+ stride_q3: u32,
+ stride_k1: u32,
+ stride_k2: u32,
+ stride_k3: u32,
+ stride_v1: u32,
+ stride_v2: u32,
+ stride_v3: u32,
+ stride_mask3: u32,
+
+ // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
+ q_per_kv: u32,
+
+ // softmax params
+ scale: f32,
+ max_bias: f32,
+ logit_softcap: f32,
+ n_head_log2: f32,
+ m0: f32,
+ m1: f32,
+};
+
+@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
+@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
+@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
+
+#if defined(MASK) && defined(SINKS)
+@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
+@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
+#define DST_BINDING 5
+#define PARAMS_BINDING 6
+#elif defined(MASK)
+@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
+#define DST_BINDING 4
+#define PARAMS_BINDING 5
+#elif defined(SINKS)
+@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
+#define DST_BINDING 4
+#define PARAMS_BINDING 5
+#else
+#define DST_BINDING 3
+#define PARAMS_BINDING 4
+#endif
+
+@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
+@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
+
+// Just a very small float value.
+const FLOAT_MIN: f32 = -1.0e9;
+
+// The number of Q rows processed per workgroup
+var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
+
+#ifndef KV_DIRECT
+const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
+// we can reuse the same shmem for K and V since we only need one at a time
+var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
+#endif
+
+var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>; // output shmem
+
+#ifdef MASK
+// storage for mask values
+var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
+#endif
+
+// storage for output of Q*K^T scores for online softmax (S matrix from paper)
+// also storage for diagonal matrix during online softmax (P matrix from paper)
+// note that we reuse the same storage for both since we only need one at a time
+var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
+
+// Storage for row max and exp sum during online softmax
+var<workgroup> row_max_shmem: array<f32, Q_TILE>;
+var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
+
+fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 {
+ var v = select(FLOAT_MIN,
+ f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
+ kv_idx < KV_TILE);
+#ifdef LOGIT_SOFTCAP
+ v = params.logit_softcap * tanh(v);
+#endif
+#ifdef MASK
+ let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
+ let mask_term = slope * mask_val;
+ v += mask_term;
+#endif
+ return v;
+}
+
+fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32) -> vec4<f32> {
+ return (*buf)[scalar_index >> 2u];
+}
+
+fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> {
+ return (*buf)[scalar_index >> 2u];
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+ @builtin(local_invocation_id) local_id: vec3<u32>,
+ @builtin(subgroup_id) subgroup_id: u32,
+ @builtin(subgroup_size) subgroup_size: u32,
+ @builtin(num_subgroups) num_subgroups: u32,
+ @builtin(subgroup_invocation_id) sg_inv_id: u32) {
+
+ // initialize row max for online softmax
+ for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
+ row_max_shmem[i] = FLOAT_MIN;
+ exp_sum_shmem[i] = 0.0;
+ }
+
+ for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
+ o_shmem[i] = 0.0;
+ }
+
+ // workgroups per head/batch
+ let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
+ let wg_per_batch = wg_per_head * params.n_heads;
+
+ let dst2_stride = HEAD_DIM_V * params.n_heads;
+ let dst3_stride = dst2_stride * params.seq_len_q;
+
+ // batch index
+ let batch_idx = wg_id.x / wg_per_batch;
+ let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
+ let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
+ let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
+ let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;
+ let wg_in_batch = wg_id.x % wg_per_batch;
+
+ // head index
+ let head_idx = wg_in_batch / wg_per_head;
+ let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
+ let k_head_idx = head_idx / params.q_per_kv;
+ let v_head_idx = k_head_idx;
+ let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
+ let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
+
+ // starting Q row for this workgroup
+ let wg_in_head = wg_in_batch % wg_per_head;
+ let q_row_start = wg_in_head * Q_TILE;
+
+#ifdef MASK
+ // mask offset
+ let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
+#endif
+
+ // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size]
+ let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;
+
+ let head = f32(head_idx);
+ let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0);
+
+ // load q tile into shared memory
+ for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
+ let q_row = elem_idx / HEAD_DIM_QK;
+ let q_col = elem_idx % HEAD_DIM_QK;
+ let head_q_row = q_row_start + q_row;
+ let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
+ q_shmem[elem_idx] = f16(select(
+ 0.0,
+ Q[global_q_row_offset + q_col],
+ head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
+ }
+
+ for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
+ // clear inter_shmem to ensure zero-initialized accumulators
+ for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
+ inter_shmem[elem_idx] = 0.0;
+ }
+
+ // load k tile into shared memory
+#if defined(KV_Q4_0)
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
+ let blck_idx = elem_idx / BLOCK_SIZE;
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let k_row = blck_idx / BLOCKS_K;
+ let global_k_row = kv_tile + k_row;
+ let block_k = blck_idx % BLOCKS_K;
+ let row_offset = k_row * HEAD_DIM_QK;
+
+ if (global_k_row < params.seq_len_kv) {
+ let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
+ let base_idx = global_block_idx * F16_PER_BLOCK;
+ let d = K[base_idx]; // scale
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = K[base_idx + 1u + block_offset + j];
+ let q_1 = K[base_idx + 1u + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
+ let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+ kv_shmem[row_offset + idx] = q_lo;
+ kv_shmem[row_offset + idx + 16u] = q_hi;
+ }
+ }
+ }
+ }
+#elif defined(KV_Q8_0)
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
+ let blck_idx = elem_idx / BLOCK_SIZE;
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let k_row = blck_idx / BLOCKS_K;
+ let global_k_row = kv_tile + k_row;
+ let block_k = blck_idx % BLOCKS_K;
+ let row_offset = k_row * HEAD_DIM_QK;
+
+ if (global_k_row < params.seq_len_kv) {
+ let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
+ let base_idx = global_block_idx * F16_PER_BLOCK;
+ let d = K[base_idx]; // scale
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = K[base_idx + 1u + block_offset + j];
+ let q_1 = K[base_idx + 1u + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte_i32(q_packed, k);
+ let q_val = f16(q_byte) * d;
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+ kv_shmem[row_offset + idx] = q_val;
+ }
+ }
+ }
+ }
+#elif defined(KV_DIRECT)
+ // Direct global loads for KV
+#else
+ for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
+ let k_row = elem_idx / HEAD_DIM_QK;
+ let k_col = elem_idx % HEAD_DIM_QK;
+ let global_k_row = kv_tile + k_row;
+ let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
+ kv_shmem[elem_idx] = f16(select(
+ 0.0,
+ K[global_k_row_offset + k_col],
+ global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK));
+ }
+#endif
+
+ workgroupBarrier();
+
+ // accumulate q block * k block into registers across the entire KV tile
+ // TODO: this loop seems to be the current largest bottleneck
+ // this bracket exists to scope the lifetime of variables, reducing register pressure
+ {
+#ifdef KV_DIRECT
+ let k_block_row = kv_tile + subgroup_id * SG_MAT_N;
+ var k_global_offset = k_head_offset + k_block_row * params.stride_k1;
+#else
+ var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK;
+#endif
+ for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
+ let inter_offset = kv_block * SG_MAT_N;
+ var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
+
+ var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
+
+#ifdef KV_DIRECT
+ var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
+#else
+ var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
+#endif
+
+ var t: u32 = 1u;
+ for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
+ let h0 = t * SG_MAT_K;
+ var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+ var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
+#else
+ var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
+#endif
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+ q_cur = q0;
+ k_cur = k0;
+
+ let h1 = (t + 1u) * SG_MAT_K;
+ var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+ var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
+#else
+ var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
+#endif
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+ q_cur = q1g;
+ k_cur = k1g;
+ }
+
+ // handle odd tail
+ if (t < HEAD_DIM_QK / SG_MAT_K) {
+ let h = t * SG_MAT_K;
+ var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+ var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
+#else
+ var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
+#endif
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+ q_cur = qn;
+ k_cur = kn;
+ }
+
+ acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+
+#ifdef KV_DIRECT
+ k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1;
+#else
+ k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK;
+#endif
+ subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
+ }
+ }
+
+
+#ifdef MASK
+ // load mask tile into shared memory for this KV block
+ // TODO: optimize and skip if mask is -INF for the entire tile
+ for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
+ let mask_row = elem_idx / KV_TILE;
+ let mask_col = elem_idx % KV_TILE;
+ let global_q_row = q_row_start + mask_row;
+ let global_k_col = kv_tile + mask_col;
+ let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
+ let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
+ mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
+ }
+#endif
+
+ workgroupBarrier();
+
+ // online softmax
+ for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
+ let global_q_row = q_row_start + q_tile_row;
+ if (global_q_row >= params.seq_len_q) {
+ break;
+ }
+
+ // initialize running max for this row
+ var prev_max = row_max_shmem[q_tile_row];
+ var final_max = prev_max;
+ // pass 1: compute final max across the full KV tile in chunks
+ for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
+ let kv_idx = kv_offset + sg_inv_id;
+ let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
+ final_max = subgroupMax(max(final_max, softmax_term));
+ }
+
+ var total_exp_term: f32 = 0.0;
+ // pass 2: compute exp sum and write P using final_max
+ for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
+ let kv_idx = kv_offset + sg_inv_id;
+ let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
+ let cur_p = select(0.0,
+ exp(softmax_term - final_max),
+ kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
+ total_exp_term += subgroupAdd(cur_p);
+ if (kv_idx < KV_TILE) {
+ inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
+ }
+ }
+
+ let cur_exp = exp(prev_max - final_max);
+
+ if (sg_inv_id == 0) {
+ row_max_shmem[q_tile_row] = final_max;
+ exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
+ }
+
+ for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
+ let idx = q_tile_row * HEAD_DIM_V + elem_idx;
+ o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
+ }
+ }
+
+ // load v tile into shared memory
+#if defined(KV_Q4_0)
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
+ let blck_idx = elem_idx / BLOCK_SIZE;
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let v_row = blck_idx / BLOCKS_V;
+ let global_v_row = kv_tile + v_row;
+ let block_k = blck_idx % BLOCKS_V;
+ let row_offset = v_row * HEAD_DIM_V;
+
+ if (global_v_row < params.seq_len_kv) {
+ let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
+ let base_idx = global_block_idx * F16_PER_BLOCK;
+ let d = V[base_idx]; // scale
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = V[base_idx + 1u + block_offset + j];
+ let q_1 = V[base_idx + 1u + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
+ let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+ kv_shmem[row_offset + idx] = q_lo;
+ kv_shmem[row_offset + idx + 16u] = q_hi;
+ }
+ }
+ }
+ }
+#elif defined(KV_Q8_0)
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
+ let blck_idx = elem_idx / BLOCK_SIZE;
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let v_row = blck_idx / BLOCKS_V;
+ let global_v_row = kv_tile + v_row;
+ let block_k = blck_idx % BLOCKS_V;
+ let row_offset = v_row * HEAD_DIM_V;
+
+ if (global_v_row < params.seq_len_kv) {
+ let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
+ let base_idx = global_block_idx * F16_PER_BLOCK;
+ let d = V[base_idx]; // scale
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = V[base_idx + 1u + block_offset + j];
+ let q_1 = V[base_idx + 1u + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte_i32(q_packed, k);
+ let q_val = f16(q_byte) * d;
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+ kv_shmem[row_offset + idx] = q_val;
+ }
+ }
+ }
+ }
+#elif defined(KV_DIRECT)
+ // Direct global loads for KV
+#else
+ for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) {
+ let v_row = elem_idx / HEAD_DIM_V;
+ let v_col = elem_idx % HEAD_DIM_V;
+ let global_v_row = kv_tile + v_row;
+ let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
+ kv_shmem[elem_idx] = f16(select(
+ 0.0,
+ V[global_v_row_offset + v_col],
+ global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V));
+ }
+#endif
+
+ workgroupBarrier();
+
+ // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
+ // we want to compute O += P * V across the full KV tile
+ for (var head_dim_block = subgroup_id * SG_MAT_N;
+ head_dim_block < HEAD_DIM_V;
+ head_dim_block += num_subgroups * SG_MAT_N) {
+ // load O submatrix from shared memory
+ var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(
+ &o_shmem,
+ head_dim_block,
+ false,
+ HEAD_DIM_V
+ );
+ for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
+ let p_offset = kv_block * SG_MAT_N;
+ var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
+ &inter_shmem,
+ p_offset,
+ false,
+ KV_TILE
+ );
+
+ // load V submatrix from global or shared memory
+#ifdef KV_DIRECT
+ let v_block_row = kv_tile + kv_block * SG_MAT_N;
+ let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
+ var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
+ &V,
+ v_global_offset,
+ false,
+ params.stride_v1
+ );
+#else
+ let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
+ var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
+ &kv_shmem,
+ v_block_offset + head_dim_block,
+ false,
+ HEAD_DIM_V
+ );
+#endif
+ // O += P * V
+ o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);
+ }
+ // store O back to shared memory
+ subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);
+ }
+ workgroupBarrier();
+ }
+
+#ifdef SINKS
+ // add sinks (applied once after processing all KV tiles)
+ for (var q_tile_row = subgroup_id;
+ q_tile_row < Q_TILE;
+ q_tile_row += num_subgroups) {
+ // no need to process rows beyond seq_len_q
+ let global_q_row = q_row_start + q_tile_row;
+ if (global_q_row >= params.seq_len_q) {
+ break;
+ }
+
+ var prev_max = row_max_shmem[q_tile_row];
+
+ // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
+ let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
+ let new_max = subgroupMax(max(prev_max, sink_val));
+ let max_exp = exp(prev_max - new_max);
+ let sink_exp = exp(sink_val - new_max);
+
+ let sink_exp_sum = subgroupAdd(sink_exp);
+
+ if (sg_inv_id == 0) {
+ exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
+ }
+
+ for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
+ let idx = q_tile_row * HEAD_DIM_V + elem_idx;
+ let val = f32(o_shmem[idx]) * max_exp;
+ o_shmem[idx] = f16(val);
+ }
+ }
+ workgroupBarrier();
+#endif
+ for (var q_tile_row = subgroup_id;
+ q_tile_row < Q_TILE;
+ q_tile_row += num_subgroups) {
+
+ let global_q_row = q_row_start + q_tile_row;
+ if (global_q_row >= params.seq_len_q) { break; }
+
+ let exp_sum = exp_sum_shmem[q_tile_row];
+ let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
+
+ let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride;
+
+ for (var elem_base = sg_inv_id * 4u;
+ elem_base < HEAD_DIM_V;
+ elem_base += subgroup_size * 4u) {
+
+ let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
+ let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
+ let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
+ let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
+
+ let v = vec4<f32>(
+ f32(o_shmem[i0]) * scale,
+ f32(o_shmem[i1]) * scale,
+ f32(o_shmem[i2]) * scale,
+ f32(o_shmem[i3]) * scale
+ );
+
+ let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
+ dst[dst_vec_index] = v;
+ }
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl
new file mode 100644
index 0000000..f80ce1f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl
@@ -0,0 +1,874 @@
+#define(VARIANTS)
+
+[
+ {
+ "SHADER_SUFFIX": "f32_vec",
+ "REPLS": {
+ "TYPE" : "vec4<f32>",
+ "DST_TYPE": "vec4<f32>",
+ "BLOCK_SIZE": 4
+ },
+ "DECLS": ["F32_VEC"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "f32",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 1
+ },
+ "DECLS": ["F32"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "f16",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 1
+ },
+ "DECLS": ["F16"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "i32",
+ "DST_TYPE": "i32",
+ "BLOCK_SIZE": 1
+ },
+ "DECLS": ["I32"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "q4_0",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 32
+ },
+ "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "q4_1",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 32
+ },
+ "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "q5_0",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 32
+ },
+ "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "q5_1",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 32
+ },
+ "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "q8_0",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 32
+ },
+ "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "q2_k",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "q3_k",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "q4_k",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "q5_k",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "q6_k",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "iq2_xxs",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "iq2_xs",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
+ },
+ {
+ "REPLS": {
+ "TYPE": "iq2_s",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
+ },
+ {
+ "REPLS": {
+ "TYPE": "iq3_xxs",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
+ },
+ {
+ "REPLS": {
+ "TYPE": "iq3_s",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
+ },
+ {
+ "REPLS": {
+ "TYPE": "iq1_s",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
+ },
+ {
+ "REPLS": {
+ "TYPE": "iq1_m",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
+ },
+ {
+ "REPLS": {
+ "TYPE": "iq4_nl",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 32,
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
+ },
+ {
+ "REPLS": {
+ "TYPE": "iq4_xs",
+ "DST_TYPE": "f32",
+ "BLOCK_SIZE": 256,
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
+ }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(F32_VEC)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
+}
+#enddecl(F32_VEC)
+
+#decl(F32)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ dst[dst_base + offset] = src[src_base + offset];
+}
+#enddecl(F32)
+
+#decl(F16)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ dst[dst_base + offset] = f32(src[src_base + offset]);
+}
+#enddecl(F16)
+
+#decl(I32)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ dst[dst_base + offset] = src[src_base + offset];
+}
+#enddecl(I32)
+
+#decl(Q4_0)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block_q4_0 = src[src_base + offset];
+ let d = f32(block_q4_0.d);
+ for (var j: u32 = 0; j < 4; j++) {
+ let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
+ let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
+ let dst_offset = dst_base + offset * 32 + j * 4 + k;
+ dst[dst_offset] = q_lo;
+ dst[dst_offset + 16] = q_hi;
+ }
+ }
+}
+#enddecl(Q4_0)
+
+#decl(Q4_1)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block_q4_1 = src[src_base + offset];
+ let d = f32(block_q4_1.d);
+ let m = f32(block_q4_1.m);
+ for (var j: u32 = 0; j < 4; j++) {
+ let q_packed = block_q4_1.qs[j];
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
+ let q_lo = f32(q_byte & 0xF) * d + m;
+ let dst_offset = dst_base + offset * 32 + j * 4 + k;
+ dst[dst_offset] = q_lo;
+ dst[dst_offset + 16] = q_hi;
+ }
+ }
+}
+#enddecl(Q4_1)
+
+#decl(Q5_0)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block_q5_0 = src[src_base + offset];
+ let d = f32(block_q5_0.d);
+ let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
+ for (var j: u32 = 0; j < 4; j++) {
+ let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
+ let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
+ let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
+ let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
+ let dst_offset = dst_base + offset * 32 + j * 4 + k;
+ dst[dst_offset] = q_lo;
+ dst[dst_offset + 16] = q_hi;
+ }
+ }
+}
+
+#enddecl(Q5_0)
+
+#decl(Q5_1)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block_q5_1 = src[src_base + offset];
+ let d = f32(block_q5_1.d);
+ let m = f32(block_q5_1.m);
+ for (var j: u32 = 0; j < 4; j++) {
+ let q_packed = block_q5_1.qs[j];
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
+ let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
+ let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
+ let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
+ let dst_offset = dst_base + offset * 32 + j * 4 + k;
+ dst[dst_offset] = q_lo;
+ dst[dst_offset + 16] = q_hi;
+ }
+ }
+}
+#enddecl(Q5_1)
+
+#decl(Q8_0)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block_q8_0 = src[src_base + offset];
+ let d = f32(block_q8_0.d);
+ for (var j: u32 = 0; j < 8; j++) {
+ let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte_i32(q_packed, k);
+ let q_val = f32(q_byte) * d;
+ let dst_offset = dst_base + offset * 32 + j * 4 + k;
+ dst[dst_offset] = q_val;
+ }
+ }
+}
+#enddecl(Q8_0)
+
+#decl(Q2_K)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+ let m = f32(block.dmin);
+ var dst_i = dst_base + offset * 256;
+ var is: u32 = 0;
+ // 2 halves of the block (128 elements each)
+ for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
+ // 4 groups (each group has 2 blocks of 16 elements)
+ for (var shift: u32 = 0; shift < 8; shift += 2) {
+ // 2 blocks
+ for (var k: u32 = 0; k < 32; k += 16) {
+ let sc = get_byte(block.scales[is / 4], is % 4);
+ is++;
+ let dl = d * f32(sc & 0xF);
+ let ml = m * f32(sc >> 4);
+ for (var l: u32 = 0u; l < 16; l++) {
+ let q_idx = q_b_idx + k + l;
+ let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+ let qs_val = (q_byte >> shift) & 3;
+ dst[dst_i] = (f32(qs_val) * dl - ml);
+ dst_i++;
+ }
+ }
+ }
+ }
+}
+#enddecl(Q2_K)
+
+#decl(Q3_K)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+
+ // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
+ // and 2-bits from the last 4 bytes
+ let kmask1: u32 = 0x03030303;
+ let kmask2: u32 = 0x0f0f0f0f;
+ var scale_vals: array<u32, 4>;
+ for (var i: u32 = 0; i < 4; i++) {
+ scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
+ }
+ var tmp: u32 = scale_vals[2];
+ scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+ scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+ scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
+ scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+
+ // convert arrays of f16 -> u32
+ var hmask_vals: array<u32, 8>;
+ for (var i: u32 = 0; i < 8; i++) {
+ hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
+ }
+ var qs_vals: array<u32, 16>;
+ for (var i: u32 = 0; i < 16; i++) {
+ qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
+ }
+
+ var dst_i = dst_base + offset * 256;
+ var is: u32 = 0;
+ var m: u32 = 1;
+ // 2 halves of the block (128 elements each)
+ for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
+ // 4 groups (each group has 2 blocks of 16 elements)
+ for (var shift: u32 = 0; shift < 8; shift += 2) {
+ // 2 blocks
+ for (var k: u32 = 0; k < 32; k += 16) {
+ let sc = get_byte(scale_vals[is / 4], is % 4);
+ is++;
+ let dl = d * (f32(sc) - 32.0);
+ for (var l: u32 = 0u; l < 16u; l++) {
+ let q_idx = q_b_idx + k + l;
+ let hm_idx = k + l;
+ let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
+ let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
+ let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
+ let qs_val = (q_byte >> shift) & 3;
+ dst[dst_i] = (f32(qs_val) - hm) * dl;
+ dst_i++;
+ }
+ }
+ m <<= 1;
+ }
+ }
+}
+#enddecl(Q3_K)
+
+#decl(Q4_K)
+// 8 blocks of 32 elements each
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+ let m = f32(block.dmin);
+ var dst_i = dst_base + offset * 256;
+ var is: u32 = 0;
+ // 2 blocks each iteration
+ for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
+ for (var shift: u32 = 0; shift < 8; shift += 4) {
+ let scale_min = get_scale_min(is, block.scales);
+ is++;
+ let dl = d * scale_min.x;
+ let ml = m * scale_min.y;
+ for (var l: u32 = 0; l < 32; l++) {
+ let q_idx = q_b_idx + l;
+ let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+ let qs_val = (q_byte >> shift) & 0xF;
+ dst[dst_i] = (f32(qs_val) * dl - ml);
+ dst_i++;
+ }
+ }
+ }
+}
+#enddecl(Q4_K)
+
+#decl(Q5_K)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+ let m = f32(block.dmin);
+ var dst_i = dst_base + offset * 256;
+ var is: u32 = 0;
+ var u: u32 = 1;
+ // 2 blocks each iteration
+ for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
+ for (var shift: u32 = 0; shift < 8; shift += 4) {
+ let scale_min = get_scale_min(is, block.scales);
+ is++;
+ let dl = d * scale_min.x;
+ let ml = m * scale_min.y;
+ for (var l: u32 = 0; l < 32; l++) {
+ let q_idx = q_b_idx + l;
+ let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+ let qh_byte = get_byte(block.qh[l / 4], l % 4);
+ let qs_val = (q_byte >> shift) & 0xF;
+ let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
+ dst[dst_i] = (f32(qs_val) + qh_val) * dl - ml;
+ dst_i++;
+ }
+ u <<= 1;
+ }
+ }
+}
+#enddecl(Q5_K)
+
+#decl(Q6_K)
+// 16 blocks of 16 elements each
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+
+ // convert arrays of f16 -> u32
+ var ql_vals: array<u32, 32>;
+ for (var i: u32 = 0; i < 32; i++) {
+ ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
+ }
+ var qh_vals: array<u32, 16>;
+ for (var i: u32 = 0; i < 16; i++) {
+ qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
+ }
+ var scale_vals: array<u32, 4>;
+ for (var i: u32 = 0; i < 4; i++) {
+ scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
+ }
+
+ var dst_i = dst_base + offset * 256;
+ var qh_b_idx: u32 = 0;
+ var sc_b_idx: u32 = 0;
+ for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
+ for (var l: u32 = 0; l < 32; l++) {
+ let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
+ let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
+ let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
+
+ let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
+ let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
+ let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
+ let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
+
+ let is = l/16;
+ let is1 = sc_b_idx + is;
+ let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
+ let is2 = sc_b_idx + is + 2;
+ let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
+ let is3 = sc_b_idx + is + 4;
+ let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
+ let is4 = sc_b_idx + is + 6;
+ let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
+
+ dst[dst_i + l] = (q1 * f32(sc1)) * d;
+ dst[dst_i + l + 32] = (q2 * f32(sc2)) * d;
+ dst[dst_i + l + 64] = (q3 * f32(sc3)) * d;
+ dst[dst_i + l + 96] = (q4 * f32(sc4)) * d;
+ }
+ dst_i += 128;
+ qh_b_idx += 32;
+ sc_b_idx += 8;
+ }
+}
+
+#enddecl(Q6_K)
+
+#decl(IQ2_XXS)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+ var dst_i = dst_base + offset * 256;
+ for (var ib: u32 = 0; ib < 32; ib += 4) {
+ let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
+ let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
+ let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
+ for (var l: u32 = 0; l < 4; l++) {
+ let ig = get_byte(aux0, l) * 8;
+ let is = (aux1 >> (7 * l)) & 127;
+ let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+ for (var j: u32 = 0; j < 8; j++) {
+ let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
+ let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+ dst[dst_i] = db * f32(g) * m;
+ dst_i++;
+ }
+ }
+ }
+}
+#enddecl(IQ2_XXS)
+
+#decl(IQ2_XS)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+ var dst_i = dst_base + offset * 256;
+ var scale_vals = array<u32, 2>(
+ bitcast<u32>(vec2(block.scales[0], block.scales[1])),
+ bitcast<u32>(vec2(block.scales[2], block.scales[3]))
+ );
+ for (var ib: u32 = 0; ib < 32; ib += 4) {
+ let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
+ let db = array<f32, 2>(
+ d * (0.5 + f32(s & 0xF)) * 0.25,
+ d * (0.5 + f32(s >> 4)) * 0.25
+ );
+ for (var l: u32 = 0; l < 4; l++) {
+ let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
+ let ig = (qs_val & 511) * 8;
+ let is = qs_val >> 9;
+ let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+ let dl = db[l/2];
+ for (var j: u32 = 0; j < 8; j++) {
+ let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
+ let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+ dst[dst_i] = dl * f32(g) * m;
+ dst_i++;
+ }
+ }
+ }
+}
+#enddecl(IQ2_XS)
+
+#decl(IQ2_S)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+ var dst_i = dst_base + offset * 256;
+ var qs_vals : array<u32, 16>;
+ for (var i: u32 = 0; i < 16; i++) {
+ qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
+ }
+ var qh_vals = array<u32, 2>(
+ bitcast<u32>(vec2(block.qh[0], block.qh[1])),
+ bitcast<u32>(vec2(block.qh[2], block.qh[3]))
+ );
+ var scale_vals = array<u32, 2>(
+ bitcast<u32>(vec2(block.scales[0], block.scales[1])),
+ bitcast<u32>(vec2(block.scales[2], block.scales[3]))
+ );
+ for (var ib: u32 = 0; ib < 8; ib ++) {
+ let s = get_byte(scale_vals[ib / 4], ib % 4);
+ let db = array<f32, 2>(
+ d * (0.5 + f32(s & 0xF)) * 0.25,
+ d * (0.5 + f32(s >> 4)) * 0.25
+ );
+ let qs_w = qs_vals[ib];
+ for (var l: u32 = 0; l < 4; l++) {
+ let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
+ let ig = (get_byte(qs_w, l) | qh_b) * 8;
+ let signs = get_byte(qs_vals[ib + 8], l);
+ let dl = db[l/2];
+ for (var j: u32 = 0; j < 8; j++) {
+ let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
+ let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+ dst[dst_i] = dl * f32(g) * m;
+ dst_i++;
+ }
+ }
+ }
+}
+
+#enddecl(IQ2_S)
+
+#decl(IQ3_XSS)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+ var dst_i = dst_base + offset * 256;
+ for (var ib: u32 = 0; ib < 16; ib += 2) {
+ let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
+ let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
+ for (var l: u32 = 0; l < 4; l++) {
+ let is = (sc_sign >> (7 * l)) & 127;
+ let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+ let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
+ let ig1 = get_byte(ig_val, 0);
+ let ig2 = get_byte(ig_val, 1);
+ for (var j: u32 = 0; j < 4; j++) {
+ let g1 = get_byte(iq3xxs_grid[ig1], j);
+ let g2 = get_byte(iq3xxs_grid[ig2], j);
+ let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
+ let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
+ dst[dst_i] = db * f32(g1) * m1;
+ dst[dst_i + 4] = db * f32(g2) * m2;
+ dst_i++;
+ }
+ dst_i += 4;
+ }
+ }
+}
+#enddecl(IQ3_XSS)
+
+#decl(IQ3_S)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+ var dst_i = dst_base + offset * 256;
+ var qh_vals = array<u32, 2>(
+ bitcast<u32>(vec2(block.qh[0], block.qh[1])),
+ bitcast<u32>(vec2(block.qh[2], block.qh[3]))
+ );
+ var sign_vals: array<u32, 8>;
+ for (var i: u32 = 0; i < 8; i++) {
+ sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
+ }
+ var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
+ for (var ib: u32 = 0; ib < 4; ib++) {
+ let s = get_byte(scale_vals, ib);
+ let db = array<f32, 2>(
+ d * (1.0 + 2.0 * f32(s & 0xF)),
+ d * (1.0 + 2.0 * f32(s >> 4))
+ );
+ for (var k: u32 = 0; k < 2; k++) {
+ let dl = db[k];
+ let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
+ let sign_w = sign_vals[ib * 2 + k];
+ for (var l: u32 = 0; l < 4; l++) {
+ let signs = get_byte(sign_w, l);
+ let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
+ let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
+ let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
+ for (var j: u32 = 0; j < 4; j++) {
+ let g1 = get_byte(iq3s_grid[ig1], j);
+ let g2 = get_byte(iq3s_grid[ig2], j);
+ let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
+ let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
+ dst[dst_i] = dl * f32(g1) * m1;
+ dst[dst_i + 4] = dl * f32(g2) * m2;
+ dst_i++;
+ }
+ dst_i += 4;
+ }
+ }
+ }
+}
+#enddecl(IQ3_S)
+
+#decl(IQ1_S)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+ var dst_i = dst_base + offset * 256;
+ for (var ib: u32 = 0; ib < 8; ib++) {
+ let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
+ let dl = d * (2 * f32((qh >> 12) & 7) + 1);
+ let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
+ let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
+ for (var l: u32 = 0; l < 4; l++) {
+ let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
+ for (var j: u32 = 0; j < 8; j++) {
+ let gw = iq1_grid[(ig + j) / 16];
+ let g = (gw >> (((ig + j) % 16) * 2)) & 3;
+ let gs = bitcast<i32>(g << 30) >> 30;
+ dst[dst_i] = dl * (f32(gs) + delta);
+ dst_i++;
+ }
+ }
+ }
+}
+
+#enddecl(IQ1_S)
+
+#decl(IQ1_M)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+
+ let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
+ let d = f32(bitcast<vec2<f16>>(scale).x);
+ var dst_i = dst_base + offset * 256;
+ for (var ib: u32 = 0; ib < 8; ib++) {
+ let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
+ let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
+ let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
+ var dl = array<f32, 2>(
+ d * f32(2 * s1 + 1),
+ d * f32(2 * s2 + 1)
+ );
+
+ let qh = block.qh[ib / 2] >> (16 * (ib % 2));
+ var idx = array<u32, 4>(
+ get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
+ get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
+ get_byte(block.qs[ib], 2) | ((qh) & 0x700),
+ get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
+ );
+ var delta = array<f32, 4>(
+ select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
+ select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
+ select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
+ select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
+ );
+ for (var l: u32 = 0; l < 4; l++) {
+ let ig = idx[l] * 8;
+ for (var j: u32 = 0; j < 8; j++) {
+ let gw = iq1_grid[(ig + j) / 16];
+ let g = (gw >> (((ig + j) % 16) * 2)) & 3;
+ let gs = bitcast<i32>(g << 30) >> 30;
+ dst[dst_i] = dl[l/2] * (f32(gs) + delta[l]);
+ dst_i++;
+ }
+ }
+ }
+}
+
+#enddecl(IQ1_M)
+
+#decl(IQ4_NL)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+ var dst_i = dst_base + offset * 32;
+ var qs: array<u32, 4>;
+ for (var i: u32 = 0; i < 4; i++) {
+ qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
+ }
+ for (var j: u32 = 0; j < 16; j++) {
+ let qsb = get_byte(qs[j / 4], j % 4);
+ dst[dst_i] = d * f32(kvalues_iq4nl[qsb & 0xF]);
+ dst[dst_i + 16] = d * f32(kvalues_iq4nl[qsb >> 4]);
+ dst_i++;
+ }
+}
+#enddecl(IQ4_NL)
+
+#decl(IQ4_XS)
+fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
+ let block = src[src_base + offset];
+ let d = f32(block.d);
+ let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
+ var dst_i = dst_base + offset * 256;
+ for (var ib: u32 = 0; ib < 8; ib++) {
+ let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
+ let dl = d * (f32(ls) - 32.0);
+ for (var j: u32 = 0; j < 16; j++) {
+ let iqs = ib * 16 + j;
+ let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
+ dst[dst_i] = dl * f32(kvalues_iq4nl[qsb & 0xF]);
+ dst[dst_i + 16] = dl * f32(kvalues_iq4nl[qsb >> 4]);
+ dst_i++;
+ }
+ dst_i += 16;
+ }
+}
+#enddecl(IQ4_XS)
+
+#end(DECLS)
+
+#define(SHADER)
+
+enable f16;
+
+DECLS
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<{{TYPE}}>;
+
+@group(0) @binding(1)
+var<storage, read_write> idx: array<i32>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<{{DST_TYPE}}>;
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_idx: u32, // in elements
+ offset_dst: u32, // in elements
+
+ // Strides (in elements)
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_idx0: u32,
+ stride_idx1: u32,
+ stride_idx2: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // Shape of dst
+ ne0: u32,
+ n_rows: u32,
+ ne2: u32,
+ ne3: u32,
+
+ // Shape of idx
+ idx1: u32,
+ idx2: u32,
+};
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
+ return;
+ }
+ var i = gid.x;
+ let i_dst3 = i / (params.ne2 * params.n_rows);
+
+ i = i % (params.ne2 * params.n_rows);
+ let i_dst2 = i / params.n_rows;
+ let i_dst1 = i % params.n_rows;
+
+ let i_idx2 = i_dst3 % params.idx2;
+ let i_idx1 = i_dst2 % params.idx1;
+ let i_idx0 = i_dst1;
+
+ let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
+
+ let idx_val = u32(idx[i_idx]);
+
+ let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
+ let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
+
+ for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
+ copy_elements(i_src_row, i_dst_row, i);
+ }
+}
+
+#end(SHADER)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl
new file mode 100644
index 0000000..03fcd54
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl
@@ -0,0 +1,323 @@
+#define(VARIANTS)
+
+[
+ {
+ "SHADER_NAME": "reglu_f32",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["NO_SPLIT", "REGLU"]
+ },
+ {
+ "SHADER_NAME": "reglu_f32_split",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["SPLIT", "REGLU"]
+ },
+ {
+ "SHADER_NAME": "reglu_f16",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["NO_SPLIT", "REGLU"]
+ },
+ {
+ "SHADER_NAME": "reglu_f16_split",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["SPLIT", "REGLU"]
+ },
+ {
+ "SHADER_NAME": "geglu_f32",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["NO_SPLIT", "GEGLU"]
+ },
+ {
+ "SHADER_NAME": "geglu_f32_split",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["SPLIT", "GEGLU"]
+ },
+ {
+ "SHADER_NAME": "geglu_f16",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["NO_SPLIT", "GEGLU"]
+ },
+ {
+ "SHADER_NAME": "geglu_f16_split",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["SPLIT", "GEGLU"]
+ },
+ {
+ "SHADER_NAME": "swiglu_f32",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["NO_SPLIT", "SWIGLU"]
+ },
+ {
+ "SHADER_NAME": "swiglu_f32_split",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["SPLIT", "SWIGLU"]
+ },
+ {
+ "SHADER_NAME": "swiglu_f16",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["NO_SPLIT", "SWIGLU"]
+ },
+ {
+ "SHADER_NAME": "swiglu_f16_split",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["SPLIT", "SWIGLU"]
+ },
+ {
+ "SHADER_NAME": "swiglu_oai_f32",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["NO_SPLIT", "SWIGLU_OAI"]
+ },
+ {
+ "SHADER_NAME": "swiglu_oai_f32_split",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["SPLIT", "SWIGLU_OAI"]
+ },
+ {
+ "SHADER_NAME": "geglu_erf_f32",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["NO_SPLIT", "GEGLU_ERF"]
+ },
+ {
+ "SHADER_NAME": "geglu_erf_f32_split",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["SPLIT", "GEGLU_ERF"]
+ },
+ {
+ "SHADER_NAME": "geglu_erf_f16",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["NO_SPLIT", "GEGLU_ERF"]
+ },
+ {
+ "SHADER_NAME": "geglu_erf_f16_split",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["SPLIT", "GEGLU_ERF"]
+ },
+ {
+ "SHADER_NAME": "geglu_quick_f32",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
+ },
+ {
+ "SHADER_NAME": "geglu_quick_f32_split",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["SPLIT", "GEGLU_QUICK"]
+ },
+ {
+ "SHADER_NAME": "geglu_quick_f16",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
+ },
+ {
+ "SHADER_NAME": "geglu_quick_f16_split",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["SPLIT", "GEGLU_QUICK"]
+ },
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(REGLU)
+fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
+ return max(a, 0) * b;
+}
+#enddecl(REGLU)
+
+#decl(GEGLU)
+const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;
+const GELU_COEF_A: {{TYPE}} = 0.044715;
+
+fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
+ let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
+ return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;
+}
+#enddecl(GEGLU)
+
+#decl(SWIGLU)
+fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
+ return a / (1.0 + exp(-a)) * b;
+}
+#enddecl(SWIGLU)
+
+#decl(SWIGLU_OAI)
+fn op(a: f32, b: f32) -> f32 {
+ let xi = min(a, params.limit);
+ let gi = max(min(b, params.limit), -params.limit);
+ var out_glu = xi / (1.0 + exp(-xi * params.alpha));
+ out_glu = out_glu * (1.0 + gi);
+ return out_glu;
+}
+#enddecl(SWIGLU_OAI)
+
+#decl(GEGLU_ERF)
+const p_erf: {{TYPE}} = 0.3275911;
+const a1_erf: {{TYPE}} = 0.254829592;
+const a2_erf: {{TYPE}} = -0.284496736;
+const a3_erf: {{TYPE}} = 1.421413741;
+const a4_erf: {{TYPE}} = -1.453152027;
+const a5_erf: {{TYPE}} = 1.061405429;
+const SQRT_2_INV: {{TYPE}} = 0.7071067811865476;
+
+fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
+ let a_div_sqr2 = a * SQRT_2_INV;
+ let sign_x = sign(a_div_sqr2);
+ let x = abs(a_div_sqr2);
+ let t = 1.0 / (1.0 + p_erf * x);
+ let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
+ let erf_approx = sign_x * y;
+ return 0.5 * a * (1.0 + erf_approx) * b;
+}
+#enddecl(GEGLU_ERF)
+
+#decl(GEGLU_QUICK)
+const GELU_QUICK_COEF: {{TYPE}} = -1.702;
+
+fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
+ return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
+}
+#enddecl(GEGLU_QUICK)
+
+#decl(NO_SPLIT)
+@group(0) @binding(1)
+var<storage, read_write> dst: array<{{TYPE}}>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+fn a_value(base: u32) -> {{TYPE}} {
+ let offset: u32 = select(0, params.ne0, params.swapped != 0);
+ return src0[base + offset];
+}
+
+fn b_value(base: u32) -> {{TYPE}} {
+ let offset: u32 = select(params.ne0, 0, params.swapped != 0);
+ return src0[base + offset];
+}
+#enddecl(NO_SPLIT)
+
+#decl(SPLIT)
+@group(0) @binding(1)
+var<storage, read_write> src1: array<{{TYPE}}>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<{{TYPE}}>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+fn a_value(base: u32) -> {{TYPE}} {
+ return src0[base];
+}
+
+fn b_value(base: u32) -> {{TYPE}} {
+ return src1[base];
+}
+#enddecl(SPLIT)
+
+#end(DECLS)
+
+#define(SHADER)
+
+enable f16;
+
+struct Params {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_dst: u32,
+
+ // Strides (in elements)
+ stride_src01: u32,
+ stride_src02: u32,
+ stride_src03: u32,
+
+ stride_src11: u32,
+ stride_src12: u32,
+ stride_src13: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // shape of dst
+ ne: u32,
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ swapped: u32,
+ alpha: f32,
+ limit: f32,
+}
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<{{TYPE}}>;
+
+DECLS
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= params.ne) {
+ return;
+ }
+
+ var i = gid.x;
+ let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+ i = i % (params.ne2 * params.ne1 * params.ne0);
+ let i2 = i / (params.ne1 * params.ne0);
+ i = i % (params.ne1 * params.ne0);
+ let i1 = i / params.ne0;
+ let i0 = i % params.ne0;
+
+ let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
+ let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
+ let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
+
+ dst[i_dst] = op(a_value(i_a), b_value(i_b));
+}
+
+#end(SHADER)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl
new file mode 100644
index 0000000..194d2d6
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl
@@ -0,0 +1,40 @@
+@group(0) @binding(0)
+var<storage, read_write> output_buffer: array<u32>;
+
+struct Params {
+ offset: u32, // in bytes
+ size: u32, // in bytes
+ value: u32, // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations)
+};
+
+@group(0) @binding(1)
+var<uniform> params: Params;
+
+override wg_size: u32;
+override bytes_per_thread: u32;
+
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ let i = gid.x * bytes_per_thread;
+ let start = params.offset;
+ let end = params.offset + params.size;
+
+ for (var j: u32 = 0u; j < bytes_per_thread; j += 4) {
+ let byte_index = start + i + j;
+ if (byte_index + 4 <= end) {
+ output_buffer[byte_index >> 2] = params.value;
+ } else {
+ // Handle tail (unaligned)
+ for (var k: u32 = 0; k < 4; k++) {
+ let idx = byte_index + k;
+ if (idx < end) {
+ let word_idx = idx >> 2;
+ let bit_offset = (idx & 3) * 8u;
+ let mask = ~(0xffu << bit_offset);
+ let existing = output_buffer[word_idx];
+ output_buffer[word_idx] = (existing & mask) | (params.value & (0xffu << bit_offset));
+ }
+ }
+ }
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
new file mode 100644
index 0000000..0f8e6e5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
@@ -0,0 +1,907 @@
+#define(VARIANTS)
+
+[
+ {
+ "REPLS": {
+ "SRC0_TYPE" : "f32",
+ "SRC1_TYPE" : "f32",
+ "BLOCK_SIZE" : 1
+ },
+ "DECLS" : ["FLOAT"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "f16",
+ "BLOCK_SIZE" : 1
+ },
+ "DECLS" : ["FLOAT"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "f32",
+ "BLOCK_SIZE" : 1
+ },
+ "DECLS" : ["FLOAT"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "q4_0",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 32
+ },
+ "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "q4_1",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 32
+ },
+ "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "q5_0",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 32
+ },
+ "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "q5_1",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 32
+ },
+ "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "q8_0",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 32
+ },
+ "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "q2_k",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "q3_k",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "q4_k",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "q5_k",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "q6_k",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "iq2_xxs",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "iq2_xs",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "iq2_s",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "iq3_xxs",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "iq3_s",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "iq1_s",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "iq1_m",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "iq4_nl",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 32,
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
+ },
+ {
+ "REPLS": {
+ "SRC0_TYPE": "iq4_xs",
+ "SRC1_TYPE": "f32",
+ "BLOCK_SIZE": 256,
+ },
+ "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
+ }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(FLOAT)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
+}
+#enddecl(FLOAT)
+
+#decl(Q4_0)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block_q4_0 = src0[src0_idx_base + offset];
+ let d = f32(block_q4_0.d);
+ var sum: f32 = 0.0;
+ for (var j: u32 = 0; j < 4; j++) {
+ let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1]));
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
+ let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
+ let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+ sum += q_lo * f32(src1[src1_offset]);
+ sum += q_hi * f32(src1[src1_offset + 16]);
+ }
+ }
+ return sum;
+}
+#enddecl(Q4_0)
+
+#decl(Q4_1)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block_q4_1 = src0[src0_idx_base + offset];
+ let d = f32(block_q4_1.d);
+ let m = f32(block_q4_1.m);
+ var sum: f32 = 0.0;
+ for (var j: u32 = 0; j < 4; j++) {
+ let q_packed = block_q4_1.qs[j];
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
+ let q_lo = f32(q_byte & 0xF) * d + m;
+ let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+ sum += q_lo * f32(src1[src1_offset]);
+ sum += q_hi * f32(src1[src1_offset + 16]);
+ }
+ }
+ return sum;
+}
+#enddecl(Q4_1)
+
+#decl(Q5_0)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block_q5_0 = src0[src0_idx_base + offset];
+ let d = f32(block_q5_0.d);
+ var sum: f32 = 0.0;
+ let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
+ for (var j: u32 = 0; j < 4; j++) {
+ let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
+ let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
+ let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
+ let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
+ let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+ sum += q_lo * f32(src1[src1_offset]);
+ sum += q_hi * f32(src1[src1_offset + 16]);
+ }
+ }
+ return sum;
+}
+#enddecl(Q5_0)
+
+#decl(Q5_1)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block_q5_1 = src0[src0_idx_base + offset];
+ let d = f32(block_q5_1.d);
+ let m = f32(block_q5_1.m);
+ var sum: f32 = 0.0;
+ for (var j: u32 = 0; j < 4; j++) {
+ let q_packed = block_q5_1.qs[j];
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
+ let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
+ let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
+ let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
+ let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+ sum += q_lo * f32(src1[src1_offset]);
+ sum += q_hi * f32(src1[src1_offset + 16]);
+ }
+ }
+ return sum;
+}
+#enddecl(Q5_1)
+
+#decl(Q8_0)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block_q8_0 = src0[src0_idx_base + offset];
+ let d = f32(block_q8_0.d);
+ var sum: f32 = 0.0;
+ for (var j: u32 = 0; j < 8; j++) {
+ let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte_i32(q_packed, k);
+ let q_val = f32(q_byte) * d;
+ let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+ sum += q_val * f32(src1[src1_offset]);
+ }
+ }
+ return sum;
+}
+#enddecl(Q8_0)
+
+#decl(Q8_1)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block_q8_1 = src0[src0_idx_base + offset];
+ let d = f32(block_q8_1.d);
+ let m = f32(block_q8_1.m);
+ var sum: f32 = 0.0;
+ for (var j: u32 = 0; j < 8; j++) {
+ let q_packed = block_q8_1.qs[j];
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte_i32(q_packed, k);
+ let q_val = f32(q_byte) * d + m;
+ let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
+ sum += q_val * f32(src1[src1_offset]);
+ }
+ }
+ return sum;
+}
+#enddecl(Q8_1)
+
+#decl(Q2_K)
+// 16 blocks of 16 elements each
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+ let m = f32(block.dmin);
+ var sum = 0.0;
+ var src1_i = src1_idx_base + offset * 256;
+ var is: u32 = 0;
+ // 2 halves of the block (128 elements each)
+ for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
+ // 4 groups (each group has 2 blocks of 16 elements)
+ for (var shift: u32 = 0; shift < 8; shift += 2) {
+ // 2 blocks
+ for (var k: u32 = 0; k < 32; k += 16) {
+ let sc = get_byte(block.scales[is / 4], is % 4);
+ is++;
+ let dl = d * f32(sc & 0xF);
+ let ml = m * f32(sc >> 4);
+ for (var l: u32 = 0u; l < 16; l++) {
+ let q_idx = q_b_idx + k + l;
+ let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+ let qs_val = (q_byte >> shift) & 3;
+ sum += (f32(qs_val) * dl - ml) * src1[src1_i];
+ src1_i++;
+ }
+ }
+ }
+ }
+ return sum;
+}
+
+#enddecl(Q2_K)
+
+#decl(Q3_K)
+// 16 blocks of 16 elements each
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+
+ // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
+ // and 2-bits from the last 4 bytes
+ let kmask1: u32 = 0x03030303;
+ let kmask2: u32 = 0x0f0f0f0f;
+ var scale_vals: array<u32, 4>;
+ for (var i: u32 = 0; i < 4; i++) {
+ scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
+ }
+ var tmp: u32 = scale_vals[2];
+ scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+ scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+ scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
+ scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+
+ // convert arrays of f16 -> u32
+ var hmask_vals: array<u32, 8>;
+ for (var i: u32 = 0; i < 8; i++) {
+ hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
+ }
+ var qs_vals: array<u32, 16>;
+ for (var i: u32 = 0; i < 16; i++) {
+ qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
+ }
+
+ var sum = 0.0;
+ var src1_i = src1_idx_base + offset * 256;
+ var is: u32 = 0;
+ var m: u32 = 1;
+ // 2 halves of the block (128 elements each)
+ for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
+ // 4 groups (each group has 2 blocks of 16 elements)
+ for (var shift: u32 = 0; shift < 8; shift += 2) {
+ // 2 blocks
+ for (var k: u32 = 0; k < 32; k += 16) {
+ let sc = get_byte(scale_vals[is / 4], is % 4);
+ is++;
+ let dl = d * (f32(sc) - 32.0);
+ for (var l: u32 = 0u; l < 16u; l++) {
+ let q_idx = q_b_idx + k + l;
+ let hm_idx = k + l;
+ let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
+ let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
+ let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
+ let qs_val = (q_byte >> shift) & 3;
+ sum += ((f32(qs_val) - hm) * dl) * src1[src1_i];
+ src1_i++;
+ }
+ }
+ m <<= 1;
+ }
+ }
+ return sum;
+}
+
+#enddecl(Q3_K)
+
+#decl(Q4_K)
+// 8 blocks of 32 elements each
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+ let m = f32(block.dmin);
+ var sum = 0.0;
+ var src1_i = src1_idx_base + offset * 256;
+ var is: u32 = 0;
+ // 2 blocks each iteration
+ for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
+ for (var shift: u32 = 0; shift < 8; shift += 4) {
+ let scale_min = get_scale_min(is, block.scales);
+ is++;
+ let dl = d * scale_min.x;
+ let ml = m * scale_min.y;
+ for (var l: u32 = 0; l < 32; l++) {
+ let q_idx = q_b_idx + l;
+ let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+ let qs_val = (q_byte >> shift) & 0xF;
+ sum += (f32(qs_val) * dl - ml) * src1[src1_i];
+ src1_i++;
+ }
+ }
+ }
+ return sum;
+}
+
+#enddecl(Q4_K)
+
+#decl(Q5_K)
+// 8 blocks of 32 elements each
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+ let m = f32(block.dmin);
+ var sum = 0.0;
+ var src1_i = src1_idx_base + offset * 256;
+ var is: u32 = 0;
+ var u: u32 = 1;
+ // 2 blocks each iteration
+ for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
+ for (var shift: u32 = 0; shift < 8; shift += 4) {
+ let scale_min = get_scale_min(is, block.scales);
+ is++;
+ let dl = d * scale_min.x;
+ let ml = m * scale_min.y;
+ for (var l: u32 = 0; l < 32; l++) {
+ let q_idx = q_b_idx + l;
+ let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
+ let qh_byte = get_byte(block.qh[l / 4], l % 4);
+ let qs_val = (q_byte >> shift) & 0xF;
+ let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
+ sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i];
+ src1_i++;
+ }
+ u <<= 1;
+ }
+ }
+ return sum;
+}
+
+#enddecl(Q5_K)
+
+#decl(Q6_K)
+// 16 blocks of 16 elements each
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+
+ // convert arrays of f16 -> u32
+ var ql_vals: array<u32, 32>;
+ for (var i: u32 = 0; i < 32; i++) {
+ ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
+ }
+ var qh_vals: array<u32, 16>;
+ for (var i: u32 = 0; i < 16; i++) {
+ qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
+ }
+ var scale_vals: array<u32, 4>;
+ for (var i: u32 = 0; i < 4; i++) {
+ scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
+ }
+
+ var sum = 0.0;
+ var src1_i = src1_idx_base + offset * 256;
+ var qh_b_idx: u32 = 0;
+ var sc_b_idx: u32 = 0;
+ for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
+ for (var l: u32 = 0; l < 32; l++) {
+ let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
+ let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
+ let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);
+
+ let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
+ let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
+ let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
+ let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
+
+ let is = l/16;
+ let is1 = sc_b_idx + is;
+ let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
+ let is2 = sc_b_idx + is + 2;
+ let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
+ let is3 = sc_b_idx + is + 4;
+ let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
+ let is4 = sc_b_idx + is + 6;
+ let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);
+
+ sum += d * f32(sc1) * q1 * src1[src1_i + l];
+ sum += d * f32(sc2) * q2 * src1[src1_i + l + 32];
+ sum += d * f32(sc3) * q3 * src1[src1_i + l + 64];
+ sum += d * f32(sc4) * q4 * src1[src1_i + l + 96];
+ }
+ src1_i += 128;
+ qh_b_idx += 32;
+ sc_b_idx += 8;
+ }
+ return sum;
+}
+
+#enddecl(Q6_K)
+
+#decl(IQ2_XXS)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+ var src1_i = src1_idx_base + offset * 256;
+ var sum = 0.0;
+ for (var ib: u32 = 0; ib < 32; ib += 4) {
+ let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1]));
+ let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3]));
+ let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
+ for (var l: u32 = 0; l < 4; l++) {
+ let ig = get_byte(aux0, l) * 8;
+ let is = (aux1 >> (7 * l)) & 127;
+ let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+ for (var j: u32 = 0; j < 8; j++) {
+ let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
+ let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+ sum += db * f32(g) * m * src1[src1_i];
+ src1_i++;
+ }
+ }
+ }
+ return sum;
+}
+
+#enddecl(IQ2_XXS)
+
+#decl(IQ2_XS)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+ var src1_i = src1_idx_base + offset * 256;
+ var scale_vals = array<u32, 2>(
+ bitcast<u32>(vec2(block.scales[0], block.scales[1])),
+ bitcast<u32>(vec2(block.scales[2], block.scales[3]))
+ );
+ var sum = 0.0;
+ for (var ib: u32 = 0; ib < 32; ib += 4) {
+ let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
+ let db = array<f32, 2>(
+ d * (0.5 + f32(s & 0xF)) * 0.25,
+ d * (0.5 + f32(s >> 4)) * 0.25
+ );
+ for (var l: u32 = 0; l < 4; l++) {
+ let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0));
+ let ig = (qs_val & 511) * 8;
+ let is = qs_val >> 9;
+ let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+ let dl = db[l/2];
+ for (var j: u32 = 0; j < 8; j++) {
+ let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
+ let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+ sum += dl * f32(g) * m * src1[src1_i];
+ src1_i++;
+ }
+ }
+ }
+ return sum;
+}
+
+#enddecl(IQ2_XS)
+
+#decl(IQ2_S)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+ var src1_i = src1_idx_base + offset * 256;
+ var qs_vals : array<u32, 16>;
+ for (var i: u32 = 0; i < 16; i++) {
+ qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
+ }
+ var qh_vals = array<u32, 2>(
+ bitcast<u32>(vec2(block.qh[0], block.qh[1])),
+ bitcast<u32>(vec2(block.qh[2], block.qh[3]))
+ );
+ var scale_vals = array<u32, 2>(
+ bitcast<u32>(vec2(block.scales[0], block.scales[1])),
+ bitcast<u32>(vec2(block.scales[2], block.scales[3]))
+ );
+ var sum = 0.0;
+ for (var ib: u32 = 0; ib < 8; ib ++) {
+ let s = get_byte(scale_vals[ib / 4], ib % 4);
+ let db = array<f32, 2>(
+ d * (0.5 + f32(s & 0xF)) * 0.25,
+ d * (0.5 + f32(s >> 4)) * 0.25
+ );
+ let qs_w = qs_vals[ib];
+ for (var l: u32 = 0; l < 4; l++) {
+ let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
+ let ig = (get_byte(qs_w, l) | qh_b) * 8;
+ let signs = get_byte(qs_vals[ib + 8], l);
+ let dl = db[l/2];
+ for (var j: u32 = 0; j < 8; j++) {
+ let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
+ let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
+ sum += dl * f32(g) * m * src1[src1_i];
+ src1_i++;
+ }
+ }
+ }
+ return sum;
+}
+
+
+#enddecl(IQ2_S)
+
+#decl(IQ3_XSS)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+ var src1_i = src1_idx_base + offset * 256;
+ var sum = 0.0;
+ for (var ib: u32 = 0; ib < 16; ib += 2) {
+ let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33]));
+ let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
+ for (var l: u32 = 0; l < 4; l++) {
+ let is = (sc_sign >> (7 * l)) & 127;
+ let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
+ let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0));
+ let ig1 = get_byte(ig_val, 0);
+ let ig2 = get_byte(ig_val, 1);
+ for (var j: u32 = 0; j < 4; j++) {
+ let g1 = get_byte(iq3xxs_grid[ig1], j);
+ let g2 = get_byte(iq3xxs_grid[ig2], j);
+ let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
+ let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
+ sum += db * f32(g1) * m1 * src1[src1_i];
+ sum += db * f32(g2) * m2 * src1[src1_i + 4];
+ src1_i++;
+ }
+ src1_i += 4;
+ }
+ }
+ return sum;
+}
+
+#enddecl(IQ3_XSS)
+
+#decl(IQ3_S)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+ var src1_i = src1_idx_base + offset * 256;
+ var qh_vals = array<u32, 2>(
+ bitcast<u32>(vec2(block.qh[0], block.qh[1])),
+ bitcast<u32>(vec2(block.qh[2], block.qh[3]))
+ );
+ var sign_vals: array<u32, 8>;
+ for (var i: u32 = 0; i < 8; i++) {
+ sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1]));
+ }
+ var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1]));
+ var sum = 0.0;
+ for (var ib: u32 = 0; ib < 4; ib++) {
+ let s = get_byte(scale_vals, ib);
+ let db = array<f32, 2>(
+ d * (1.0 + 2.0 * f32(s & 0xF)),
+ d * (1.0 + 2.0 * f32(s >> 4))
+ );
+ for (var k: u32 = 0; k < 2; k++) {
+ let dl = db[k];
+ let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
+ let sign_w = sign_vals[ib * 2 + k];
+ for (var l: u32 = 0; l < 4; l++) {
+ let signs = get_byte(sign_w, l);
+ let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0));
+ let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
+ let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
+ for (var j: u32 = 0; j < 4; j++) {
+ let g1 = get_byte(iq3s_grid[ig1], j);
+ let g2 = get_byte(iq3s_grid[ig2], j);
+ let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
+ let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
+ sum += dl * f32(g1) * m1 * src1[src1_i];
+ sum += dl * f32(g2) * m2 * src1[src1_i + 4];
+ src1_i++;
+ }
+ src1_i += 4;
+ }
+ }
+ }
+ return sum;
+}
+#enddecl(IQ3_S)
+
+#decl(IQ1_S)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+ var src1_i = src1_idx_base + offset * 256;
+ var sum = 0.0;
+ for (var ib: u32 = 0; ib < 8; ib++) {
+ let qh = bitcast<u32>(vec2(block.qh[ib], 0.0));
+ let dl = d * (2 * f32((qh >> 12) & 7) + 1);
+ let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
+ let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1]));
+ for (var l: u32 = 0; l < 4; l++) {
+ let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
+ for (var j: u32 = 0; j < 8; j++) {
+ let gw = iq1_grid[(ig + j) / 16];
+ let g = (gw >> (((ig + j) % 16) * 2)) & 3;
+ let gs = bitcast<i32>(g << 30) >> 30;
+ sum += dl * (f32(gs) + delta) * src1[src1_i];
+ src1_i++;
+ }
+ }
+ }
+ return sum;
+}
+
+#enddecl(IQ1_S)
+
+#decl(IQ1_M)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+
+ let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
+ let d = f32(bitcast<vec2<f16>>(scale).x);
+ var src1_i = src1_idx_base + offset * 256;
+ var sum = 0.0;
+ for (var ib: u32 = 0; ib < 8; ib++) {
+ let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
+ let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
+ let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
+ var dl = array<f32, 2>(
+ d * f32(2 * s1 + 1),
+ d * f32(2 * s2 + 1)
+ );
+
+ let qh = block.qh[ib / 2] >> (16 * (ib % 2));
+ var idx = array<u32, 4>(
+ get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
+ get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
+ get_byte(block.qs[ib], 2) | ((qh) & 0x700),
+ get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
+ );
+ var delta = array<f32, 4>(
+ select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
+ select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
+ select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
+ select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
+ );
+ for (var l: u32 = 0; l < 4; l++) {
+ let ig = idx[l] * 8;
+ for (var j: u32 = 0; j < 8; j++) {
+ let gw = iq1_grid[(ig + j) / 16];
+ let g = (gw >> (((ig + j) % 16) * 2)) & 3;
+ let gs = bitcast<i32>(g << 30) >> 30;
+ sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i];
+ src1_i++;
+ }
+ }
+ }
+ return sum;
+}
+
+#enddecl(IQ1_M)
+
+#decl(IQ4_NL)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+ var src1_i = src1_idx_base + offset * 32;
+ var sum = 0.0;
+ var qs: array<u32, 4>;
+ for (var i: u32 = 0; i < 4; i++) {
+ qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1]));
+ }
+ for (var j: u32 = 0; j < 16; j++) {
+ let qsb = get_byte(qs[j / 4], j % 4);
+ sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
+ sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
+ src1_i++;
+ }
+ return sum;
+}
+
+#enddecl(IQ4_NL)
+
+#decl(IQ4_XS)
+fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
+ let block = src0[src0_idx_base + offset];
+ let d = f32(block.d);
+ let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0));
+ var src1_i = src1_idx_base + offset * 256;
+ var sum = 0.0;
+ for (var ib: u32 = 0; ib < 8; ib++) {
+ let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
+ let dl = d * (f32(ls) - 32.0);
+ for (var j: u32 = 0; j < 16; j++) {
+ let iqs = ib * 16 + j;
+ let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
+ sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
+ sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
+ src1_i++;
+ }
+ src1_i += 16;
+ }
+ return sum;
+}
+
+#enddecl(IQ4_XS)
+
+#end(DECLS)
+
+#define(SHADER)
+
+enable f16;
+
+DECLS
+
+struct MulMatParams {
+ offset_src0: u32, // in elements/blocks
+ offset_src1: u32, // in elements/blocks
+ offset_dst: u32, // in elements/blocks
+ m: u32,
+ n: u32,
+ k: u32,
+ // all strides are in elements/blocks
+ stride_01: u32,
+ stride_11: u32,
+ stride_02: u32,
+ stride_12: u32,
+ stride_03: u32,
+ stride_13: u32,
+
+ bs02: u32,
+ bs03: u32,
+ broadcast2: u32,
+ broadcast3: u32
+};
+
+@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
+@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
+@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
+
+@group(0) @binding(3) var<uniform> params: MulMatParams;
+
+@compute @workgroup_size(256)
+fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
+ let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+ if (global_id.x >= total) {
+ return;
+ }
+
+ let dst2_stride = params.m * params.n;
+ let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+
+ let dst3_idx = global_id.x / dst3_stride;
+ let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
+ let src13_idx = dst3_idx; // src1 is not broadcast
+ let dst3_rem = global_id.x % dst3_stride;
+
+ let dst2_idx = dst3_rem / dst2_stride;
+ let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
+ let src12_idx = dst2_idx; // src1 is not broadcast
+
+ let dst2_rem = dst3_rem % dst2_stride;
+
+ let row = dst2_rem / params.m; // output row
+ let col = dst2_rem % params.m; // output column
+
+ let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
+ let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
+
+ var sum = 0.0;
+ for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
+ sum += multiply_add(src0_idx_base, src1_idx_base, i);
+ }
+ dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
+}
+
+#end(SHADER)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
new file mode 100644
index 0000000..109ff8d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
@@ -0,0 +1,97 @@
+#decl(SHMEM_VEC)
+fn store_shmem(val: vec4<f16>, idx: u32) {
+ shmem[idx] = val.x;
+ shmem[idx + 1] = val.y;
+ shmem[idx + 2] = val.z;
+ shmem[idx + 3] = val.w;
+}
+#enddecl(SHMEM_VEC)
+
+#decl(SHMEM_SCALAR)
+fn store_shmem(val: f16, idx: u32) {
+ shmem[idx] = val;
+}
+#enddecl(SHMEM_SCALAR)
+
+#decl(INIT_SRC0_SHMEM_FLOAT)
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+ let tile_m = elem_idx / TILE_K;
+ let tile_k = elem_idx % TILE_K;
+ let global_m = offset_m + tile_m;
+ let global_k = k_outer + tile_k;
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+ let src0_val = select( // taking a slight performance hit to avoid oob
+ {{SRC0_TYPE}}(0.0),
+ src0[src0_idx/{{VEC_SIZE}}],
+ global_m < params.m && global_k < params.k);
+ store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
+ }
+}
+
+#enddecl(INIT_SRC0_SHMEM_FLOAT)
+
+#decl(INIT_SRC1_SHMEM)
+
+fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
+ for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+ let tile_n = elem_idx / TILE_K;
+ let tile_k = elem_idx % TILE_K;
+ let global_n = offset_n + tile_n;
+ let global_k = k_outer + tile_k;
+ let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
+ let src1_val = select(
+ {{SRC1_TYPE}}(0.0),
+ src1[src1_idx/{{VEC_SIZE}}],
+ global_n < params.n && global_k < params.k);
+ store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
+ }
+}
+
+#enddecl(INIT_SRC1_SHMEM)
+
+#decl(INIT_SRC0_SHMEM_Q4_0)
+
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+ let tile_m = blck_idx / BLOCKS_K;
+ let global_m = offset_m + tile_m;
+ let block_k = blck_idx % BLOCKS_K;
+ let global_k = k_outer / BLOCK_SIZE + block_k;
+
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+ let d = src0[scale_idx];
+
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = src0[scale_idx + 1u + block_offset + j];
+ let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
+
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
+ let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
+ shmem[shmem_idx + j * 2 + k] = q_lo;
+ shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
+ }
+ }
+ }
+ }
+}
+
+#enddecl(INIT_SRC0_SHMEM_Q4_0)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl
new file mode 100644
index 0000000..6b1dd26
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl
@@ -0,0 +1,247 @@
+#define(VARIANTS)
+[
+ {
+ "SHADER_SUFFIX": "f32_f32_vec",
+ "REPLS": {
+ "SRC0_TYPE" : "vec4<f32>",
+ "SRC1_TYPE" : "vec4<f32>",
+ "DST_TYPE" : "vec4<f32>",
+ "SHMEM_TYPE" : "vec4<f16>",
+ "VEC_SIZE" : 4,
+ },
+ "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "f32_f32",
+ "REPLS": {
+ "SRC0_TYPE" : "f32",
+ "SRC1_TYPE" : "f32",
+ "DST_TYPE" : "f32",
+ "SHMEM_TYPE" : "f16",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f32_vec",
+ "REPLS": {
+ "SRC0_TYPE" : "vec4<f16>",
+ "SRC1_TYPE" : "vec4<f32>",
+ "DST_TYPE" : "vec4<f32>",
+ "SHMEM_TYPE" : "vec4<f16>",
+ "VEC_SIZE" : 4,
+ },
+ "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f32",
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "f32",
+ "DST_TYPE" : "f32",
+ "SHMEM_TYPE" : "f16",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f16_vec",
+ "REPLS": {
+ "SRC0_TYPE" : "vec4<f16>",
+ "SRC1_TYPE" : "vec4<f16>",
+ "DST_TYPE" : "vec4<f32>",
+ "SHMEM_TYPE" : "vec4<f16>",
+ "VEC_SIZE" : 4,
+ },
+ "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f16",
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "f16",
+ "DST_TYPE" : "f32",
+ "SHMEM_TYPE" : "f16",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "q4_0_f32_vec",
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "vec4<f32>",
+ "DST_TYPE" : "vec4<f32>",
+ "SHMEM_TYPE" : "vec4<f16>",
+ "VEC_SIZE" : 4,
+ },
+ "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "q4_0_f32",
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "f32",
+ "DST_TYPE" : "f32",
+ "SHMEM_TYPE" : "f16",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
+ }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(VEC)
+fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
+ return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
+}
+#enddecl(VEC)
+
+#decl(SCALAR)
+fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
+ return f32(acc[tm][tn]);
+}
+#enddecl(SCALAR)
+
+#end(DECLS)
+
+#define(SHADER)
+enable f16;
+
+struct MulMatParams {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_dst: u32,
+ m: u32,
+ n: u32,
+ k: u32,
+ stride_01: u32,
+ stride_11: u32,
+ stride_02: u32,
+ stride_12: u32,
+ stride_03: u32,
+ stride_13: u32,
+ bs02: u32,
+ bs03: u32,
+ broadcast2: u32,
+ broadcast3: u32
+};
+
+@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
+@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
+@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
+
+@group(0) @binding(3) var<uniform> params: MulMatParams;
+
+DECLS
+
+fn get_local_n(thread_id: u32) -> u32 {
+ return thread_id / WORKGROUP_SIZE_M;
+}
+fn get_local_m(thread_id: u32) -> u32 {
+ return thread_id % WORKGROUP_SIZE_M;
+}
+
+// TILE_M must be multiple of 4 for vec4 loads
+const TILE_M = {{WEBGPU_TILE_M}}u;
+const TILE_N = {{WEBGPU_TILE_N}}u;
+
+override WORKGROUP_SIZE_M: u32;
+override WORKGROUP_SIZE_N: u32;
+override TILE_K: u32;
+
+override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
+override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
+override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
+
+var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
+
+@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+ @builtin(local_invocation_id) local_id: vec3<u32>) {
+
+ let thread_id = local_id.x;
+ let local_m = get_local_m(thread_id);
+ let local_n = get_local_n(thread_id);
+
+ let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N);
+ let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
+ let wg_per_matrix = wg_m_count * wg_n_count;
+
+ let batch_idx = wg_id.x / wg_per_matrix;
+
+ let wg_in_batch = wg_id.x % wg_per_matrix;
+ let wg_m = wg_in_batch % wg_m_count;
+ let wg_n = wg_in_batch / wg_m_count;
+
+ let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M;
+ let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N;
+
+ let dst2_stride = params.m * params.n;
+ let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+
+ let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
+ let src03_idx = dst3_idx / params.broadcast3;
+ let src13_idx = dst3_idx;
+ let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
+ let src02_idx = dst2_idx / params.broadcast2;
+ let src12_idx = dst2_idx;
+
+ let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
+ let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
+
+ let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
+ let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
+
+ var acc: array<array<f16, TILE_N>, TILE_M>;
+
+ for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
+
+ // see mul_mat_decls.tmpl
+ init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
+ init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
+
+ workgroupBarrier();
+
+ let k_end = min(TILE_K, params.k - k_outer);
+
+ for (var k_inner = 0u; k_inner < k_end; k_inner++) {
+ var src0_tile: array<f16, TILE_M>;
+ for (var tm = 0u; tm < TILE_M; tm++) {
+ let src0_m = local_m * TILE_M + tm;
+ let src0_idx = k_inner + src0_m * TILE_K;
+ src0_tile[tm] = shmem[src0_idx];
+ }
+ for (var tn = 0u; tn < TILE_N; tn++) {
+ let src1_n = local_n * TILE_N + tn;
+ let src1_idx = src1_n * TILE_K + k_inner;
+ let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
+ for (var tm = 0u; tm < TILE_M; tm++) {
+ acc[tm][tn] += src0_tile[tm] * src1_val;
+ }
+ }
+ }
+
+ workgroupBarrier();
+ }
+
+ let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
+
+ for (var tn = 0u; tn < TILE_N; tn++) {
+ let global_col = output_col_base + tn;
+ if (global_col < params.n) {
+ for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
+ let global_row = output_row_base + tm;
+ if (global_row < params.m) {
+ let dst_idx = dst_batch_offset + global_col * params.m + global_row;
+ dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
+ }
+ }
+ }
+ }
+}
+
+#end(SHADER)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl
new file mode 100644
index 0000000..47c8ce3
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl
@@ -0,0 +1,302 @@
+#define(VARIANTS)
+[
+ {
+ "SHADER_SUFFIX": "f32_f32_vec",
+ "REPLS": {
+ "SRC0_TYPE" : "vec4<f32>",
+ "SRC1_TYPE" : "vec4<f32>",
+ "DST_TYPE" : "vec4<f32>",
+ "SHMEM_TYPE" : "vec4<f16>",
+ "VEC_SIZE" : 4,
+ },
+ "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "f32_f32",
+ "REPLS": {
+ "SRC0_TYPE" : "f32",
+ "SRC1_TYPE" : "f32",
+ "DST_TYPE" : "f32",
+ "SHMEM_TYPE" : "f16",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f32_vec",
+ "REPLS": {
+ "SRC0_TYPE" : "vec4<f16>",
+ "SRC1_TYPE" : "vec4<f32>",
+ "DST_TYPE" : "vec4<f32>",
+ "SHMEM_TYPE" : "vec4<f16>",
+ "VEC_SIZE" : 4,
+ },
+ "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f32",
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "f32",
+ "DST_TYPE" : "f32",
+ "SHMEM_TYPE" : "f16",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f16_vec",
+ "REPLS": {
+ "SRC0_TYPE" : "vec4<f16>",
+ "SRC1_TYPE" : "vec4<f16>",
+ "DST_TYPE" : "vec4<f32>",
+ "SHMEM_TYPE" : "vec4<f16>",
+ "VEC_SIZE" : 4,
+ },
+ "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f16",
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "f16",
+ "DST_TYPE" : "f32",
+ "SHMEM_TYPE" : "f16",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "q4_0_f32_vec",
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "vec4<f32>",
+ "DST_TYPE" : "vec4<f32>",
+ "SHMEM_TYPE" : "vec4<f16>",
+ "VEC_SIZE" : 4,
+ },
+ "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
+ },
+ {
+ "SHADER_SUFFIX": "q4_0_f32",
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "f32",
+ "DST_TYPE" : "f32",
+ "SHMEM_TYPE" : "f16",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
+ }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(VEC)
+fn store_dst(shmem_idx: u32, dst_idx: u32) {
+ dst[dst_idx] = vec4<f32>(
+ f32(shmem[shmem_idx]),
+ f32(shmem[shmem_idx + 1]),
+ f32(shmem[shmem_idx + 2]),
+ f32(shmem[shmem_idx + 3])
+ );
+}
+#enddecl(VEC)
+
+#decl(SCALAR)
+fn store_dst(shmem_idx: u32, dst_idx: u32) {
+ dst[dst_idx] = f32(shmem[shmem_idx]);
+}
+#enddecl(SCALAR)
+
+#end(DECLS)
+
+#define(SHADER)
+diagnostic(off, chromium.subgroup_matrix_uniformity);
+enable f16;
+enable subgroups;
+enable chromium_experimental_subgroup_matrix;
+
+struct MulMatParams {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_dst: u32,
+ m: u32,
+ n: u32,
+ k: u32,
+ stride_01: u32,
+ stride_11: u32,
+ stride_02: u32,
+ stride_12: u32,
+ stride_03: u32,
+ stride_13: u32,
+ bs02: u32,
+ bs03: u32,
+ broadcast2: u32,
+ broadcast3: u32
+};
+
+@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
+@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
+@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
+
+@group(0) @binding(3) var<uniform> params: MulMatParams;
+
+DECLS
+
+// Note: These are string interpolated at build time, cannot use override constants due to limitations in
+// current Dawn version type definitions/matrix load requirements for constant memory sizes.
+const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
+const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
+// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
+// runtime subgroup size is smaller.
+const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
+
+const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
+
+const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
+const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
+const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
+
+const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
+const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
+
+const TILE_K = {{WEBGPU_TILE_K}}u;
+
+const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
+const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE;
+
+// We reuse shmem for accumulation matrices
+const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM);
+
+var<workgroup> shmem: array<f16, SHMEM_SIZE>;
+
+@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+ @builtin(local_invocation_id) local_id: vec3<u32>,
+ @builtin(subgroup_id) subgroup_id: u32) {
+
+ let thread_id = local_id.x;
+ let subgroup_m = subgroup_id % SUBGROUP_M;
+ let subgroup_n = subgroup_id / SUBGROUP_M;
+
+ let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE;
+ let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
+ let wg_per_matrix = wg_m_count * wg_n_count;
+
+ let batch_idx = wg_id.x / wg_per_matrix;
+
+ let wg_in_batch = wg_id.x % wg_per_matrix;
+ let wg_m = wg_in_batch % wg_m_count;
+ let wg_n = wg_in_batch / wg_m_count;
+
+ let dst2_stride = params.m * params.n;
+ let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+
+ let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
+ let src03_idx = dst3_idx / params.broadcast3;
+ let src13_idx = dst3_idx;
+ let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
+ let src02_idx = dst2_idx / params.broadcast2;
+ let src12_idx = dst2_idx;
+
+ let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
+ let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
+
+ let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+ let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+ var acc_sg_mat : array<array<subgroup_matrix_result<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>;
+
+ for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
+
+ // see mul_mat_decls.tmpl
+ init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
+ init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
+
+ workgroupBarrier();
+
+ if (subgroup_id < EXPECTED_SUBGROUPS) {
+
+ for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) {
+
+ let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner;
+ var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>;
+ for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
+ src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>(
+ &shmem,
+ src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K,
+ false,
+ TILE_K
+ );
+ }
+
+ let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner;
+ for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
+ let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>(
+ &shmem,
+ src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K,
+ true,
+ TILE_K
+ );
+ for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
+ acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]);
+ }
+ }
+ }
+ }
+
+ workgroupBarrier();
+ }
+
+ let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
+
+ // Stage the subgroup matrix tiles into shared memory
+ // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile).
+ let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE;
+ let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+ let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+
+ if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :(
+ for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
+ for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
+ let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE;
+ let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE;
+ let out_base = local_row * WG_TILE_STRIDE + local_col;
+ subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE);
+ }
+ }
+ }
+
+ workgroupBarrier();
+
+ // Cooperative write: iterate over the entire workgroup tile
+ let tile_rows = WG_N_SG_TILE_SIZE;
+ let tile_cols = WG_M_SG_TILE_SIZE;
+ let total_tile_elems = tile_rows * tile_cols;
+ let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
+ let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
+
+ for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+ let local_row = idx % WG_TILE_STRIDE;
+ let local_col = idx / WG_TILE_STRIDE;
+
+ let global_row = tile_dst_row_base + local_row;
+ let global_col = tile_dst_col_base + local_col;
+
+ if (global_col < params.n && global_row < params.m) {
+ let dst_idx = dst_batch_offset + global_col * params.m + global_row;
+ store_dst(idx, dst_idx/{{VEC_SIZE}});
+ }
+ }
+}
+
+#end(SHADER)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl
new file mode 100644
index 0000000..ffbb640
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl
@@ -0,0 +1,267 @@
+#define(VARIANTS)
+[
+ {
+ "SHADER_SUFFIX": "f32_f32_vec",
+ "REPLS": {
+ "SRC0_TYPE" : "vec4<f32>",
+ "SRC1_TYPE" : "vec4<f32>",
+ "DST_TYPE": "vec4<f32>",
+ "VEC_SIZE" : 4,
+ },
+ "DECLS": ["VEC", "MUL_ACC_FLOAT"]
+ },
+ {
+ "SHADER_SUFFIX": "f32_f32",
+ "REPLS": {
+ "SRC0_TYPE" : "f32",
+ "SRC1_TYPE" : "f32",
+ "DST_TYPE": "f32",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f32_vec",
+ "REPLS": {
+ "SRC0_TYPE" : "vec4<f16>",
+ "SRC1_TYPE" : "vec4<f32>",
+ "DST_TYPE": "vec4<f32>",
+ "VEC_SIZE" : 4,
+ },
+ "DECLS": ["VEC", "MUL_ACC_FLOAT"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f32",
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "f32",
+ "DST_TYPE": "f32",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f16_vec",
+ "REPLS": {
+ "SRC0_TYPE" : "vec4<f16>",
+ "SRC1_TYPE" : "vec4<f16>",
+ "DST_TYPE": "vec4<f32>",
+ "VEC_SIZE" : 4,
+ },
+ "DECLS": ["VEC", "MUL_ACC_FLOAT"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_f16",
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "f16",
+ "DST_TYPE": "f32",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
+ },
+ {
+ "SHADER_SUFFIX": "q4_0_f32",
+ "REPLS": {
+ "SRC0_TYPE" : "f16",
+ "SRC1_TYPE" : "f32",
+ "DST_TYPE": "f32",
+ "VEC_SIZE" : 1,
+ },
+ "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
+ }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(VEC)
+fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
+ return f32(dot({{SRC1_TYPE}}(src0_val), src1_val));
+}
+
+fn store_val(group_base: u32) -> vec4<f32> {
+ return vec4<f32>(partial_sums[group_base],
+ partial_sums[group_base + THREADS_PER_OUTPUT],
+ partial_sums[group_base + THREADS_PER_OUTPUT * 2],
+ partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
+}
+#enddecl(VEC)
+
+#decl(SCALAR)
+fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
+ return f32(src0_val) * f32(src1_val);
+}
+
+fn store_val(group_base: u32) -> f32 {
+ return partial_sums[group_base];
+}
+#enddecl(SCALAR)
+
+#decl(MUL_ACC_FLOAT)
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+ var local_sum = 0.0;
+ for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) {
+ let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}];
+ let b = shared_vector[i / {{VEC_SIZE}}];
+ local_sum += inner_dot(a, b);
+ }
+ return local_sum;
+}
+
+#enddecl(MUL_ACC_FLOAT)
+
+#decl(MUL_ACC_Q4_0)
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+ var local_sum = 0.0;
+ for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+ // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+ let d = f32(src0[scale_idx]);
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = src0[scale_idx + 1 + block_offset + j];
+ let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k: u32 = 0; k < 4; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
+ let q_lo = (f32(q_byte & 0xF) - 8.0) * d;
+ local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
+ local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
+ }
+ }
+ }
+ return local_sum;
+}
+
+#enddecl(MUL_ACC_Q4_0)
+
+#end(DECLS)
+
+#define(SHADER)
+enable f16;
+
+DECLS
+
+struct MulMatParams {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_dst: u32,
+ m: u32,
+ n: u32,
+ k: u32,
+ stride_01: u32,
+ stride_11: u32,
+ stride_02: u32,
+ stride_12: u32,
+ stride_03: u32,
+ stride_13: u32,
+ bs02: u32,
+ bs03: u32,
+ broadcast2: u32,
+ broadcast3: u32
+};
+
+@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K)
+@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed)
+@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // Result vector (transposed)
+
+@group(0) @binding(3) var<uniform> params: MulMatParams;
+
+override WORKGROUP_SIZE: u32;
+override TILE_K: u32;
+override OUTPUTS_PER_WG: u32;
+override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
+
+// Shared memory for collaborative loading and reduction
+var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile
+var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>; // For reduction
+
+@compute @workgroup_size(WORKGROUP_SIZE)
+fn main(
+ @builtin(local_invocation_id) local_id: vec3<u32>,
+ @builtin(workgroup_id) wg_id: vec3<u32>,
+ @builtin(num_workgroups) num_wg: vec3<u32>) {
+ let thread_id = local_id.x;
+
+ // Handle batch dimensions
+ let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+ let wg_linear = wg_id.y * num_wg.x + wg_id.x;
+ let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
+ let batch_idx = wg_linear / output_groups;
+ if (batch_idx >= total_batches) {
+ return;
+ }
+
+ // Which of the outputs does this thread belong to?
+ let thread_group = thread_id / THREADS_PER_OUTPUT;
+ let thread_in_group = thread_id % THREADS_PER_OUTPUT;
+
+ // Each workgroup computes OUTPUTS_PER_WG consecutive outputs
+ let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group;
+
+ let dst2_stride = params.m * params.n;
+ let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
+ let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+ let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
+ let src03_idx = dst3_idx / params.broadcast3;
+ let src13_idx = dst3_idx;
+ let src02_idx = dst2_idx / params.broadcast2;
+ let src12_idx = dst2_idx;
+
+ let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01;
+ let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
+ let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row;
+
+ var local_sum = 0.0;
+
+ // Each thread processes multiple K elements and accumulates
+ for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) {
+ let tile_size = min(TILE_K, params.k - k_tile);
+
+ // Cooperatively load vector tile into shared memory (all threads)
+ for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) {
+ shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}];
+ }
+
+ workgroupBarrier();
+
+ if (output_row < params.m) {
+ local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile);
+ }
+
+ workgroupBarrier();
+ }
+
+ // Store partial sums and reduce within each partition
+ partial_sums[thread_id] = local_sum;
+ workgroupBarrier();
+ let group_base = thread_group * THREADS_PER_OUTPUT;
+ let thread_base = group_base + thread_in_group;
+ var offset = THREADS_PER_OUTPUT / 2;
+ while (offset > 0) {
+ if (thread_in_group < offset) {
+ partial_sums[thread_base] += partial_sums[thread_base + offset];
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+
+ // Store back to global memory
+ if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) {
+ dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base);
+ }
+}
+#end(SHADER)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl
new file mode 100644
index 0000000..ea63b9a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl
@@ -0,0 +1,86 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+struct Params {
+ ne: u32, // total number of elements
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+
+ // Strides (in elements)
+ stride_src0: u32,
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ // Logical shapes
+ src_ne0: u32,
+ src_ne1: u32,
+ src_ne2: u32,
+ src_ne3: u32,
+
+ dst_ne0: u32,
+ dst_ne1: u32,
+ dst_ne2: u32,
+ dst_ne3: u32,
+
+ // Pad sizes (in elements)
+ lp0: u32,
+ rp0: u32,
+ lp1: u32,
+ rp1: u32,
+ lp2: u32,
+ rp2: u32,
+ lp3: u32,
+ rp3: u32,
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+fn wrap_around(idx: i32, n: u32) -> u32 {
+ return u32(idx + i32(n)) % n;
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= params.ne) {
+ return;
+ }
+
+ var i = gid.x;
+ let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0;
+ let i3 = i / dst_plane;
+ i = i % dst_plane;
+ let i2 = i / (params.dst_ne1 * params.dst_ne0);
+ i = i % (params.dst_ne1 * params.dst_ne0);
+ let i1 = i / params.dst_ne0;
+ let i0 = i % params.dst_ne0;
+
+ var value: f32 = 0.0;
+
+#ifdef CIRCULAR
+ let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0);
+ let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1);
+ let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2);
+ let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3);
+ let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 +
+ ci2 * params.stride_src2 + ci3 * params.stride_src3;
+ value = src[params.offset_src + circular_src_idx];
+#else
+ let is_src =
+ (i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) &&
+ (i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) &&
+ (i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) &&
+ (i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3);
+ if (is_src) {
+ let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 +
+ (i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3;
+ value = src[params.offset_src + src_idx];
+ }
+#endif
+
+ dst[params.offset_dst + gid.x] = value;
+}
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl
new file mode 100644
index 0000000..712b921
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl
@@ -0,0 +1,123 @@
+#define(VARIANTS)
+
+[
+ {
+ "DECLS": ["NOT_INPLACE"]
+ },
+ {
+ "SHADER_SUFFIX": "inplace",
+ "DECLS": ["INPLACE"]
+ },
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(NOT_INPLACE)
+
+fn update(src_offset: u32, dst_offset: u32, scale: f32) {
+ dst[dst_offset] = scale * src[src_offset];
+}
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#enddecl(NOT_INPLACE)
+
+#decl(INPLACE)
+
+fn update(src_offset: u32, dst_offset: u32, scale: f32) {
+ src[dst_offset] = scale * src[src_offset];
+}
+
+@group(0) @binding(1)
+var<uniform> params: Params;
+
+#enddecl(INPLACE)
+
+#end(DECLS)
+
+#define(SHADER)
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+
+ // Strides (in elements)
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // Shape of src/dst
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+ ne3: u32,
+
+ eps: f32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+DECLS
+
+override wg_size: u32;
+var<workgroup> scratch: array<f32, wg_size>;
+
+@compute @workgroup_size(wg_size)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+
+ // one thread per row
+ var i = wid.x;
+ let i3 = i / (params.ne2 * params.ne1);
+ i = i % (params.ne2 * params.ne1);
+ let i2 = i / params.ne1;
+ let i1 = i % params.ne1;
+ let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
+ let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+
+ let elems = (params.ne0 + wg_size - 1) / wg_size;
+
+ var sum = 0.0f;
+ var col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ sum += pow(src[i_src_row + col], 2.0);
+ col += wg_size;
+ }
+
+ scratch[lid.x] = sum;
+ workgroupBarrier();
+ var offset = wg_size / 2;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ scratch[lid.x] += scratch[lid.x + offset];
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+ sum = scratch[0];
+
+ let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
+ col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ update(i_src_row + col, i_dst_row + col, scale);
+ col += wg_size;
+ }
+}
+#end(SHADER)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl
new file mode 100644
index 0000000..84dc8db
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl
@@ -0,0 +1,295 @@
+#define(VARIANTS)
+
+[
+ {
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
+ },
+ {
+ "SHADER_SUFFIX": "f32_inplace",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_inplace",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
+ },
+ {
+ "SHADER_SUFFIX": "f32_ff",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
+ },
+ {
+ "SHADER_SUFFIX": "f32_ff_inplace",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_ff",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_ff_inplace",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
+ }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(ROTATE)
+fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
+ dst[i_dst0] = {{TYPE}}(out0);
+ dst[i_dst1] = {{TYPE}}(out1);
+}
+#enddecl(ROTATE)
+
+#decl(ROTATE_INPLACE)
+fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
+ src0[i_dst0] = {{TYPE}}(out0);
+ src0[i_dst1] = {{TYPE}}(out1);
+}
+#enddecl(ROTATE_INPLACE)
+
+#decl(NO_FF_FUNC)
+fn freq_factor(i: u32) -> f32 {
+ return 1.0f;
+}
+#enddecl(NO_FF_FUNC)
+
+#decl(FF_FUNC)
+fn freq_factor(i: u32) -> f32 {
+ return src2[params.offset_src2 + i/2];
+}
+#enddecl(FF_FUNC)
+
+#decl(NO_FF_BINDINGS)
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<{{TYPE}}>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+#enddecl(NO_FF_BINDINGS)
+
+#decl(NO_FF_BINDINGS_INPLACE)
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#enddecl(NO_FF_BINDINGS_INPLACE)
+
+#decl(FF_BINDINGS)
+
+@group(0) @binding(2)
+var<storage, read_write> src2: array<f32>;
+
+@group(0) @binding(3)
+var<storage, read_write> dst: array<{{TYPE}}>;
+
+@group(0) @binding(4)
+var<uniform> params: Params;
+
+#enddecl(FF_BINDINGS)
+
+#decl(FF_BINDINGS_INPLACE)
+
+@group(0) @binding(2)
+var<storage, read_write> src2: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+#enddecl(FF_BINDINGS_INPLACE)
+
+#end(DECLS)
+
+#define(SHADER)
+
+enable f16;
+
+struct Params {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_src2: u32,
+ offset_dst: u32,
+
+ // Strides (in elements)
+ stride_src01: u32,
+ stride_src02: u32,
+ stride_src03: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ n_threads: u32,
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ n_dims: u32,
+ mode: u32,
+ theta_scale: f32,
+ attn_factor: f32,
+ freq_scale: f32,
+ ext_factor: f32,
+ corr_dim0: f32,
+ corr_dim1: f32,
+ sections0: u32,
+ sections1: u32,
+ sections2: u32,
+ sections3: u32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<{{TYPE}}>;
+
+@group(0) @binding(1)
+var<storage, read_write> src1: array<i32>;
+
+DECLS
+
+fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
+ let y = (f32(i / 2) - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// returns vector of (cos_theta, sin_theta)
+// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
+fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
+ var mscale = params.attn_factor;
+ var theta = params.freq_scale * theta_extrap;
+ if (params.ext_factor != 0.0f) {
+ let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
+ theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
+ mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
+ }
+ return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
+}
+
+fn pair_base(i0: u32, div_2: bool) -> u32 {
+ if (div_2) {
+ return i0 / 2;
+ } else {
+ return i0;
+ }
+}
+
+fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
+ if (is_vision) {
+ return params.n_dims;
+ } else if (is_neox || is_mrope) {
+ return params.n_dims / 2;
+ } else {
+ return 1;
+ }
+}
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ // two elements per thread
+ if (gid.x >= params.n_threads) {
+ return;
+ }
+
+ let is_neox = bool(params.mode & 2);
+ let is_mrope = bool(params.mode & 8);
+ let is_imrope = params.mode == 40;
+ let is_vision = params.mode == 24;
+
+ var i = gid.x * 2; // start index for this thread
+ let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+ i = i % (params.ne2 * params.ne1 * params.ne0);
+ let i2 = i / (params.ne1 * params.ne0);
+ i = i % (params.ne1 * params.ne0);
+ let i1 = i / params.ne0;
+ let i0 = i % params.ne0;
+
+ let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
+ let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+
+ if (i0 >= params.n_dims && !is_vision) {
+ let i_src = i_src_row + i0;
+ let i_dst = i_dst_row + i0;
+ rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
+ return;
+ }
+
+ var theta_base_mult: u32 = 0;
+ var theta_scale_pwr: u32 = i0 / 2;
+ if (is_mrope) {
+ let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
+ let sec_w = params.sections1 + params.sections0;
+ let sec_e = params.sections2 + sec_w;
+ let sector = (i0 / 2) % sect_dims;
+ if (is_imrope) {
+ if (sector % 3 == 1 && sector < 3 * params.sections1) {
+ theta_base_mult = 1;
+ } else if (sector % 3 == 2 && sector < 3 * params.sections2) {
+ theta_base_mult = 2;
+ } else if (sector % 3 == 0 && sector < 3 * params.sections0) {
+ theta_base_mult = 0;
+ } else {
+ theta_base_mult = 3;
+ }
+ } else {
+ if (sector >= params.sections0 && sector < sec_w) {
+ theta_base_mult = 1;
+ if (is_vision) {
+ theta_scale_pwr = sector - params.sections0;
+ }
+ } else if (sector >= sec_w && sector < sec_e) {
+ theta_base_mult = 2;
+ if (is_vision) {
+ theta_scale_pwr = sector - sec_w;
+ }
+ } else if (sector >= sec_e) {
+ if (is_vision) {
+ theta_scale_pwr = sector - sec_e;
+ theta_scale_pwr = (i0 / 2) % sec_e;
+ }
+ theta_base_mult = 3;
+ } else if (is_vision) {
+ theta_scale_pwr = sector;
+ }
+ }
+ }
+ let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
+ let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
+
+ let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
+ let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
+
+ let x0 = f32(src0[i_src]);
+ let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
+ rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
+}
+
+#end(SHADER)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl
new file mode 100644
index 0000000..040e80d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl
@@ -0,0 +1,90 @@
+#define(VARIANTS)
+
+[
+ {
+ "SHADER_NAME": "scale_f32",
+ "DECLS": ["NOT_INPLACE"]
+ },
+ {
+ "SHADER_NAME": "scale_f32_inplace",
+ "DECLS": ["INPLACE"]
+ }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(NOT_INPLACE)
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+fn store_scale(val: f32, offset: u32) {
+ dst[offset] = val;
+}
+#enddecl(NOT_INPLACE)
+
+#decl(INPLACE)
+@group(0) @binding(1)
+var<uniform> params: Params;
+
+fn store_scale(val: f32, offset: u32) {
+ src[offset] = val;
+}
+#enddecl(INPLACE)
+
+#end(DECLS)
+
+#define(SHADER)
+
+struct Params {
+ offset_src: u32,
+ offset_dst: u32,
+
+ // Strides (in elements)
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ ne: u32,
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ scale: f32,
+ bias: f32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+DECLS
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= params.ne) {
+ return;
+ }
+
+ var i = gid.x;
+ let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+ i = i % (params.ne2 * params.ne1 * params.ne0);
+ let i2 = i / (params.ne1 * params.ne0);
+ i = i % (params.ne1 * params.ne0);
+ let i1 = i / params.ne0;
+ let i0 = i % params.ne0;
+
+ let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
+ let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
+
+ store_scale(src[i_src] * params.scale + params.bias, i_dst);
+}
+#end(SHADER)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
new file mode 100644
index 0000000..99e9192
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
@@ -0,0 +1,109 @@
+enable f16;
+
+#ifdef DST_F32
+#define DST_INNER_TYPE f32
+#else
+#define DST_INNER_TYPE f16
+#endif
+
+#ifdef VEC4
+#define SRC_TYPE vec4<f32>
+#define DST_TYPE vec4<DST_INNER_TYPE>
+#define VEC_SIZE 4
+#else
+#define SRC_TYPE f32
+#define DST_TYPE DST_INNER_TYPE
+#define VEC_SIZE 1
+#endif
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<SRC_TYPE>;
+
+@group(0) @binding(1)
+var<storage, read_write> idx: array<u32>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<DST_TYPE>;
+
+#ifdef I64_IDX
+@group(0) @binding(3)
+var<storage, read_write> error: atomic<u32>;
+#define PARAMS_BINDING 4
+#else
+#define PARAMS_BINDING 3
+#endif
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_idx: u32, // in elements
+ offset_dst: u32, // in elements
+
+ // Strides (in elements)
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_idx0: u32,
+ stride_idx1: u32,
+ stride_idx2: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // Shape of src
+ ne0: u32,
+ n_rows: u32,
+ ne2: u32,
+ ne3: u32,
+
+ // Shape of idx
+ idx1: u32,
+ idx2: u32,
+};
+
+@group(0) @binding(PARAMS_BINDING)
+var<uniform> params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) {
+ return;
+ }
+
+ // getting the row from gid
+ let elems_per_row = params.ne0 / VEC_SIZE;
+ var i = gid.x / elems_per_row;
+
+ let i_src3 = i / (params.ne2 * params.n_rows);
+
+ i = i % (params.ne2 * params.n_rows);
+ let i_src2 = i / params.n_rows;
+ let i_src1 = i % params.n_rows;
+
+ let i_idx2 = i_src3 % params.idx2;
+ let i_idx1 = i_src2 % params.idx1;
+ let i_idx0 = i_src1;
+
+#ifdef I64_IDX
+ let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
+
+ let idx_val = idx[idx_high];
+ let idx_low_val = idx[idx_high + 1];
+
+ if (idx_low_val != 0) {
+ // Upper bits of index are not zero, output will be incorrect
+ atomicStore(&error, 1);
+ return;
+ }
+#else
+ let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
+ let idx_val = idx[idx_i];
+#endif
+
+ let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
+ let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
+
+ let col_idx = (gid.x % elems_per_row);
+ dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
+}
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl
new file mode 100644
index 0000000..c74dc4c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl
@@ -0,0 +1,345 @@
+#define(VARIANTS)
+[
+ {
+ "SHADER_NAME": "soft_max_f32",
+ "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_inplace",
+ "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_sink",
+ "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_sink_inplace",
+ "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32_sink",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16_sink",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
+ }
+]
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(BASE_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#enddecl(BASE_BINDINGS)
+
+#decl(BASE_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<uniform> params: Params;
+#enddecl(BASE_BINDINGS_INPLACE)
+
+#decl(SINK_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#enddecl(SINK_BINDINGS)
+
+#decl(SINK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#enddecl(SINK_BINDINGS_INPLACE)
+
+#decl(MASK_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#enddecl(MASK_BINDINGS)
+
+#decl(MASK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#enddecl(MASK_BINDINGS_INPLACE)
+
+#decl(MASK_SINK_BINDINGS)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(3)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(4)
+var<uniform> params: Params;
+#enddecl(MASK_SINK_BINDINGS)
+
+#decl(MASK_SINK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var<storage, read_write> mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var<storage, read_write> sinks: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#enddecl(MASK_SINK_BINDINGS_INPLACE)
+
+#decl(NOT_INPLACE)
+fn inter_value(i: u32) -> f32 {
+ return dst[i];
+}
+
+fn update(i: u32, val: f32) {
+ dst[i] = val;
+}
+#enddecl(NOT_INPLACE)
+
+#decl(INPLACE)
+fn inter_value(i: u32) -> f32 {
+ return src[i];
+}
+
+fn update(i: u32, val: f32) {
+ src[i] = val;
+}
+#enddecl(INPLACE)
+
+#decl(NO_MASK)
+fn mask_val(i: u32) -> f32 {
+ return 0.0;
+}
+#enddecl(NO_MASK)
+
+#decl(MASK)
+fn mask_val(i: u32) -> f32 {
+ return f32(mask[i]);
+}
+#enddecl(MASK)
+
+#decl(NO_SINK)
+fn lower_max_bound(i2: u32) -> f32 {
+ return -1e30;
+}
+
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+ return val;
+}
+#enddecl(NO_SINK)
+
+#decl(SINK)
+fn lower_max_bound(i2: u32) -> f32 {
+ return sinks[params.offset_sinks + i2];
+}
+
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+ return val + exp(sinks[params.offset_sinks + i2] - max_val);
+}
+#enddecl(SINK)
+
+#end(DECLS)
+
+#define(SHADER)
+enable f16;
+
+struct Params {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_sinks: u32,
+ offset_dst: u32,
+
+ // Strides (in elements)
+ stride_src01: u32,
+ stride_src02: u32,
+ stride_src03: u32,
+
+ stride_src11: u32,
+ stride_src12: u32,
+ stride_src13: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // shape of src0/dst
+ ne: u32,
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ // shape of src1
+ ne12: u32,
+ ne13: u32,
+
+ scale: f32,
+ max_bias: f32,
+ n_head_log2: f32,
+ m0: f32,
+ m1: f32,
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+DECLS
+
+const CACHE_SIZE: u32 = 16;
+
+override wg_size: u32;
+var<workgroup> scratch: array<f32, wg_size>;
+
+@compute @workgroup_size(wg_size)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+
+ var i = wid.x;
+ let i3 = i / (params.ne2 * params.ne1);
+ i = i % (params.ne2 * params.ne1);
+ let i2 = i / params.ne1;
+ let i1 = i % params.ne1;
+ let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
+ let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
+ let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+ let elems = (params.ne0 + wg_size - 1) / wg_size;
+
+ let head = f32(i2);
+ let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
+
+ var cache: array<f32, CACHE_SIZE>;
+
+ var max_val = lower_max_bound(i2);
+ var col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
+ max_val = max(max_val, val);
+ if (col < CACHE_SIZE) {
+ cache[col] = val;
+ }
+ col += wg_size;
+ }
+
+ scratch[lid.x] = max_val;
+ workgroupBarrier();
+ var offset = wg_size / 2;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+ let row_max = scratch[0];
+ workgroupBarrier();
+
+ var sum = 0.0f;
+ col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
+ cache[col], col < CACHE_SIZE);
+ let ex = exp(val - row_max);
+ sum += ex;
+ if (col < CACHE_SIZE) {
+ cache[col] = ex;
+ } else {
+ update(i_dst_row + col, ex);
+ }
+ col += wg_size;
+ }
+
+ scratch[lid.x] = sum;
+ workgroupBarrier();
+ offset = wg_size / 2;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ scratch[lid.x] += scratch[lid.x + offset];
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+ let row_sum = add_sinks(scratch[0], i2, row_max);
+
+ let sum_recip = 1.0 / row_sum;
+ col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
+ col += wg_size;
+ }
+}
+#end(SHADER)
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl
new file mode 100644
index 0000000..6ea2de9
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl
@@ -0,0 +1,55 @@
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+
+ // Strides (in elements)
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ ne0: u32,
+ ne1: u32,
+ ne2: u32
+};
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+var<workgroup> shared_sum: array<f32, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+
+ var i = wid.x;
+ let i3 = i / (params.ne2 * params.ne1);
+ i = i % (params.ne2 * params.ne1);
+ let i2 = i / params.ne1;
+ let i1 = i % params.ne1;
+ let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
+ var local_sum: f32 = 0.0;
+ for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
+ local_sum += src[i_src_row + col];
+ }
+ shared_sum[lid.x] = local_sum;
+ workgroupBarrier();
+ // reduce within workgroup
+ var offset: u32 = WG_SIZE >> 1;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ shared_sum[lid.x] = shared_sum[lid.x] + shared_sum[lid.x + offset];
+ }
+ workgroupBarrier();
+ offset >>= 1;
+ }
+
+ if (lid.x == 0) {
+ dst[params.offset_dst + wid.x] = shared_sum[0];
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl
new file mode 100644
index 0000000..d639d98
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl
@@ -0,0 +1,179 @@
+#ifdef TYPE_F16
+enable f16;
+#define TYPE f16
+#else
+#define TYPE f32
+#endif
+
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<TYPE>;
+
+#ifndef INPLACE
+@group(0) @binding(1)
+var<storage, read_write> dst: array<TYPE>;
+#define PARAMS_BINDING 2
+#else
+#define PARAMS_BINDING 1
+#endif
+
+struct Params {
+ ne: u32, // total number of elements
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+
+ // Strides (in elements)
+ stride_src0: u32,
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ // Logical shapes
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+#ifdef CLAMP
+ clamp_min: f32,
+ clamp_max: f32,
+#endif
+#ifdef FILL
+ fill_val: f32,
+#endif
+#ifdef XIELU
+ alpha_n: f32,
+ alpha_p: f32,
+ beta: f32,
+ eps: f32,
+#endif
+
+};
+
+@group(0) @binding(PARAMS_BINDING)
+var<uniform> params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= params.ne) {
+ return;
+ }
+ var i = gid.x;
+ let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+ i = i % (params.ne2 * params.ne1 * params.ne0);
+ let i2 = i / (params.ne1 * params.ne0);
+ i = i % (params.ne1 * params.ne0);
+ let i1 = i / params.ne0;
+ let i0 = i % params.ne0;
+
+ let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
+ i2 * params.stride_src2 + i3 * params.stride_src3;
+
+#ifdef ABS
+ let res = abs(src[params.offset_src + src_idx]);
+#endif
+#ifdef SGN
+ let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0),
+ src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef NEG
+ let res = -src[params.offset_src + src_idx];
+#endif
+#ifdef STEP
+ let res = TYPE(select(0.0, 1.0, src[params.offset_src + src_idx] > 0.0));
+#endif
+#ifdef TANH
+ let res = tanh(clamp(src[params.offset_src + src_idx], -9.010913, 9.010913));
+#endif
+#ifdef RELU
+ let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef ELU
+ let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx],
+ src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef HARDSIGMOID
+ let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
+#endif
+#ifdef SIGMOID
+ let res = 1.0 / (1.0 + exp(-src[params.offset_src + src_idx]));
+#endif
+#ifdef SILU
+ let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
+#endif
+#ifdef EXP
+ let res = exp(src[params.offset_src + src_idx]);
+#endif
+#ifdef LOG
+ let res = TYPE(log(f32(src[params.offset_src + src_idx])));
+#endif
+#ifdef CLAMP
+ let res = clamp(src[params.offset_src + src_idx], TYPE(params.clamp_min), TYPE(params.clamp_max));
+#endif
+#ifdef FILL
+ let res = TYPE(params.fill_val);
+#endif
+#ifdef HARDSWISH
+ let res = src[params.offset_src + src_idx] *
+ min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
+#endif
+#ifdef GELU
+ let res = 0.5 * src[params.offset_src + src_idx] *
+ (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) *
+ (src[params.offset_src + src_idx] +
+ 0.044715 * pow(src[params.offset_src + src_idx], 3.0)),
+ -9.010913, 9.010913)));
+#endif
+#ifdef GELU_QUICK
+ let res = src[params.offset_src + src_idx] * 0.5 *
+ (1.0 + tanh(clamp(0.79788456 *
+ (src[params.offset_src + src_idx] +
+ 0.044715 * src[params.offset_src + src_idx] *
+ src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
+ -9.010913, 9.010913)));
+#endif
+#ifdef GELU_ERF
+ let res = 0.5 * src[params.offset_src + src_idx] *
+ (1.0 + tanh(clamp(0.79788456 *
+ (src[params.offset_src + src_idx] +
+ 0.044715 * src[params.offset_src + src_idx] *
+ src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
+ -9.010913, 9.010913)));
+#endif
+#ifdef XIELU
+ let res =
+ select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -
+ src[params.offset_src + src_idx]) *
+ TYPE(params.alpha_n) +
+ TYPE(params.beta) * src[params.offset_src + src_idx],
+ TYPE(params.alpha_p) * src[params.offset_src + src_idx] *
+ src[params.offset_src + src_idx] +
+ TYPE(params.beta) * src[params.offset_src + src_idx],
+ src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef SOFTPLUS
+ let src_f32 = f32(src[params.offset_src + src_idx]);
+ let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
+#endif
+#ifdef EXPM1
+ let res = exp(src[params.offset_src + src_idx]) - 1.0;
+#endif
+#ifdef FLOOR
+ let res = floor(src[params.offset_src + src_idx]);
+#endif
+#ifdef CEIL
+ let res = ceil(src[params.offset_src + src_idx]);
+#endif
+#ifdef ROUND
+ let src_f32 = f32(src[params.offset_src + src_idx]);
+ let result = select(ceil(src_f32 - 0.5), floor(src_f32 + 0.5), src_f32 >= 0.0);
+ let res = TYPE(result);
+#endif
+#ifdef TRUNC
+ let res = trunc(src[params.offset_src + src_idx]);
+#endif
+
+#ifdef INPLACE
+ src[params.offset_src + src_idx] = res;
+#else
+ dst[params.offset_dst + gid.x] = res;
+#endif
+}