1#define(VARIANTS)
  2[
  3  {
  4    "SHADER_SUFFIX": "f32_f32_vec",
  5    "REPLS": {
  6      "SRC0_TYPE" : "vec4<f32>",
  7      "SRC1_TYPE" : "vec4<f32>",
  8      "DST_TYPE": "vec4<f32>",
  9      "VEC_SIZE" : 4,
 10    },
 11    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
 12  },
 13  {
 14    "SHADER_SUFFIX": "f32_f32",
 15    "REPLS": {
 16      "SRC0_TYPE" : "f32",
 17      "SRC1_TYPE" : "f32",
 18      "DST_TYPE": "f32",
 19      "VEC_SIZE" : 1,
 20    },
 21    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
 22  },
 23  {
 24    "SHADER_SUFFIX": "f16_f32_vec",
 25    "REPLS": {
 26      "SRC0_TYPE" : "vec4<f16>",
 27      "SRC1_TYPE" : "vec4<f32>",
 28      "DST_TYPE": "vec4<f32>",
 29      "VEC_SIZE" : 4,
 30    },
 31    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
 32  },
 33  {
 34    "SHADER_SUFFIX": "f16_f32",
 35    "REPLS": {
 36      "SRC0_TYPE" : "f16",
 37      "SRC1_TYPE" : "f32",
 38      "DST_TYPE": "f32",
 39      "VEC_SIZE" : 1,
 40    },
 41    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
 42  },
 43  {
 44    "SHADER_SUFFIX": "f16_f16_vec",
 45    "REPLS": {
 46      "SRC0_TYPE" : "vec4<f16>",
 47      "SRC1_TYPE" : "vec4<f16>",
 48      "DST_TYPE": "vec4<f32>",
 49      "VEC_SIZE" : 4,
 50    },
 51    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
 52  },
 53  {
 54    "SHADER_SUFFIX": "f16_f16",
 55    "REPLS": {
 56      "SRC0_TYPE" : "f16",
 57      "SRC1_TYPE" : "f16",
 58      "DST_TYPE": "f32",
 59      "VEC_SIZE" : 1,
 60    },
 61    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
 62  },
 63  {
 64    "SHADER_SUFFIX": "q4_0_f32",
 65    "REPLS": {
 66      "SRC0_TYPE" : "f16",
 67      "SRC1_TYPE" : "f32",
 68      "DST_TYPE": "f32",
 69      "VEC_SIZE" : 1,
 70    },
 71    "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
 72  }
 73]
 74
 75#end(VARIANTS)
 76
 77#define(DECLS)
 78
 79#decl(VEC)
 80fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
 81    return f32(dot({{SRC1_TYPE}}(src0_val), src1_val));
 82}
 83
 84fn store_val(group_base: u32) -> vec4<f32> {
 85    return vec4<f32>(partial_sums[group_base],
 86                     partial_sums[group_base + THREADS_PER_OUTPUT],
 87                     partial_sums[group_base + THREADS_PER_OUTPUT * 2],
 88                     partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
 89}
 90#enddecl(VEC)
 91
 92#decl(SCALAR)
 93fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
 94    return f32(src0_val) * f32(src1_val);
 95}
 96
 97fn store_val(group_base: u32) -> f32 {
 98    return partial_sums[group_base];
 99}
