summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl302
1 files changed, 302 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl
new file mode 100644
index 0000000..47c8ce3
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl
@@ -0,0 +1,302 @@
1#define(VARIANTS)
2[
3 {
4 "SHADER_SUFFIX": "f32_f32_vec",
5 "REPLS": {
6 "SRC0_TYPE" : "vec4<f32>",
7 "SRC1_TYPE" : "vec4<f32>",
8 "DST_TYPE" : "vec4<f32>",
9 "SHMEM_TYPE" : "vec4<f16>",
10 "VEC_SIZE" : 4,
11 },
12 "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
13 },
14 {
15 "SHADER_SUFFIX": "f32_f32",
16 "REPLS": {
17 "SRC0_TYPE" : "f32",
18 "SRC1_TYPE" : "f32",
19 "DST_TYPE" : "f32",
20 "SHMEM_TYPE" : "f16",
21 "VEC_SIZE" : 1,
22 },
23 "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
24 },
25 {
26 "SHADER_SUFFIX": "f16_f32_vec",
27 "REPLS": {
28 "SRC0_TYPE" : "vec4<f16>",
29 "SRC1_TYPE" : "vec4<f32>",
30 "DST_TYPE" : "vec4<f32>",
31 "SHMEM_TYPE" : "vec4<f16>",
32 "VEC_SIZE" : 4,
33 },
34 "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
35 },
36 {
37 "SHADER_SUFFIX": "f16_f32",
38 "REPLS": {
39 "SRC0_TYPE" : "f16",
40 "SRC1_TYPE" : "f32",
41 "DST_TYPE" : "f32",
42 "SHMEM_TYPE" : "f16",
43 "VEC_SIZE" : 1,
44 },
45 "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
46 },
47 {
48 "SHADER_SUFFIX": "f16_f16_vec",
49 "REPLS": {
50 "SRC0_TYPE" : "vec4<f16>",
51 "SRC1_TYPE" : "vec4<f16>",
52 "DST_TYPE" : "vec4<f32>",
53 "SHMEM_TYPE" : "vec4<f16>",
54 "VEC_SIZE" : 4,
55 },
56 "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
57 },
58 {
59 "SHADER_SUFFIX": "f16_f16",
60 "REPLS": {
61 "SRC0_TYPE" : "f16",
62 "SRC1_TYPE" : "f16",
63 "DST_TYPE" : "f32",
64 "SHMEM_TYPE" : "f16",
65 "VEC_SIZE" : 1,
66 },
67 "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
68 },
69 {
70 "SHADER_SUFFIX": "q4_0_f32_vec",
71 "REPLS": {
72 "SRC0_TYPE" : "f16",
73 "SRC1_TYPE" : "vec4<f32>",
74 "DST_TYPE" : "vec4<f32>",
75 "SHMEM_TYPE" : "vec4<f16>",
76 "VEC_SIZE" : 4,
77 },
78 "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
79 },
80 {
81 "SHADER_SUFFIX": "q4_0_f32",
82 "REPLS": {
83 "SRC0_TYPE" : "f16",
84 "SRC1_TYPE" : "f32",
85 "DST_TYPE" : "f32",
86 "SHMEM_TYPE" : "f16",
87 "VEC_SIZE" : 1,
88 },
89 "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
90 }
91]
92
93#end(VARIANTS)
94
95#define(DECLS)
96
97#decl(VEC)
98fn store_dst(shmem_idx: u32, dst_idx: u32) {
99 dst[dst_idx] = vec4<f32>(
100 f32(shmem[shmem_idx]),
101 f32(shmem[shmem_idx + 1]),
102 f32(shmem[shmem_idx + 2]),
103 f32(shmem[shmem_idx + 3])
104 );
105}
106#enddecl(VEC)
107
108#decl(SCALAR)
109fn store_dst(shmem_idx: u32, dst_idx: u32) {
110 dst[dst_idx] = f32(shmem[shmem_idx]);
111}
112#enddecl(SCALAR)
113
114#end(DECLS)
115
116#define(SHADER)
117diagnostic(off, chromium.subgroup_matrix_uniformity);
118enable f16;
119enable subgroups;
120enable chromium_experimental_subgroup_matrix;
121
122struct MulMatParams {
123 offset_src0: u32,
124 offset_src1: u32,
125 offset_dst: u32,
126 m: u32,
127 n: u32,
128 k: u32,
129 stride_01: u32,
130 stride_11: u32,
131 stride_02: u32,
132 stride_12: u32,
133 stride_03: u32,
134 stride_13: u32,
135 bs02: u32,
136 bs03: u32,
137 broadcast2: u32,
138 broadcast3: u32
139};
140
141@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
142@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
143@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
144
145@group(0) @binding(3) var<uniform> params: MulMatParams;
146
147DECLS
148
149// Note: These are string interpolated at build time, cannot use override constants due to limitations in
150// current Dawn version type definitions/matrix load requirements for constant memory sizes.
151const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
152const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
153// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
154// runtime subgroup size is smaller.
155const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
156
157const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
158
159const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
160const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
161const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
162
163const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
164const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
165
166const TILE_K = {{WEBGPU_TILE_K}}u;
167
168const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
169const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
170
171const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
172const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
173const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
174
175const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE;
176
177// We reuse shmem for accumulation matrices
178const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM);
179
180var<workgroup> shmem: array<f16, SHMEM_SIZE>;
181
182@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
183fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
184 @builtin(local_invocation_id) local_id: vec3<u32>,
185 @builtin(subgroup_id) subgroup_id: u32) {
186
187 let thread_id = local_id.x;
188 let subgroup_m = subgroup_id % SUBGROUP_M;
189 let subgroup_n = subgroup_id / SUBGROUP_M;
190
191 let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE;
192 let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
193 let wg_per_matrix = wg_m_count * wg_n_count;
194
195 let batch_idx = wg_id.x / wg_per_matrix;
196
197 let wg_in_batch = wg_id.x % wg_per_matrix;
198 let wg_m = wg_in_batch % wg_m_count;
199 let wg_n = wg_in_batch / wg_m_count;
200
201 let dst2_stride = params.m * params.n;
202 let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
203
204 let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
205 let src03_idx = dst3_idx / params.broadcast3;
206 let src13_idx = dst3_idx;
207 let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
208 let src02_idx = dst2_idx / params.broadcast2;
209 let src12_idx = dst2_idx;
210
211 let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
212 let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
213
214 let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
215 let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
216
217 var acc_sg_mat : array<array<subgroup_matrix_result<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>;
218
219 for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
220
221 // see mul_mat_decls.tmpl
222 init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
223 init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
224
225 workgroupBarrier();
226
227 if (subgroup_id < EXPECTED_SUBGROUPS) {
228
229 for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) {
230
231 let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner;
232 var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>;
233 for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
234 src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>(
235 &shmem,
236 src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K,
237 false,
238 TILE_K
239 );
240 }
241
242 let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner;
243 for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
244 let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>(
245 &shmem,
246 src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K,
247 true,
248 TILE_K
249 );
250 for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
251 acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]);
252 }
253 }
254 }
255 }
256
257 workgroupBarrier();
258 }
259
260 let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
261
262 // Stage the subgroup matrix tiles into shared memory
263 // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile).
264 let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE;
265 let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
266 let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
267
268 if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :(
269 for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
270 for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
271 let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE;
272 let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE;
273 let out_base = local_row * WG_TILE_STRIDE + local_col;
274 subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE);
275 }
276 }
277 }
278
279 workgroupBarrier();
280
281 // Cooperative write: iterate over the entire workgroup tile
282 let tile_rows = WG_N_SG_TILE_SIZE;
283 let tile_cols = WG_M_SG_TILE_SIZE;
284 let total_tile_elems = tile_rows * tile_cols;
285 let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
286 let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
287
288 for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
289 let local_row = idx % WG_TILE_STRIDE;
290 let local_col = idx / WG_TILE_STRIDE;
291
292 let global_row = tile_dst_row_base + local_row;
293 let global_col = tile_dst_col_base + local_col;
294
295 if (global_col < params.n && global_row < params.m) {
296 let dst_idx = dst_batch_offset + global_col * params.m + global_row;
297 store_dst(idx, dst_idx/{{VEC_SIZE}});
298 }
299 }
300}
301
302#end(SHADER)