1#define(VARIANTS)
 2
 3[
 4  {
 5    "SHADER_NAME": "scale_f32",
 6    "DECLS": ["NOT_INPLACE"]
 7  },
 8  {
 9    "SHADER_NAME": "scale_f32_inplace",
10    "DECLS": ["INPLACE"]
11  }
12]
13
14#end(VARIANTS)
15
16#define(DECLS)
17
18#decl(NOT_INPLACE)
19@group(0) @binding(1)
20var<storage, read_write> dst: array<f32>;
21
22@group(0) @binding(2)
23var<uniform> params: Params;
24
25fn store_scale(val: f32, offset: u32) {
26    dst[offset] = val;
27}
28#enddecl(NOT_INPLACE)
29
30#decl(INPLACE)
31@group(0) @binding(1)
32var<uniform> params: Params;
33
34fn store_scale(val: f32, offset: u32) {
35    src[offset] = val;
36}
37#enddecl(INPLACE)
38
39#end(DECLS)
40
41#define(SHADER)
42
43struct Params {
44    offset_src: u32,
45    offset_dst: u32,
46
47    // Strides (in elements)
48    stride_src1: u32,
49    stride_src2: u32,
50    stride_src3: u32,
51
52    stride_dst1: u32,
53    stride_dst2: u32,
54    stride_dst3: u32,
55
56    ne: u32,
57    ne0: u32,
58    ne1: u32,
59    ne2: u32,
60
61    scale: f32,
62    bias: f32
63};
64
65@group(0) @binding(0)
66var<storage, read_write> src: array<f32>;
67
68DECLS
69
70override wg_size: u32;
71@compute @workgroup_size(wg_size)
72fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
73    if (gid.x >= params.ne) {
74        return;
75    }
76
77    var i = gid.x;
78    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
79    i = i % (params.ne2 * params.ne1 * params.ne0);
80    let i2 = i / (params.ne1 * params.ne0);
81    i = i % (params.ne1 * params.ne0);
82    let i1 = i / params.ne0;
83    let i0 = i % params.ne0;
84
85    let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
86    let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
87
88    store_scale(src[i_src] * params.scale + params.bias, i_dst);
89}
90#end(SHADER)