aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl
blob: 55dd66408a3e34671494b6759a8954b33fb44ac2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
enable f16;

struct Params {
    ne: u32,

    // offsets in elements
    offset_src0: u32,
    offset_src1: u32,
    offset_dst: u32,

    stride_src1_0: u32,
    stride_src1_1: u32,
    stride_src1_2: u32,
    stride_src1_3: u32,

    a_ne0: u32,
    a_ne1: u32,
    a_ne2: u32,

    b_ne0: u32,
    b_ne1: u32,
    b_ne2: u32,
    b_ne3: u32,
};

fn src1_index(_i: u32) -> u32 {
    var i = _i;
    let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
    i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
    let a_i2 = i / (params.a_ne1 * params.a_ne0);
    i = i % (params.a_ne1 * params.a_ne0);
    let a_i1 = i / params.a_ne0;
    let a_i0 = i % params.a_ne0;

    // handle repetition of b
    // index loops back to the beginning and repeats after elements are exhausted = modulo
    let b_i0 = a_i0 % params.b_ne0;
    let b_i1 = a_i1 % params.b_ne1;
    let b_i2 = a_i2 % params.b_ne2;
    let b_i3 = a_i3 % params.b_ne3;

    // compute index for position in b's flat array
    return b_i0 * params.stride_src1_0 +
           b_i1 * params.stride_src1_1 +
           b_i2 * params.stride_src1_2 +
           b_i3 * params.stride_src1_3;
}

#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_F16
#define DataType f16
#endif

@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;

@group(0) @binding(1)
var<storage, read_write> src1 : array<DataType>;

#ifdef INPLACE
@group(0) @binding(2)
var<uniform> params: Params;

#elif defined(OVERLAP)
@group(0) @binding(2)
var<uniform> params: Params;

#else
@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;

@group(0) @binding(3)
var<uniform> params: Params;
#endif

fn op(a: DataType, b: DataType) -> DataType {
#ifdef OP_ADD
    return a + b;
#elif defined(OP_SUB)
    return a - b;
#elif defined(OP_MUL)
    return a * b;
#elif defined(OP_DIV)
    return a / b;
#endif
}

fn update(dst_i: u32, src0_i: u32, src1_i: u32){
    let result = op(src0[src0_i], src1[src1_i]);

#ifdef INPLACE
    src0[dst_i] = result;
#elif defined(OVERLAP)
    src1[dst_i] = result;
#else
    dst[dst_i] = result;
#endif
}

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    if (gid.x < params.ne) {
        update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
    }
}