diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl | 302 |
1 files changed, 302 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl new file mode 100644 index 0000000..47c8ce3 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl | |||
| @@ -0,0 +1,302 @@ | |||
| 1 | #define(VARIANTS) | ||
| 2 | [ | ||
| 3 | { | ||
| 4 | "SHADER_SUFFIX": "f32_f32_vec", | ||
| 5 | "REPLS": { | ||
| 6 | "SRC0_TYPE" : "vec4<f32>", | ||
| 7 | "SRC1_TYPE" : "vec4<f32>", | ||
| 8 | "DST_TYPE" : "vec4<f32>", | ||
| 9 | "SHMEM_TYPE" : "vec4<f16>", | ||
| 10 | "VEC_SIZE" : 4, | ||
| 11 | }, | ||
| 12 | "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] | ||
| 13 | }, | ||
| 14 | { | ||
| 15 | "SHADER_SUFFIX": "f32_f32", | ||
| 16 | "REPLS": { | ||
| 17 | "SRC0_TYPE" : "f32", | ||
| 18 | "SRC1_TYPE" : "f32", | ||
| 19 | "DST_TYPE" : "f32", | ||
| 20 | "SHMEM_TYPE" : "f16", | ||
| 21 | "VEC_SIZE" : 1, | ||
| 22 | }, | ||
| 23 | "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] | ||
| 24 | }, | ||
| 25 | { | ||
| 26 | "SHADER_SUFFIX": "f16_f32_vec", | ||
| 27 | "REPLS": { | ||
| 28 | "SRC0_TYPE" : "vec4<f16>", | ||
| 29 | "SRC1_TYPE" : "vec4<f32>", | ||
| 30 | "DST_TYPE" : "vec4<f32>", | ||
| 31 | "SHMEM_TYPE" : "vec4<f16>", | ||
| 32 | "VEC_SIZE" : 4, | ||
| 33 | }, | ||
| 34 | "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] | ||
| 35 | }, | ||
| 36 | { | ||
| 37 | "SHADER_SUFFIX": "f16_f32", | ||
| 38 | "REPLS": { | ||
| 39 | "SRC0_TYPE" : "f16", | ||
| 40 | "SRC1_TYPE" : "f32", | ||
| 41 | "DST_TYPE" : "f32", | ||
| 42 | "SHMEM_TYPE" : "f16", | ||
| 43 | "VEC_SIZE" : 1, | ||
| 44 | }, | ||
| 45 | "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] | ||
| 46 | }, | ||
| 47 | { | ||
| 48 | "SHADER_SUFFIX": "f16_f16_vec", | ||
| 49 | "REPLS": { | ||
| 50 | "SRC0_TYPE" : "vec4<f16>", | ||
| 51 | "SRC1_TYPE" : "vec4<f16>", | ||
| 52 | "DST_TYPE" : "vec4<f32>", | ||
| 53 | "SHMEM_TYPE" : "vec4<f16>", | ||
| 54 | "VEC_SIZE" : 4, | ||
| 55 | }, | ||
| 56 | "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] | ||
| 57 | }, | ||
| 58 | { | ||
| 59 | "SHADER_SUFFIX": "f16_f16", | ||
| 60 | "REPLS": { | ||
| 61 | "SRC0_TYPE" : "f16", | ||
| 62 | "SRC1_TYPE" : "f16", | ||
| 63 | "DST_TYPE" : "f32", | ||
| 64 | "SHMEM_TYPE" : "f16", | ||
| 65 | "VEC_SIZE" : 1, | ||
| 66 | }, | ||
| 67 | "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] | ||
| 68 | }, | ||
| 69 | { | ||
| 70 | "SHADER_SUFFIX": "q4_0_f32_vec", | ||
| 71 | "REPLS": { | ||
| 72 | "SRC0_TYPE" : "f16", | ||
| 73 | "SRC1_TYPE" : "vec4<f32>", | ||
| 74 | "DST_TYPE" : "vec4<f32>", | ||
| 75 | "SHMEM_TYPE" : "vec4<f16>", | ||
| 76 | "VEC_SIZE" : 4, | ||
| 77 | }, | ||
| 78 | "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] | ||
| 79 | }, | ||
| 80 | { | ||
| 81 | "SHADER_SUFFIX": "q4_0_f32", | ||
| 82 | "REPLS": { | ||
| 83 | "SRC0_TYPE" : "f16", | ||
| 84 | "SRC1_TYPE" : "f32", | ||
| 85 | "DST_TYPE" : "f32", | ||
| 86 | "SHMEM_TYPE" : "f16", | ||
| 87 | "VEC_SIZE" : 1, | ||
| 88 | }, | ||
| 89 | "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] | ||
| 90 | } | ||
| 91 | ] | ||
| 92 | |||
| 93 | #end(VARIANTS) | ||
| 94 | |||
| 95 | #define(DECLS) | ||
| 96 | |||
| 97 | #decl(VEC) | ||
| 98 | fn store_dst(shmem_idx: u32, dst_idx: u32) { | ||
| 99 | dst[dst_idx] = vec4<f32>( | ||
| 100 | f32(shmem[shmem_idx]), | ||
| 101 | f32(shmem[shmem_idx + 1]), | ||
| 102 | f32(shmem[shmem_idx + 2]), | ||
| 103 | f32(shmem[shmem_idx + 3]) | ||
| 104 | ); | ||
| 105 | } | ||
| 106 | #enddecl(VEC) | ||
| 107 | |||
| 108 | #decl(SCALAR) | ||
| 109 | fn store_dst(shmem_idx: u32, dst_idx: u32) { | ||
| 110 | dst[dst_idx] = f32(shmem[shmem_idx]); | ||
| 111 | } | ||
| 112 | #enddecl(SCALAR) | ||
| 113 | |||
| 114 | #end(DECLS) | ||
| 115 | |||
| 116 | #define(SHADER) | ||
| 117 | diagnostic(off, chromium.subgroup_matrix_uniformity); | ||
| 118 | enable f16; | ||
| 119 | enable subgroups; | ||
| 120 | enable chromium_experimental_subgroup_matrix; | ||
| 121 | |||
| 122 | struct MulMatParams { | ||
| 123 | offset_src0: u32, | ||
| 124 | offset_src1: u32, | ||
| 125 | offset_dst: u32, | ||
| 126 | m: u32, | ||
| 127 | n: u32, | ||
| 128 | k: u32, | ||
| 129 | stride_01: u32, | ||
| 130 | stride_11: u32, | ||
| 131 | stride_02: u32, | ||
| 132 | stride_12: u32, | ||
| 133 | stride_03: u32, | ||
| 134 | stride_13: u32, | ||
| 135 | bs02: u32, | ||
| 136 | bs03: u32, | ||
| 137 | broadcast2: u32, | ||
| 138 | broadcast3: u32 | ||
| 139 | }; | ||
| 140 | |||
| 141 | @group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns | ||
| 142 | @group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) | ||
| 143 | @group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) | ||
| 144 | |||
| 145 | @group(0) @binding(3) var<uniform> params: MulMatParams; | ||
| 146 | |||
| 147 | DECLS | ||
| 148 | |||
| 149 | // Note: These are string interpolated at build time, cannot use override constants due to limitations in | ||
| 150 | // current Dawn version type definitions/matrix load requirements for constant memory sizes. | ||
| 151 | const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u; | ||
| 152 | const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u; | ||
| 153 | // For portability we assume the max subgroup size, meaning some subgroups will be masked out if the | ||
| 154 | // runtime subgroup size is smaller. | ||
| 155 | const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u; | ||
| 156 | |||
| 157 | const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; | ||
| 158 | |||
| 159 | const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u; | ||
| 160 | const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u; | ||
| 161 | const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u; | ||
| 162 | |||
| 163 | const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u; | ||
| 164 | const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u; | ||
| 165 | |||
| 166 | const TILE_K = {{WEBGPU_TILE_K}}u; | ||
| 167 | |||
| 168 | const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; | ||
| 169 | const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; | ||
| 170 | |||
| 171 | const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; | ||
| 172 | const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; | ||
| 173 | const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; | ||
| 174 | |||
| 175 | const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE; | ||
| 176 | |||
| 177 | // We reuse shmem for accumulation matrices | ||
| 178 | const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM); | ||
| 179 | |||
| 180 | var<workgroup> shmem: array<f16, SHMEM_SIZE>; | ||
| 181 | |||
| 182 | @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) | ||
| 183 | fn main(@builtin(workgroup_id) wg_id: vec3<u32>, | ||
| 184 | @builtin(local_invocation_id) local_id: vec3<u32>, | ||
| 185 | @builtin(subgroup_id) subgroup_id: u32) { | ||
| 186 | |||
| 187 | let thread_id = local_id.x; | ||
| 188 | let subgroup_m = subgroup_id % SUBGROUP_M; | ||
| 189 | let subgroup_n = subgroup_id / SUBGROUP_M; | ||
| 190 | |||
| 191 | let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE; | ||
| 192 | let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; | ||
| 193 | let wg_per_matrix = wg_m_count * wg_n_count; | ||
| 194 | |||
| 195 | let batch_idx = wg_id.x / wg_per_matrix; | ||
| 196 | |||
| 197 | let wg_in_batch = wg_id.x % wg_per_matrix; | ||
| 198 | let wg_m = wg_in_batch % wg_m_count; | ||
| 199 | let wg_n = wg_in_batch / wg_m_count; | ||
| 200 | |||
| 201 | let dst2_stride = params.m * params.n; | ||
| 202 | let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; | ||
| 203 | |||
| 204 | let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); | ||
| 205 | let src03_idx = dst3_idx / params.broadcast3; | ||
| 206 | let src13_idx = dst3_idx; | ||
| 207 | let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); | ||
| 208 | let src02_idx = dst2_idx / params.broadcast2; | ||
| 209 | let src12_idx = dst2_idx; | ||
| 210 | |||
| 211 | let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; | ||
| 212 | let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; | ||
| 213 | |||
| 214 | let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; | ||
| 215 | let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; | ||
| 216 | |||
| 217 | var acc_sg_mat : array<array<subgroup_matrix_result<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; | ||
| 218 | |||
| 219 | for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { | ||
| 220 | |||
| 221 | // see mul_mat_decls.tmpl | ||
| 222 | init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); | ||
| 223 | init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); | ||
| 224 | |||
| 225 | workgroupBarrier(); | ||
| 226 | |||
| 227 | if (subgroup_id < EXPECTED_SUBGROUPS) { | ||
| 228 | |||
| 229 | for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) { | ||
| 230 | |||
| 231 | let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner; | ||
| 232 | var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>; | ||
| 233 | for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { | ||
| 234 | src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>( | ||
| 235 | &shmem, | ||
| 236 | src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K, | ||
| 237 | false, | ||
| 238 | TILE_K | ||
| 239 | ); | ||
| 240 | } | ||
| 241 | |||
| 242 | let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; | ||
| 243 | for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { | ||
| 244 | let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>( | ||
| 245 | &shmem, | ||
| 246 | src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, | ||
| 247 | true, | ||
| 248 | TILE_K | ||
| 249 | ); | ||
| 250 | for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { | ||
| 251 | acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]); | ||
| 252 | } | ||
| 253 | } | ||
| 254 | } | ||
| 255 | } | ||
| 256 | |||
| 257 | workgroupBarrier(); | ||
| 258 | } | ||
| 259 | |||
| 260 | let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; | ||
| 261 | |||
| 262 | // Stage the subgroup matrix tiles into shared memory | ||
| 263 | // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile). | ||
| 264 | let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE; | ||
| 265 | let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; | ||
| 266 | let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; | ||
| 267 | |||
| 268 | if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :( | ||
| 269 | for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { | ||
| 270 | for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { | ||
| 271 | let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE; | ||
| 272 | let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE; | ||
| 273 | let out_base = local_row * WG_TILE_STRIDE + local_col; | ||
| 274 | subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE); | ||
| 275 | } | ||
| 276 | } | ||
| 277 | } | ||
| 278 | |||
| 279 | workgroupBarrier(); | ||
| 280 | |||
| 281 | // Cooperative write: iterate over the entire workgroup tile | ||
| 282 | let tile_rows = WG_N_SG_TILE_SIZE; | ||
| 283 | let tile_cols = WG_M_SG_TILE_SIZE; | ||
| 284 | let total_tile_elems = tile_rows * tile_cols; | ||
| 285 | let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; | ||
| 286 | let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; | ||
| 287 | |||
| 288 | for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { | ||
| 289 | let local_row = idx % WG_TILE_STRIDE; | ||
| 290 | let local_col = idx / WG_TILE_STRIDE; | ||
| 291 | |||
| 292 | let global_row = tile_dst_row_base + local_row; | ||
| 293 | let global_col = tile_dst_col_base + local_col; | ||
| 294 | |||
| 295 | if (global_col < params.n && global_row < params.m) { | ||
| 296 | let dst_idx = dst_batch_offset + global_col * params.m + global_row; | ||
| 297 | store_dst(idx, dst_idx/{{VEC_SIZE}}); | ||
| 298 | } | ||
| 299 | } | ||
| 300 | } | ||
| 301 | |||
| 302 | #end(SHADER) | ||