100#enddecl(SCALAR)
101
102#decl(MUL_ACC_FLOAT)
103
104fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
105    var local_sum = 0.0;
106    for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) {
107        let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}];
108        let b = shared_vector[i / {{VEC_SIZE}}];
109        local_sum += inner_dot(a, b);
110    }
111    return local_sum;
112}
113
114#enddecl(MUL_ACC_FLOAT)
115
116#decl(MUL_ACC_Q4_0)
117
118const BLOCK_SIZE = 32;
119const NQ = 16u; // number of weights per thread
120const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
121const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
122const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
123
124fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
125    var local_sum = 0.0;
126    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
127        let blck_idx = i / BLOCK_SIZE;
128        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
129        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
130        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
131        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
132        let d = f32(src0[scale_idx]);
133        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
134            let q_0 = src0[scale_idx + 1 + block_offset + j];
135            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
136            let q_packed = bitcast<u32>(vec2(q_0, q_1));
137            for (var k: u32 = 0; k < 4; k++) {
138                let q_byte = get_byte(q_packed, k);
139                let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
140                let q_lo = (f32(q_byte & 0xF) - 8.0) * d;
141                local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
142                local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
143            }
144        }
145    }
146    return local_sum;
147}
148
149#enddecl(MUL_ACC_Q4_0)
150
151#end(DECLS)
152
153#define(SHADER)
154enable f16;
155
156DECLS
157
158struct MulMatParams {
159    offset_src0: u32,
160    offset_src1: u32,
161    offset_dst: u32,
162    m: u32,
163    n: u32,
164    k: u32,
165    stride_01: u32,
166    stride_11: u32,
167    stride_02: u32,
168    stride_12: u32,
169    stride_03: u32,
170    stride_13: u32,
171    bs02: u32,
172    bs03: u32,
173    broadcast2: u32,
174    broadcast3: u32
175};
176
177@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K)
178@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed)
179@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>;  // Result vector (transposed)
180
181@group(0) @binding(3) var<uniform> params: MulMatParams;
182
183override WORKGROUP_SIZE: u32;
184override TILE_K: u32;
185override OUTPUTS_PER_WG: u32;
186override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
187
188// Shared memory for collaborative loading and reduction
189var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>;  // Cache vector tile
190var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>;   // For reduction
191
192@compute @workgroup_size(WORKGROUP_SIZE)
193fn main(
194    @builtin(local_invocation_id) local_id: vec3<u32>,
195    @builtin(workgroup_id) wg_id: vec3<u32>,
196    @builtin(num_workgroups) num_wg: vec3<u32>) {
197    let thread_id = local_id.x;
198
199    // Handle batch dimensions
200    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
201    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
202    let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
203    let batch_idx = wg_linear / output_groups;
204    if (batch_idx >= total_batches) {
205        return;
206    }
207
208    // Which of the outputs does this thread belong to?
209    let thread_group = thread_id / THREADS_PER_OUTPUT;
210    let thread_in_group = thread_id % THREADS_PER_OUTPUT;
211
212    // Each workgroup computes OUTPUTS_PER_WG consecutive outputs
213    let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group;
214
215    let dst2_stride = params.m * params.n;
216    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
217    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
218    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
219    let src03_idx = dst3_idx / params.broadcast3;
220    let src13_idx = dst3_idx;
221    let src02_idx = dst2_idx / params.broadcast2;
222    let src12_idx = dst2_idx;
223
224    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01;
225    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
226    let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row;
227
228    var local_sum = 0.0;
229
230    // Each thread processes multiple K elements and accumulates
231    for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) {
232        let tile_size = min(TILE_K, params.k - k_tile);
233
234        // Cooperatively load vector tile into shared memory (all threads)
235        for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) {
236            shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}];
237        }
238
239        workgroupBarrier();
240
241        if (output_row < params.m) {
242            local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile);
243        }
244
245        workgroupBarrier();
246    }
247
248    // Store partial sums and reduce within each partition
249    partial_sums[thread_id] = local_sum;
250    workgroupBarrier();
251    let group_base = thread_group * THREADS_PER_OUTPUT;
252    let thread_base = group_base + thread_in_group;
253    var offset = THREADS_PER_OUTPUT / 2;
254    while (offset > 0) {
255        if (thread_in_group < offset) {
256            partial_sums[thread_base] += partial_sums[thread_base + offset];
257        }
258        offset = offset / 2;
259        workgroupBarrier();
260    }
261
262    // Store back to global memory
263    if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) {
264        dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base);
265    }
266}
267#end(SHADER)