1#decl(SHMEM_VEC)
 2fn store_shmem(val: vec4<f16>, idx: u32) {
 3    shmem[idx] = val.x;
 4    shmem[idx + 1] = val.y;
 5    shmem[idx + 2] = val.z;
 6    shmem[idx + 3] = val.w;
 7}
 8#enddecl(SHMEM_VEC)
 9
10#decl(SHMEM_SCALAR)
11fn store_shmem(val: f16, idx: u32) {
12    shmem[idx] = val;
13}
14#enddecl(SHMEM_SCALAR)
15
16#decl(INIT_SRC0_SHMEM_FLOAT)
17
18fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
19    for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
20        let tile_m = elem_idx / TILE_K;
21        let tile_k = elem_idx % TILE_K;
22        let global_m = offset_m + tile_m;
23        let global_k = k_outer + tile_k;
24        let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
25        let src0_val = select( // taking a slight performance hit to avoid oob
26            {{SRC0_TYPE}}(0.0),
27            src0[src0_idx/{{VEC_SIZE}}],
28            global_m < params.m && global_k < params.k);
29        store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
30    }
31}
32
33#enddecl(INIT_SRC0_SHMEM_FLOAT)
34
35#decl(INIT_SRC1_SHMEM)
36
37fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
38    for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
39        let tile_n = elem_idx / TILE_K;
40        let tile_k = elem_idx % TILE_K;
41        let global_n = offset_n + tile_n;
42        let global_k = k_outer + tile_k;
43        let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
44        let src1_val = select(
45            {{SRC1_TYPE}}(0.0),
46            src1[src1_idx/{{VEC_SIZE}}],
47            global_n < params.n && global_k < params.k);
48        store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
49    }
50}
51
52#enddecl(INIT_SRC1_SHMEM)
53
54#decl(INIT_SRC0_SHMEM_Q4_0)
55
56const BLOCK_SIZE = 32u;
57// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
58override BLOCKS_K = TILE_K/BLOCK_SIZE;
59const NQ = 16u;
60const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
61const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
62const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
63
64fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
65    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
66        let blck_idx = i / BLOCK_SIZE;
67        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
68        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
69
70        let tile_m = blck_idx / BLOCKS_K;
71        let global_m = offset_m + tile_m;
72        let block_k = blck_idx % BLOCKS_K;
73        let global_k = k_outer / BLOCK_SIZE + block_k;
74
75        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
76            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
77            let scale_idx = src0_idx * F16_PER_BLOCK;
78            let d = src0[scale_idx];
79
80            for (var j = 0u; j < F16_PER_THREAD; j += 2) {
81                let q_0 = src0[scale_idx + 1u + block_offset + j];
82                let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
83
84                let q_packed = bitcast<u32>(vec2(q_0, q_1));
85                for (var k = 0u; k < 4u; k++) {
86                    let q_byte = get_byte(q_packed, k);
87                    let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
88                    let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
89                    shmem[shmem_idx + j * 2 + k] = q_lo;
90                    shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
91                }
92            }
93        }
94    }
95}
96
97#enddecl(INIT_SRC0_SHMEM_Q4_0)