#decl(SHMEM_VEC) fn store_shmem(val: vec4, idx: u32) { shmem[idx] = val.x; shmem[idx + 1] = val.y; shmem[idx + 2] = val.z; shmem[idx + 3] = val.w; } #enddecl(SHMEM_VEC) #decl(SHMEM_SCALAR) fn store_shmem(val: f16, idx: u32) { shmem[idx] = val; } #enddecl(SHMEM_SCALAR) #decl(INIT_SRC0_SHMEM_FLOAT) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_m = offset_m + tile_m; let global_k = k_outer + tile_k; let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let src0_val = select( // taking a slight performance hit to avoid oob {{SRC0_TYPE}}(0.0), src0[src0_idx/{{VEC_SIZE}}], global_m < params.m && global_k < params.k); store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); } } #enddecl(INIT_SRC0_SHMEM_FLOAT) #decl(INIT_SRC1_SHMEM) fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { let tile_n = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_n = offset_n + tile_n; let global_k = k_outer + tile_k; let src1_idx = batch_offset + global_n * params.stride_11 + global_k; let src1_val = select( {{SRC1_TYPE}}(0.0), src1[src1_idx/{{VEC_SIZE}}], global_n < params.n && global_k < params.k); store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); } } #enddecl(INIT_SRC1_SHMEM) #decl(INIT_SRC0_SHMEM_Q4_0) const BLOCK_SIZE = 32u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; let global_k = k_outer / BLOCK_SIZE + block_k; if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let scale_idx = src0_idx * F16_PER_BLOCK; let d = src0[scale_idx]; for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_0 = src0[scale_idx + 1u + block_offset + j]; let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; let q_lo = (f16(q_byte & 0xF) - 8.0) * d; shmem[shmem_idx + j * 2 + k] = q_lo; shmem[shmem_idx + j * 2 + k + 16u] = q_hi; } } } } } #enddecl(INIT_SRC0_SHMEM_Q4_0)