1#decl(SHMEM_VEC)
2fn store_shmem(val: vec4<f16>, idx: u32) {
3 shmem[idx] = val.x;
4 shmem[idx + 1] = val.y;
5 shmem[idx + 2] = val.z;
6 shmem[idx + 3] = val.w;
7}
8#enddecl(SHMEM_VEC)
9
10#decl(SHMEM_SCALAR)
11fn store_shmem(val: f16, idx: u32) {
12 shmem[idx] = val;
13}
14#enddecl(SHMEM_SCALAR)
15
16#decl(INIT_SRC0_SHMEM_FLOAT)
17
18fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
19 for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
20 let tile_m = elem_idx / TILE_K;
21 let tile_k = elem_idx % TILE_K;
22 let global_m = offset_m + tile_m;
23 let global_k = k_outer + tile_k;
24 let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
25 let src0_val = select( // taking a slight performance hit to avoid oob
26 {{SRC0_TYPE}}(0.0),
27 src0[src0_idx/{{VEC_SIZE}}],
28 global_m < params.m && global_k < params.k);
29 store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
30 }
31}
32
33#enddecl(INIT_SRC0_SHMEM_FLOAT)
34
35#decl(INIT_SRC1_SHMEM)
36
37fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
38 for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
39 let tile_n = elem_idx / TILE_K;
40 let tile_k = elem_idx % TILE_K;
41 let global_n = offset_n + tile_n;
42 let global_k = k_outer + tile_k;
43 let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
44 let src1_val = select(
45 {{SRC1_TYPE}}(0.0),
46 src1[src1_idx/{{VEC_SIZE}}],
47 global_n < params.n && global_k < params.k);
48 store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
49 }
50}
51
52#enddecl(INIT_SRC1_SHMEM)
53
54#decl(INIT_SRC0_SHMEM_Q4_0)
55
56const BLOCK_SIZE = 32u;
57// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
58override BLOCKS_K = TILE_K/BLOCK_SIZE;
59const NQ = 16u;
60const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
61const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
62const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
63
64fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
65 for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
66 let blck_idx = i / BLOCK_SIZE;
67 let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
68 let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
69
70 let tile_m = blck_idx / BLOCKS_K;
71 let global_m = offset_m + tile_m;
72 let block_k = blck_idx % BLOCKS_K;
73 let global_k = k_outer / BLOCK_SIZE + block_k;
74
75 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
76 let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
77 let scale_idx = src0_idx * F16_PER_BLOCK;
78 let d = src0[scale_idx];
79
80 for (var j = 0u; j < F16_PER_THREAD; j += 2) {
81 let q_0 = src0[scale_idx + 1u + block_offset + j];
82 let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
83
84 let q_packed = bitcast<u32>(vec2(q_0, q_1));
85 for (var k = 0u; k < 4u; k++) {
86 let q_byte = get_byte(q_packed, k);
87 let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
88 let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
89 shmem[shmem_idx + j * 2 + k] = q_lo;
90 shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
91 }
92 }
93 }
94 }
95}
96
97#enddecl(INIT_SRC0_SHMEM_Q4_0)