aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
blob: 109ff8d6159e1ac040b9b6302fe72134bc39badc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#decl(SHMEM_VEC)
fn store_shmem(val: vec4<f16>, idx: u32) {
    shmem[idx] = val.x;
    shmem[idx + 1] = val.y;
    shmem[idx + 2] = val.z;
    shmem[idx + 3] = val.w;
}
#enddecl(SHMEM_VEC)

#decl(SHMEM_SCALAR)
fn store_shmem(val: f16, idx: u32) {
    shmem[idx] = val;
}
#enddecl(SHMEM_SCALAR)

#decl(INIT_SRC0_SHMEM_FLOAT)

fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
    for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
        let tile_m = elem_idx / TILE_K;
        let tile_k = elem_idx % TILE_K;
        let global_m = offset_m + tile_m;
        let global_k = k_outer + tile_k;
        let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
        let src0_val = select( // taking a slight performance hit to avoid oob
            {{SRC0_TYPE}}(0.0),
            src0[src0_idx/{{VEC_SIZE}}],
            global_m < params.m && global_k < params.k);
        store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
    }
}

#enddecl(INIT_SRC0_SHMEM_FLOAT)

#decl(INIT_SRC1_SHMEM)

fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
    for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
        let tile_n = elem_idx / TILE_K;
        let tile_k = elem_idx % TILE_K;
        let global_n = offset_n + tile_n;
        let global_k = k_outer + tile_k;
        let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
        let src1_val = select(
            {{SRC1_TYPE}}(0.0),
            src1[src1_idx/{{VEC_SIZE}}],
            global_n < params.n && global_k < params.k);
        store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
    }
}

#enddecl(INIT_SRC1_SHMEM)

#decl(INIT_SRC0_SHMEM_Q4_0)

const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;

fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
        let blck_idx = i / BLOCK_SIZE;
        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;

        let tile_m = blck_idx / BLOCKS_K;
        let global_m = offset_m + tile_m;
        let block_k = blck_idx % BLOCKS_K;
        let global_k = k_outer / BLOCK_SIZE + block_k;

        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
            let scale_idx = src0_idx * F16_PER_BLOCK;
            let d = src0[scale_idx];

            for (var j = 0u; j < F16_PER_THREAD; j += 2) {
                let q_0 = src0[scale_idx + 1u + block_offset + j];
                let q_1 = src0[scale_idx + 1u + block_offset + j + 1];

                let q_packed = bitcast<u32>(vec2(q_0, q_1));
                for (var k = 0u; k < 4u; k++) {
                    let q_byte = get_byte(q_packed, k);
                    let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
                    let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
                    shmem[shmem_idx + j * 2 + k] = q_lo;
                    shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
                }
            }
        }
    }
}

#enddecl(INIT_SRC0_SHMEM_Q4_0)