summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl97
1 files changed, 97 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
new file mode 100644
index 0000000..109ff8d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
@@ -0,0 +1,97 @@
+#decl(SHMEM_VEC)
+fn store_shmem(val: vec4<f16>, idx: u32) {
+ shmem[idx] = val.x;
+ shmem[idx + 1] = val.y;
+ shmem[idx + 2] = val.z;
+ shmem[idx + 3] = val.w;
+}
+#enddecl(SHMEM_VEC)
+
+#decl(SHMEM_SCALAR)
+fn store_shmem(val: f16, idx: u32) {
+ shmem[idx] = val;
+}
+#enddecl(SHMEM_SCALAR)
+
+#decl(INIT_SRC0_SHMEM_FLOAT)
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+ let tile_m = elem_idx / TILE_K;
+ let tile_k = elem_idx % TILE_K;
+ let global_m = offset_m + tile_m;
+ let global_k = k_outer + tile_k;
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+ let src0_val = select( // taking a slight performance hit to avoid oob
+ {{SRC0_TYPE}}(0.0),
+ src0[src0_idx/{{VEC_SIZE}}],
+ global_m < params.m && global_k < params.k);
+ store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
+ }
+}
+
+#enddecl(INIT_SRC0_SHMEM_FLOAT)
+
+#decl(INIT_SRC1_SHMEM)
+
+fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
+ for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+ let tile_n = elem_idx / TILE_K;
+ let tile_k = elem_idx % TILE_K;
+ let global_n = offset_n + tile_n;
+ let global_k = k_outer + tile_k;
+ let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
+ let src1_val = select(
+ {{SRC1_TYPE}}(0.0),
+ src1[src1_idx/{{VEC_SIZE}}],
+ global_n < params.n && global_k < params.k);
+ store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
+ }
+}
+
+#enddecl(INIT_SRC1_SHMEM)
+
+#decl(INIT_SRC0_SHMEM_Q4_0)
+
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+ for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+ let blck_idx = i / BLOCK_SIZE;
+ let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+ let tile_m = blck_idx / BLOCKS_K;
+ let global_m = offset_m + tile_m;
+ let block_k = blck_idx % BLOCKS_K;
+ let global_k = k_outer / BLOCK_SIZE + block_k;
+
+ if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+ let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+ let scale_idx = src0_idx * F16_PER_BLOCK;
+ let d = src0[scale_idx];
+
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = src0[scale_idx + 1u + block_offset + j];
+ let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
+
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
+ let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
+ shmem[shmem_idx + j * 2 + k] = q_lo;
+ shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
+ }
+ }
+ }
+ }
+}
+
+#enddecl(INIT_SRC0_SHMEM_Q4_0)