aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl
blob: 040e80dfea24a507151112be7ee408c52ba91d78 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#define(VARIANTS)

[
  {
    "SHADER_NAME": "scale_f32",
    "DECLS": ["NOT_INPLACE"]
  },
  {
    "SHADER_NAME": "scale_f32_inplace",
    "DECLS": ["INPLACE"]
  }
]

#end(VARIANTS)

#define(DECLS)

#decl(NOT_INPLACE)
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;

@group(0) @binding(2)
var<uniform> params: Params;

fn store_scale(val: f32, offset: u32) {
    dst[offset] = val;
}
#enddecl(NOT_INPLACE)

#decl(INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;

fn store_scale(val: f32, offset: u32) {
    src[offset] = val;
}
#enddecl(INPLACE)

#end(DECLS)

#define(SHADER)

struct Params {
    offset_src: u32,
    offset_dst: u32,

    // Strides (in elements)
    stride_src1: u32,
    stride_src2: u32,
    stride_src3: u32,

    stride_dst1: u32,
    stride_dst2: u32,
    stride_dst3: u32,

    ne: u32,
    ne0: u32,
    ne1: u32,
    ne2: u32,

    scale: f32,
    bias: f32
};

@group(0) @binding(0)
var<storage, read_write> src: array<f32>;

DECLS

override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    if (gid.x >= params.ne) {
        return;
    }

    var i = gid.x;
    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
    i = i % (params.ne2 * params.ne1 * params.ne0);
    let i2 = i / (params.ne1 * params.ne0);
    i = i % (params.ne1 * params.ne0);
    let i1 = i / params.ne0;
    let i0 = i % params.ne0;

    let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
    let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;

    store_scale(src[i_src] * params.scale + params.bias, i_dst);
}
#end(SHADER)