1#define(VARIANTS)
  2[
  3  {
  4    "SHADER_SUFFIX": "f32_f32_vec",
  5    "REPLS": {
  6      "SRC0_TYPE" : "vec4<f32>",
  7      "SRC1_TYPE" : "vec4<f32>",
  8      "DST_TYPE" : "vec4<f32>",
  9      "SHMEM_TYPE" : "vec4<f16>",
 10      "VEC_SIZE" : 4,
 11    },
 12    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
 13  },
 14  {
 15    "SHADER_SUFFIX": "f32_f32",
 16    "REPLS": {
 17      "SRC0_TYPE" : "f32",
 18      "SRC1_TYPE" : "f32",
 19      "DST_TYPE" : "f32",
 20      "SHMEM_TYPE" : "f16",
 21      "VEC_SIZE" : 1,
 22    },
 23    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
 24  },
 25  {
 26    "SHADER_SUFFIX": "f16_f32_vec",
 27    "REPLS": {
 28      "SRC0_TYPE" : "vec4<f16>",
 29      "SRC1_TYPE" : "vec4<f32>",
 30      "DST_TYPE" : "vec4<f32>",
 31      "SHMEM_TYPE" : "vec4<f16>",
 32      "VEC_SIZE" : 4,
 33    },
 34    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
 35  },
 36  {
 37    "SHADER_SUFFIX": "f16_f32",
 38    "REPLS": {
 39      "SRC0_TYPE" : "f16",
 40      "SRC1_TYPE" : "f32",
 41      "DST_TYPE" : "f32",
 42      "SHMEM_TYPE" : "f16",
 43      "VEC_SIZE" : 1,
 44    },
 45    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
 46  },
 47  {
 48    "SHADER_SUFFIX": "f16_f16_vec",
 49    "REPLS": {
 50      "SRC0_TYPE" : "vec4<f16>",
 51      "SRC1_TYPE" : "vec4<f16>",
 52      "DST_TYPE" : "vec4<f32>",
 53      "SHMEM_TYPE" : "vec4<f16>",
 54      "VEC_SIZE" : 4,
 55    },
 56    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
 57  },
 58  {
 59    "SHADER_SUFFIX": "f16_f16",
 60    "REPLS": {
 61      "SRC0_TYPE" : "f16",
 62      "SRC1_TYPE" : "f16",
 63      "DST_TYPE" : "f32",
 64      "SHMEM_TYPE" : "f16",
 65      "VEC_SIZE" : 1,
 66    },
 67    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
 68  },
 69  {
 70    "SHADER_SUFFIX": "q4_0_f32_vec",
 71    "REPLS": {
 72      "SRC0_TYPE" : "f16",
 73      "SRC1_TYPE" : "vec4<f32>",
 74      "DST_TYPE" : "vec4<f32>",
 75      "SHMEM_TYPE" : "vec4<f16>",
 76      "VEC_SIZE" : 4,
 77    },
 78    "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
 79  },
 80  {
 81    "SHADER_SUFFIX": "q4_0_f32",
 82    "REPLS": {
 83      "SRC0_TYPE" : "f16",
 84      "SRC1_TYPE" : "f32",
 85      "DST_TYPE" : "f32",
 86      "SHMEM_TYPE" : "f16",
 87      "VEC_SIZE" : 1,
 88    },
 89    "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
 90  }
 91]
 92
 93#end(VARIANTS)
 94
 95#define(DECLS)
 96
 97#decl(VEC)
 98fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
 99    return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
100}
101#enddecl(VEC)
102
103#decl(SCALAR)
104fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
105    return f32(acc[tm][tn]);
106}
107#enddecl(SCALAR)
108
109#end(DECLS)
110
111#define(SHADER)
112enable f16;
113
114struct MulMatParams {
115    offset_src0: u32,
116    offset_src1: u32,
117    offset_dst: u32,
118    m: u32,
119    n: u32,
120    k: u32,
121    stride_01: u32,
122    stride_11: u32,
123    stride_02: u32,
124    stride_12: u32,
125    stride_03: u32,
126    stride_13: u32,
127    bs02: u32,
128    bs03: u32,
129    broadcast2: u32,
130    broadcast3: u32
131};
132
133@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
134@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
135@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
136
137@group(0) @binding(3) var<uniform> params: MulMatParams;
138
139DECLS
140
141fn get_local_n(thread_id: u32) -> u32 {
142    return thread_id / WORKGROUP_SIZE_M;
143}
144fn get_local_m(thread_id: u32) -> u32 {
145    return thread_id % WORKGROUP_SIZE_M;
146}
147
148// TILE_M must be multiple of 4 for vec4 loads
149const TILE_M = {{WEBGPU_TILE_M}}u;
150const TILE_N = {{WEBGPU_TILE_N}}u;
151
152override WORKGROUP_SIZE_M: u32;
153override WORKGROUP_SIZE_N: u32;
154override TILE_K: u32;
155
156override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
157override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
158override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
159
160var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
161
162@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
163fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
164        @builtin(local_invocation_id) local_id: vec3<u32>) {
165
166    let thread_id = local_id.x;
167    let local_m = get_local_m(thread_id);
168    let local_n = get_local_n(thread_id);
169
170    let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N);
171    let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
172    let wg_per_matrix = wg_m_count * wg_n_count;
173
174    let batch_idx = wg_id.x / wg_per_matrix;
175
176    let wg_in_batch = wg_id.x % wg_per_matrix;
177    let wg_m = wg_in_batch % wg_m_count;
178    let wg_n = wg_in_batch / wg_m_count;
179
180    let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M;
181    let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N;
182
183    let dst2_stride = params.m * params.n;
184    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
185
186    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
187    let src03_idx = dst3_idx / params.broadcast3;
188    let src13_idx = dst3_idx;
189    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
190    let src02_idx = dst2_idx / params.broadcast2;
191    let src12_idx = dst2_idx;
192
193    let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02;
194    let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
195
196    let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
197    let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
198
199    var acc: array<array<f16, TILE_N>, TILE_M>;
200
201    for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
202
203        // see mul_mat_decls.tmpl
204        init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer);
205        init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer);
206
207        workgroupBarrier();
208
209        let k_end = min(TILE_K, params.k - k_outer);
210
211        for (var k_inner = 0u; k_inner < k_end; k_inner++) {
212            var src0_tile: array<f16, TILE_M>;
213            for (var tm = 0u; tm < TILE_M; tm++) {
214                let src0_m = local_m * TILE_M + tm;
215                let src0_idx = k_inner + src0_m * TILE_K;
216                src0_tile[tm] = shmem[src0_idx];
217            }
218            for (var tn = 0u; tn < TILE_N; tn++) {
219                let src1_n = local_n * TILE_N + tn;
220                let src1_idx = src1_n * TILE_K + k_inner;
221                let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
222                for (var tm = 0u; tm < TILE_M; tm++) {
223                      acc[tm][tn] += src0_tile[tm] * src1_val;
224                }
225            }
226        }
227
228        workgroupBarrier();
229    }
230
231    let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride;
232
233    for (var tn = 0u; tn < TILE_N; tn++) {
234        let global_col = output_col_base + tn;
235        if (global_col < params.n) {
236            for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
237                let global_row = output_row_base + tm;
238                if (global_row < params.m) {
239                    let dst_idx = dst_batch_offset + global_col * params.m + global_row;
240                    dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
241                }
242            }
243        }
244    }
245}
246
247#end(SHADER)